diff options
Diffstat (limited to 'mlir')
645 files changed, 38980 insertions, 6576 deletions
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 1a211f5..9e1e931 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -124,10 +124,13 @@ set_target_properties(mlir-doc PROPERTIES FOLDER "MLIR/Docs") # Only enable execution engine if the native target is available. if(${LLVM_NATIVE_ARCH} IN_LIST LLVM_TARGETS_TO_BUILD) - set(MLIR_ENABLE_EXECUTION_ENGINE 1) + set(MLIR_ENABLE_EXECUTION_ENGINE_default 1) else() - set(MLIR_ENABLE_EXECUTION_ENGINE 0) + set(MLIR_ENABLE_EXECUTION_ENGINE_default 0) endif() +option(MLIR_ENABLE_EXECUTION_ENGINE + "Enable building the MLIR Execution Engine." + ${MLIR_ENABLE_EXECUTION_ENGINE_default}) # Build the ROCm conversions and run according tests if the AMDGPU backend # is available. @@ -210,6 +213,19 @@ set(MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES 0 CACHE BOOL 'Development.Module' and ensure that find_package(pybind11) is \ satisfied (and keep up to date as requirements evolve).") +set(_mlir_python_stubgen_enabled ON) +# Stubgen doesn't work when cross-compiling (stubgen will run in the host interpreter and then fail +# to find the extension module for the host arch). +# Note: Stubgen requires some extra handling to work properly when sanitizers are enabled, +# so we skip running it in that case now. +if(CMAKE_CROSSCOMPILING OR (NOT LLVM_USE_SANITIZER STREQUAL "")) + set(_mlir_python_stubgen_enabled OFF) +endif() + +option(MLIR_PYTHON_STUBGEN_ENABLED + "Generate Python type stubs for the MLIR Python bindings." + ${_mlir_python_stubgen_enabled}) + if(MLIR_ENABLE_BINDINGS_PYTHON) include(MLIRDetectPythonEnv) # Note that both upstream and downstreams often call this macro. It gates diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md index b892bbe..37604fc 100644 --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`: ```python @linalg_structured_op -def fill(value=ScalarDef(T1), - O=TensorDef(U, output=True)): - O[None] = TypeFn.cast_signed(U, value) +def fill(value=ScalarDef(T), + O=TensorDef(T, output=True)): + O[None] = value ``` -The operation sets the elements of the output tensor `O` to `value`. All -operands are either scalars or rank zero tensors that are accessed using the -index `None`. The operation thus performs a scalar computation that trivially -extends to a multi-dimensional pointwise computation. As a result, we may use -`fill` with arbitrary ranked output tensors: +The operation sets the elements of the output tensor `O` to `value`. The value +type must match the element type of the output tensor. All operands are either +scalars or rank zero tensors that are accessed using the index `None`. The +operation thus performs a scalar computation that trivially extends to a +multi-dimensional pointwise computation. As a result, we may use `fill` with +arbitrary ranked output tensors: ```python tensor_2d = tensor.EmptyOp([4, 8], f32) diff --git a/mlir/docs/Dialects/NVVMDialect.md b/mlir/docs/Dialects/NVVMDialect.md new file mode 100644 index 0000000..b2f5e888 --- /dev/null +++ b/mlir/docs/Dialects/NVVMDialect.md @@ -0,0 +1,117 @@ +# 'nvvm' Dialect + +The NVVM dialect is MLIR's LLVM-IR-based, NVIDIA-specific backend dialect. It +models NVVM intrinsics and public ISA functionality and introduces NVIDIA +extensions to the MLIR/LLVM type system and address spaces (e.g., global, +shared, and cluster memory), enabling faithful lowering of GPU kernels to the +NVPTX toolchain. While a NVVM op usually maps to a single LLVM IR intrinsic, +the NVVM dialect uses type polymorphism and other attributes so that a single +NVVM op can map to different LLVM intrinsics. + +[TOC] + +## Scope and Capabilities + +The dialect covers core GPU features such as thread/block builtins, barriers +and atomics, warp-level collectives (e.g., shuffle/vote), matrix/tensor core +operations (e.g., `mma.sync`, `wgmma`), tensor memory accelerator (TMA) +operations, asynchronous copies (`cp.async`, bulk/tensor variants) with memory +barriers, cache and prefetch controls, and NVVM-specific attributes and enums +(e.g., FP rounding modes, memory scopes, and MMA types/layouts). + +## Placement in the Lowering Pipeline + +NVVM sits below target-agnostic dialects like `gpu` and NVIDIA's `nvgpu`. +Typical pipelines convert `gpu`/`nvgpu` ops into NVVM using +`-convert-gpu-to-nvvm` and `-convert-nvgpu-to-nvvm`, then translate into LLVM +for final code generation via NVPTX backend. + +## Target Configuration and Serialization + +NVVM provides a `#nvvm.target` attribute to describe the GPU target (SM, +features, and flags). In conjunction with `gpu` serialization (e.g., +`gpu-module-to-binary`), this enables producing architecture-specific GPU +binaries (such as CUBIN) from nested GPU modules. + +## Inline PTX + +When an intrinsic is unavailable or a performance-critical sequence must be +expressed directly, NVVM provides an `nvvm.inline_ptx` op to embed PTX inline +as a last-resort escape hatch, with explicit operands and results. + +## Memory Spaces + +The NVVM dialect introduces the following memory spaces, each with distinct +scopes and lifetimes: + +| Memory Space | Address Space | Scope | +|-------------------|---------------|----------------------| +| `generic` | 0 | All threads | +| `global` | 1 | All threads (device) | +| `shared` | 3 | Thread block (CTA) | +| `constant` | 4 | All threads | +| `local` | 5 | Single thread | +| `tensor` | 6 | Thread block (CTA) | +| `shared_cluster` | 7 | Thread block cluster | + +### Memory Space Details + +- **generic**: Can point to any memory space; requires runtime resolution of + actual address space. Use when pointer origin is unknown at compile time. + Performance varies based on the underlying memory space. A pointer to this + memory space is represented by `LLVM_PointerGeneric` in the NVVM Ops. +- **global**: Accessible by all threads across all blocks; persists across + kernel launches. Highest latency but largest capacity (device memory). Best + for large data and inter-kernel communication. A pointer to this memory space + is represented by `LLVM_PointerGlobal` in the NVVM Ops. +- **shared**: Shared within a thread block (CTA); very fast on-chip memory for + cooperation between threads in the same block. Limited capacity. Ideal for + block-level collaboration, caching, and reducing global memory traffic. + This memory is usually referred as `shared_cta` in the NVVMOps and as + `shared::cta` in the PTX ISA. A pointer to this memory space is represented + by the `LLVM_PointerShared` type in the NVVM Ops. +- **constant**: Read-only memory cached per SM. Size typically limited to 64KB. + Best for read-only data and uniform values accessed by all threads. A pointer + to this memory space is represented by `LLVM_PointerConst` type in NVVM Ops. +- **local**: Private to each thread. Use for per-thread private data and + automatic variables that don't fit in registers. A pointer to this memory is + represented by `LLVM_PointerLocal` type in NVVM Ops. +- **tensor**: Special memory space for tensor core operations. Used by + `tcgen05` instructions on SM 100+ for tensor input/output operations. + A pointer to this memory space is represented by the `LLVM_PointerTensor` + type in the NVVM Ops. +- **shared_cluster**: Distributed shared memory across thread blocks within a + cluster (SM 90+). Enables collaboration beyond single-block scope with fast + access across cluster threads. This memory is usually referred as + `shared_cluster` in the NVVMOps and as `shared::cluster` in the PTX ISA. + A pointer to this memory space is represented by the `LLVM_PointerSharedCluster` + type in the NVVM Ops. + +## MBarrier objects + +An ``mbarrier`` is a barrier created in shared memory that supports +synchronizing any subset of threads within a CTA. An *mbarrier object* +is an opaque object in shared memory with `.b64` type and an alignment of +8-bytes. Unlike ``nvvm.barrier`` Op which can access only a limited number +of barriers per CTA, the *mbarrier objects* are user-defined and are only +limited by the total shared memory size available. The list of operations +supported on an *mbarrier object* is exposed through the ``nvvm.mbarrier.*`` +family of NVVM Ops. + +## Non-Goals + +NVVM is not a place for convenience or "wrapper" ops. It is not intended to +introduce high-level ops that expand into multiple unrelated NVVM intrinsics or +that lower to no intrinsic at all. Such abstractions belong in higher-level +dialects (e.g., `nvgpu`, `gpu`, or project-specific dialects). The design +intent is a thin, predictable, low-level surface with near-mechanical lowering +to NVVM/LLVM IR. + + +## Operations + +All operations in the NVIDIA's instruction set have a custom form in MLIR. The mnemonic +of an operation is that used in LLVM IR prefixed with "`nvvm.`". + +[include "Dialects/NVVMOps.md"] + diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md index 716dd77..dd68e6e 100644 --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -566,7 +566,7 @@ merge block. For example, for the given function ```c++ -void loop(bool cond) { +void if(bool cond) { int x = 0; if (cond) { x = 1; @@ -605,6 +605,62 @@ func.func @selection(%cond: i1) -> () { } ``` +Similarly, for the give function with a `switch` statement + +```c++ +void switch(int selector) { + int x = 0; + switch (selector) { + case 0: + x = 2; + break; + case 1: + x = 3; + break; + default: + x = 1; + break; + } + // ... +} +``` + +It will be represented as + +```mlir +func.func @selection(%selector: i32) -> () { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function> + + spirv.mlir.selection { + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1 + ] + ^default: + spirv.Store "Function" %var, %one : i32 + spirv.Branch ^merge + + ^case0: + spirv.Store "Function" %var, %two : i32 + spirv.Branch ^merge + + ^case1: + spirv.Store "Function" %var, %three : i32 + spirv.Branch ^merge + + ^merge: + spirv.mlir.merge + } + + // ... +} +``` + The selection can return values by yielding them with `spirv.mlir.merge`. This mechanism allows values defined within the selection region to be used outside of it. Without this, values that were sunk into the selection region, but used outside, would diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index 7e1c5fe..d3e1888 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -85,6 +85,72 @@ if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(diale } ``` +#### Utilizing the ODS framework + +Note: Before reading this section, the reader should have some familiarity with +the concepts described in the +[`Operation Definition Specification`](DefiningDialects/Operations.md) documentation. + +MLIR also supports defining dialect interfaces directly in **TableGen**. +This reduces boilerplate and allows authors to specify high-level interface +structure declaratively. + +For example, the above interface can be defined using ODS as follows: + +```tablegen +def DialectInlinerInterface : DialectInterface<"DialectInlinerInterface"> { + let description = [{ + Define a base inlining interface class to allow for dialects to opt-in to + the inliner. + }]; + + let methods = [ + InterfaceMethod<[{ + Returns true if the given region 'src' can be inlined into the region + 'dest' that is attached to an operation registered to the current dialect. + 'valueMapping' contains any remapped values from within the 'src' region. + This can be used to examine what values will replace entry arguments into + the 'src' region, for example. + }], + "bool", "isLegalToInline", + (ins "Region *":$dest, "Region *":$src, "IRMapping &":$valueMapping), + [{ + return false; + }] + > + ]; +} +``` + +`DialectInterfaces` class make use of the following components: + +* C++ Class Name (Provided via template parameter) + - The name of the C++ interface class. +* Description (`description`) + - A string description of the interface, its invariants, example usages, + etc. +* C++ Namespace (`cppNamespace`) + - The C++ namespace that the interface class should be generated in. +* Methods (`methods`) + - The list of interface hook methods that are defined by the IR object. + - The structure of these methods is defined [here](#interface-methods). + +The header file can be generated via the following command: + +```bash +mlir-tblgen --gen-dialect-interface-decls DialectInterface.td +``` + +To generate dialect interface declarations using the ODS framework in CMake, you would write: + +```cmake +set(LLVM_TARGET_DEFINITIONS DialectInlinerInterface.td) +mlir_tablegen(DialectInlinerInterface.h.inc -gen-dialect-interface-decls) +``` + +An example of this can be found in the DialectInlinerInterface implementation +and the related `CMakeLists.txt` under `mlir/include/mlir/Transforms`. + #### DialectInterfaceCollection An additional utility is provided via `DialectInterfaceCollection`. This class @@ -364,10 +430,6 @@ void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, #### Utilizing the ODS Framework -Note: Before reading this section, the reader should have some familiarity with -the concepts described in the -[`Operation Definition Specification`](DefiningDialects/Operations.md) documentation. - As detailed above, [Interfaces](#attributeoperationtype-interfaces) allow for attributes, operations, and types to expose method calls without requiring that the caller know the specific derived type. The downside to this infrastructure, diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md index 10cfba9..b1da4b9 100644 --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -424,7 +424,7 @@ func.func @simple(i64, i1) -> i64 { **Context:** The "block argument" representation eliminates a number of special cases from the IR compared to traditional "PHI nodes are operations" SSA IRs (like LLVM). For example, the -[parallel copy semantics](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.524.5461&rep=rep1&type=pdf) +[parallel copy semantics](https://ieeexplore.ieee.org/document/4907656) of SSA is immediately apparent, and function arguments are no longer a special case: they become arguments to the entry block [[more rationale](Rationale/Rationale.md/#block-arguments-vs-phi-nodes)]. Blocks diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt index df19fa8..8469bff 100644 --- a/mlir/examples/standalone/python/CMakeLists.txt +++ b/mlir/examples/standalone/python/CMakeLists.txt @@ -74,12 +74,7 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI set(StandalonePythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}") -set(_mlir_python_stubgen_enabled ON) -if(CMAKE_CROSSCOMPILING OR (NOT LLVM_USE_SANITIZER STREQUAL "")) - set(_mlir_python_stubgen_enabled OFF) -endif() - -if(_mlir_python_stubgen_enabled) +if(MLIR_PYTHON_STUBGEN_ENABLED) # Everything here is very tightly coupled. See the ample descriptions at the bottom of # mlir/python/CMakeLists.txt. @@ -146,7 +141,7 @@ set(_declared_sources ) # For an external projects build, the MLIRPythonExtension.Core.type_stub_gen # target already exists and can just be added to DECLARED_SOURCES. -if(EXTERNAL_PROJECT_BUILD AND _mlir_python_stubgen_enabled) +if(EXTERNAL_PROJECT_BUILD AND MLIR_PYTHON_STUBGEN_ENABLED) list(APPEND _declared_sources MLIRPythonExtension.Core.type_stub_gen) endif() @@ -158,7 +153,7 @@ add_mlir_python_modules(StandalonePythonModules StandalonePythonCAPI ) -if(_mlir_python_stubgen_enabled) +if(MLIR_PYTHON_STUBGEN_ENABLED) if(NOT EXTERNAL_PROJECT_BUILD) add_dependencies(StandalonePythonModules "${_mlir_typestub_gen_target}") endif() diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index c1ade9e..cc7f09f 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -23,6 +23,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace); +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMPointerTypeGetTypeID(void); + /// Returns `true` if the type is an LLVM dialect pointer type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type); @@ -58,6 +60,8 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID(void); + /// Returns `true` if the type is a literal (unnamed) LLVM struct type. MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type); diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 339e63d..003b0cde 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -10,6 +10,7 @@ #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H +#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions { MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op); +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, + size_t numMaps); + MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); typedef struct MlirLinalgConvolutionDimensions { diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 1a58d68..2a81798 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -41,10 +41,13 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void); /// generation. The number and array of paths corresponding to shared libraries /// that will be loaded are specified via `numPaths` and `sharedLibPaths` /// respectively. +/// The `enablePIC` arguments controls the relocation model, when true the +/// generated code is emitted as "position independent", making it possible to +/// save it and reload it as a shared object in another process. /// TODO: figure out other options. MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( MlirModule op, int optLevel, int numPaths, - const MlirStringRef *sharedLibPaths, bool enableObjectDump); + const MlirStringRef *sharedLibPaths, bool enableObjectDump, bool enablePIC); /// Initialize the ExecutionEngine. Global constructors specified by /// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index c464e4d..d2f4762 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -1051,6 +1051,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirBlockArgumentGetArgNumber(MlirValue value); MLIR_CAPI_EXPORTED void mlirBlockArgumentSetType(MlirValue value, MlirType type); +/// Sets the location of the block argument to the given location. +MLIR_CAPI_EXPORTED void mlirBlockArgumentSetLocation(MlirValue value, + MlirLocation loc); + /// Returns an operation that produced this value as its result. Asserts if the /// value is not an op result. MLIR_CAPI_EXPORTED MlirOperation mlirOpResultGetOwner(MlirValue value); diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index 72ec606..4975ced 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -48,7 +48,7 @@ class IntegerRangeAnalysis public: using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; - /// At an entry point, we cannot reason about interger value ranges. + /// At an entry point, we cannot reason about integer value ranges. void setToEntryState(IntegerValueRangeLattice *lattice) override { propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange( lattice->getAnchor()))); diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index f865357..2ae4bef 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -196,6 +196,7 @@ public: inline DynamicAPInt atIneq(unsigned i, unsigned j) const { return inequalities(i, j); } + /// The same, but casts to int64_t. This is unsafe and will assert-fail if the /// value does not fit in an int64_t. inline int64_t atIneq64(unsigned i, unsigned j) const { @@ -209,6 +210,19 @@ public: return getNumInequalities() + getNumEqualities(); } + /// Unified indexing into the constraints. Index into the inequalities + /// if i < getNumInequalities() and into the equalities otherwise. + inline int64_t atConstraint64(unsigned i, unsigned j) const { + assert(i < getNumConstraints()); + unsigned numIneqs = getNumInequalities(); + return i < numIneqs ? atIneq64(i, j) : atEq64(i - numIneqs, j); + } + inline DynamicAPInt &atConstraint(unsigned i, unsigned j) { + assert(i < getNumConstraints()); + unsigned numIneqs = getNumInequalities(); + return i < numIneqs ? atIneq(i, j) : atEq(i - numIneqs, j); + } + unsigned getNumDomainVars() const { return space.getNumDomainVars(); } unsigned getNumRangeVars() const { return space.getNumRangeVars(); } unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); } @@ -351,6 +365,7 @@ public: void removeEquality(unsigned pos); void removeInequality(unsigned pos); + void removeConstraint(unsigned pos); /// Remove the (in)equalities at positions [start, end). void removeEqualityRange(unsigned start, unsigned end); @@ -511,6 +526,34 @@ public: void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + /// The function removes some constraints that do not impose any bound on the + /// specified variable. + /// + /// The set of constraints (equations/inequalities) can be modeled as an + /// undirected graph where: + /// 1. Variables are the nodes. + /// 2. Constraints are the edges connecting those nodes. + /// + /// Variables and constraints belonging to different connected components + /// are irrelevant to each other. This property allows for safe pruning of + /// constraints. + /// + /// For example, given the following constraints: + /// - Inequalities: (1) d0 + d1 > 0, (2) d1 >= 2, (3) d4 > 5 + /// - Equalities: (4) d3 + d4 = 1, (5) d0 - d2 = 3 + /// + /// These form two connected components: + /// - Component 1: {d0, d1, d2} (related by constraints 1, 2, 5) + /// - Component 2: {d3, d4} (related by constraint 4) + /// + /// If we are querying the bound of variable `d0`, constraints related to + /// Component 2 (e.g., constraints 3 and 4) can be safely pruned as they + /// have no impact on the solution space of Component 1. + /// This function prunes irrelevant constraints by identifying all variables + /// and constraints that belong to the same connected component as the + /// target variable. + void pruneOrthogonalConstraints(unsigned pos); + /// Tries to fold the specified variable to a constant using a trivial /// equality detection; if successful, the constant is substituted for the /// variable everywhere in the constraint system and then removed from the diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h index 97573b6..c7808e7 100644 --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -128,7 +128,7 @@ private: /// variable and q is a local variable. Let us put the constraints: /// `1 <= x <= 7, x = 2q` /// on this space to get the set: -/// `(x) : (exists q : q <= x <= 7, x = 2q)`. +/// `(x) : (exists q : 1 <= x <= 7, x = 2q)`. /// An assignment to symbolic and dimension variables is valid if there /// exists some assignment to the local variable `q` satisfying these /// constraints. For this example, the set is equivalent to {2, 4, 6}. @@ -136,7 +136,7 @@ private: /// of projection. In this example, `q` is existentially quantified. This can be /// thought of as the result of projecting out `q` from the previous example, /// i.e. we obtained {2, 4, 6} by projecting out the second dimension from -/// {(2, 1), (4, 2), (6, 2)}. +/// {(2, 1), (4, 2), (6, 3)}. /// /// Dimension variables are further divided into Domain and Range variables /// to support building relations. diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h index ca942c8..8dc8a0d 100644 --- a/mlir/include/mlir/Bindings/Python/Nanobind.h +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -30,6 +30,7 @@ #include <nanobind/stl/string_view.h> #include <nanobind/stl/tuple.h> #include <nanobind/stl/vector.h> +#include <nanobind/typing.h> #if defined(__clang__) || defined(__GNUC__) #pragma GCC diagnostic pop #endif diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h index 7ffc861..7020e24 100644 --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -65,11 +65,8 @@ public: convertArithFastMathAttrToLLVM(arithFMFAttr)); } } - ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } + Attribute getPropAttr() const { return {}; } private: NamedAttrList convertedAttr; @@ -82,23 +79,36 @@ template <typename SourceOp, typename TargetOp> class AttrConvertOverflowToLLVM { public: AttrConvertOverflowToLLVM(SourceOp srcOp) { + using IntegerOverflowFlagsAttr = LLVM::IntegerOverflowFlagsAttr; + // Copy the source attributes. convertedAttr = NamedAttrList{srcOp->getAttrs()}; // Get the name of the arith overflow attribute. StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); - // Remove the source overflow attribute. + // Remove the source overflow attribute from the set that will be present + // in the target. if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( convertedAttr.erase(arithAttrName))) { - overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); + auto llvmFlag = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); + // Create a dictionary attribute holding the overflow flags property. + // (In the LLVM dialect, the overflow flags are a property, not an + // attribute.) + MLIRContext *ctx = srcOp.getOperation()->getContext(); + Builder b(ctx); + auto llvmFlagAttr = IntegerOverflowFlagsAttr::get(ctx, llvmFlag); + StringRef llvmAttrName = TargetOp::getOverflowFlagsAttrName(); + NamedAttribute attr{llvmAttrName, llvmFlagAttr}; + // Set the properties attribute of the operation state so that the + // property can be updated when the operation is created. + propertiesAttr = b.getDictionaryAttr(ArrayRef(attr)); } } - ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; } + Attribute getPropAttr() const { return propertiesAttr; } private: NamedAttrList convertedAttr; - LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none; + DictionaryAttr propertiesAttr; }; template <typename SourceOp, typename TargetOp> @@ -129,9 +139,7 @@ public: } ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } + Attribute getPropAttr() const { return {}; } private: NamedAttrList convertedAttr; diff --git a/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h new file mode 100644 index 0000000..64a42a2 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h @@ -0,0 +1,21 @@ +//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===// +// +// Part of the APFloat Project, under the Apache License v2.0 with APFloat +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H +#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H + +#include <memory> + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 4c8abea..48982ac 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -27,7 +27,7 @@ class MMAMatrixType; #define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS #include "mlir/Conversion/Passes.h.inc" -LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type); +Type convertMMAToLLVMType(gpu::MMAMatrixType type); /// Configure target to convert from the GPU dialect to NVVM. void configureGpuToNVVMConversionLegality(ConversionTarget &target); diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index c292e37..f8e0ccc 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -19,16 +19,14 @@ class CallOpInterface; namespace LLVM { namespace detail { -/// Handle generically setting flags as native properties on LLVM operations. -void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags); - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. -LogicalResult oneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); +LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef<NamedAttribute> targetAttrs, + Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); /// Replaces the given operation "op" with a call to an LLVM intrinsic with the /// specified name "intrinsic" and operands. @@ -307,9 +305,9 @@ public: LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - adaptor.getOperands(), op->getAttrs(), - *this->getTypeConverter(), rewriter); + return LLVM::detail::oneToOneRewrite( + op, TargetOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), + /*propertiesAttr=*/Attribute{}, *this->getTypeConverter(), rewriter); } }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index e7ab63a..32dd8ba 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -54,25 +54,32 @@ LogicalResult handleMultidimensionalVectors( std::function<Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter); -LogicalResult vectorOneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); +LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + ArrayRef<NamedAttribute> targetAttrs, + Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); + +/// Return "true" if the given type is an unsupported floating point type. In +/// case of a vector type, return "true" if the element type is an unsupported +/// floating point type. +bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter, + Type type); } // namespace detail } // namespace LLVM // Default attribute conversion class, which passes all source attributes -// through to the target op, unmodified. +// through to the target op, unmodified. The attribute to set properties of the +// target operation will be nullptr (i.e. any properties that exist in will have +// default values). template <typename SourceOp, typename TargetOp> class AttrConvertPassThrough { public: AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; } - LLVM::IntegerOverflowFlags getOverflowFlags() const { - return LLVM::IntegerOverflowFlags::none; - } + Attribute getPropAttr() const { return {}; } private: ArrayRef<NamedAttribute> srcAttrs; @@ -80,10 +87,13 @@ private: /// Basic lowering implementation to rewrite Ops with just one result to the /// LLVM Dialect. This supports higher-dimensional vector types. -/// The AttrConvert template template parameter should be a template class -/// with SourceOp and TargetOp type parameters, a constructor that takes -/// a SourceOp instance, and a getAttrs() method that returns -/// ArrayRef<NamedAttribute>. +/// The AttrConvert template template parameter should: +// - be a template class with SourceOp and TargetOp type parameters +// - have a constructor that takes a SourceOp instance +// - a getAttrs() method that returns ArrayRef<NamedAttribute> containing +// attributes that the target operation will have +// - a getPropAttr() method that returns either a NULL attribute or a +// DictionaryAttribute with properties that exist for the target operation template <typename SourceOp, typename TargetOp, template <typename, typename> typename AttrConvert = AttrConvertPassThrough, @@ -93,16 +103,6 @@ public: using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; - /// Return the given type if it's a floating point type. If the given type is - /// a vector type, return its element type if it's a floating point type. - static FloatType getFloatingPointType(Type type) { - if (auto floatType = dyn_cast<FloatType>(type)) - return floatType; - if (auto vecType = dyn_cast<VectorType>(type)) - return dyn_cast<FloatType>(vecType.getElementType()); - return nullptr; - } - LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -110,26 +110,18 @@ public: std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, "expected single result op"); - // The pattern should not apply if a floating-point operand is converted to - // a non-floating-point type. This indicates that the floating point type - // is not supported by the LLVM lowering. (Such types are converted to - // integers.) - auto checkType = [&](Value v) -> LogicalResult { - FloatType floatType = getFloatingPointType(v.getType()); - if (!floatType) - return success(); - Type convertedType = this->getTypeConverter()->convertType(floatType); - if (!isa_and_nonnull<FloatType>(convertedType)) - return rewriter.notifyMatchFailure(op, - "unsupported floating point type"); - return success(); - }; + // Bail on unsupported floating point types. (These are type-converted to + // integer types.) if (FailOnUnsupportedFP) { for (Value operand : op->getOperands()) - if (failed(checkType(operand))) - return failure(); - if (failed(checkType(op->getResult(0)))) - return failure(); + if (LLVM::detail::isUnsupportedFloatingPointType( + *this->getTypeConverter(), operand.getType())) + return rewriter.notifyMatchFailure(op, + "unsupported floating point type"); + if (LLVM::detail::isUnsupportedFloatingPointType( + *this->getTypeConverter(), op->getResult(0).getType())) + return rewriter.notifyMatchFailure(op, + "unsupported floating point type"); } // Determine attributes for the target op @@ -137,8 +129,8 @@ public: return LLVM::detail::vectorOneToOneRewrite( op, TargetOp::getOperationName(), adaptor.getOperands(), - attrConvert.getAttrs(), *this->getTypeConverter(), rewriter, - attrConvert.getOverflowFlags()); + attrConvert.getAttrs(), attrConvert.getPropAttr(), + *this->getTypeConverter(), rewriter); } }; } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 40d866e..82bdfd0 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 70e3e45..fcbaf3cc 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -187,6 +187,22 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> { } //===----------------------------------------------------------------------===// +// ArithToAPFloat +//===----------------------------------------------------------------------===// + +def ArithToAPFloatConversionPass + : Pass<"convert-arith-to-apfloat", "ModuleOp"> { + let summary = "Convert Arith ops to APFloat runtime library calls"; + let description = [{ + This pass converts supported Arith ops to APFloat-based runtime library + calls (APFloatWrappers.cpp). APFloat is a software implementation of + floating-point arithmetic operations. + }]; + let dependentDialects = ["arith::ArithDialect", "func::FuncDialect", + "vector::VectorDialect"]; +} + +//===----------------------------------------------------------------------===// // ArithToSPIRV //===----------------------------------------------------------------------===// @@ -613,6 +629,8 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { /*default=*/"false", "Replace memref arguments in GPU functions with bare pointers. " "All memrefs must have static shape.">, + Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true", + "Experimental performance flag to disallow pattern rollback">, ListOption<"allowedDialects", "allowed-dialects", "std::string", "Run conversion patterns of only the specified dialects">, ]; @@ -1069,6 +1087,10 @@ def SCFToControlFlowPass : Pass<"convert-scf-to-cf"> { let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured" " control flow with a CFG"; let dependentDialects = ["cf::ControlFlowDialect"]; + let options = [ + Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true", + "Experimental performance flag to disallow pattern rollback"> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 45cb67f..56160d3 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -33,6 +33,7 @@ def AMDGPU_Dialect : Dialect { "gpu::GPUDialect" ]; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">; @@ -80,6 +81,39 @@ def AMDGPU_AddressSpaceAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_AddressSpace, } //===----------------------------------------------------------------------===// +// AMDGPU Type definitions +//===----------------------------------------------------------------------===// + +class AMDGPU_Type<string name, string typeMnemonic, list<Trait> traits = []> + : TypeDef<AMDGPU_Dialect, name, traits> { + let mnemonic = typeMnemonic; +} + +def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> { + let summary = "Pair of base addresses that move data between LDS and global storage."; + let description = [{ + This type is opaque and it is used to represent a struct of two addresses. + One address is in LDS while the other is in global memory. + }]; + let parameters = (ins "Type":$elementType); + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ + return $_get(elementType.getContext(), elementType); + }]> + ]; + let assemblyFormat = "`<` $elementType `>`"; +} + +def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> { + let summary = "Descriptors used in tensor store/load operations."; + let description = [{ + This type is opaque and corresponds to the two or four descriptor groups + used in tensor_load_to_lds or tensor_store_from_lds. + }]; + +} + +//===----------------------------------------------------------------------===// // AMDGPU Op definitions //===----------------------------------------------------------------------===// @@ -112,12 +146,8 @@ def AMDGPU_ExtPackedFp8Op : }]; } -def IsValidBlockSize: AttrConstraint< - CPred<"::llvm::is_contained({16, 32}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">, - "whose value is 16 or 32">; - -def AMDGPU_ScaledExtPacked816Op - : AMDGPU_Op<"scaled_ext_packed816", [Pure, AllShapesMatch<["source", "res"]>]>, +def AMDGPU_ScaledExtPackedMatrixOp + : AMDGPU_Op<"scaled_ext_packed_matrix", [Pure, AllShapesMatch<["source", "res"]>]>, Arguments<( ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>, FixedVectorOfShapeAndType<[8], F8E4M3FN>, @@ -125,9 +155,9 @@ def AMDGPU_ScaledExtPacked816Op FixedVectorOfShapeAndType<[16], F6E2M3FN>, FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source, FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale, - ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize, - ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane, - ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>, + ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$blockSize, + ConfinedAttr<I32Attr, [IntIsOneOf<[0, 16]>]>:$firstScaleLane, + ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<3>]>:$firstScaleByte)>, Results<( outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>, FixedVectorOfShapeAndType<[8], F16>, @@ -136,57 +166,75 @@ def AMDGPU_ScaledExtPacked816Op FixedVectorOfShapeAndType<[16], F16>, FixedVectorOfShapeAndType<[16], BF16>]>:$res)> { - let summary = "Extend a vector of packed floating point values"; + let summary = "Extend a wave-wide matrix of packed floating point values"; let description = [{ - The scales applied to the input microfloats are stored in two bytes which + Extend matrix of microfloats (8 or 16 elements per lane) using a set of scales + that may be stored on other lanes. + + The scales applied to the input microfloats are stored in bytes which come from the `scales` input provided in a *half* of the wave identified - by `firstScaleLane`. The pair of bytes used is selected by - `firstScaleByte`. The 16 vectors in consecutive lanes starting from + by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends + on the type of `source`. The 16 vectors in consecutive lanes starting from `firstScaleLane` (which we'll call the scale vectors) will be used by both - halves of the wave (with lane L reading from L % 16'th scale vector), but - each half will use a different byte. + halves of the wave (with lane L reading from L % 16'th scale vector). + + When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN each half of the + wave will use a different byte. The first one being `firstScaleByte` and + the second one being `firstScaleByte` + 1. When the block size is 32, + `firstScaleByte` can be either 0 or 2, selecting halves of the scale vectors. + Lanes 0-15 will read from `firstScaleByte` and lanes 16-31 will read + from `firstScaleByte` + 1. + - When the block size is 32, `firstScaleByte` can be either 0 or 2, - selecting halves of the scale vectors. Lanes 0-15 will read from - `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1. For example: ```mlir // Input: 8-element vector of F8E4M3FN, converting to F32 // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1 - %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> // Input: 16-element vector of F6E2M3FN, converting to F16 // Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3 - %result = amdgpu.scaled_ext_packed816 %source scale(%scales) - blockSize(32) firstScaleLane(1) firstScaleByte(2) + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) + blockSize(32) firstScaleLane(16) firstScaleByte(2) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> ``` - However, when the block size is 16, `firstScaleByte` can be 0 or 1. + When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN and + the block size is 16, `firstScaleByte` can be 0 or 1. Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors, while lanes 16-31 read from `firstScaleByte` + 2. For example: ```mlir // Input: 8-element vector of F8E5M2, converting to BF16 // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2) - %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> // Input: 16-element vector of F6E3M2FN, converting to F32 // Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2) - %result = amdgpu.scaled_ext_packed816 %source scale(%scales) - blockSize(16) firstScaleLane(1) firstScaleByte(1) + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) + blockSize(16) firstScaleLane(16) firstScaleByte(1) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> ``` Note: the layout for the scales generally mirrors how the WMMA - instructions use for matix scales. These selection operands allows + instructions use for matrix scales. These selection operands allows one to choose portions of the matrix to convert. + When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32, + then the same byte will be used by both halves of the wave. + In this case, `firstScaleByte` can be any value from 0 to 3. + + When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16, + following combinations are allowed: + * `firstScaleLane(0), firstScaleByte(0)` + * `firstScaleLane(16), firstScaleByte(2)` + all other combinations are reserved. + Available on gfx1250+. }]; @@ -858,7 +906,8 @@ def AMDGPU_MemoryCounterWaitOp : OptionalAttr<I32Attr>:$load, OptionalAttr<I32Attr>:$store, OptionalAttr<I32Attr>:$ds, - OptionalAttr<I32Attr>:$exp + OptionalAttr<I32Attr>:$exp, + OptionalAttr<I32Attr>:$tensor )> { let summary = "Wait for specified hardware counters"; @@ -871,8 +920,10 @@ def AMDGPU_MemoryCounterWaitOp : counters into one. }]; let assemblyFormat = [{ - oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` ) attr-dict + oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` | `tensor` `(` $tensor `)` ) attr-dict }]; + + let hasCanonicalizer = 1; } def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB", @@ -1177,4 +1228,165 @@ def AMDGPU_ScaledMFMAOp : }]; let hasCanonicalizer = 1; } + +def AMDGPU_MakeDmaBaseOp : + AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>, + Arguments<(ins Arg<AnyMemRef>:$global, + Variadic<Index>:$global_indices, + Arg<AnyMemRef>:$lds, + Variadic<Index>:$lds_indices)>, + Results<(outs AMDGPU_TDMBaseType: $base)> { + + // TODO: + // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions. + + let summary = "Pair of based addresses used when moving tiles between LDS and global memory."; + let description = [{ + This operation creates a pair of addresses that will be used by tensor_load_to_lds + and tensor_store_from_lds. + + This operation creates a value corresponding to the tensor descriptor (D#) group 0 + found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect. + + For example: + + ```mlir + %base = amdgpu.make_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> + %descriptor = amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor + ``` + + to + + ```mlir + // pseudo-code + %global_base = llvm.extractvalue %global_memref[1] + %global_address = llvm.get_element_ptr ... + + %lds_base = llvm.extractvalue %lds_memref[1] + %lds_address = llvm.get_element_ptr ... + + // Definition of %base + %undef = llvm.mlir.undef : vector<4xi32> + %v0 = llvm.insertelement %15, %undef[0] : vector<4xi32> + %v1 = llvm.insertelement %lds_address, %v0[1] : vector<4xi32> + %v2 = llvm.insertelement %global_address_low, %v1[2] : vector<4xi32> + %base = llvm.insertelement %global_address_high, %v2[3] : vector<4xi32> + + rocdl.tensor.load.to.lds %base, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + ``` + + These tensor DMA operations were introduced in gfx1250. + }]; + + let assemblyFormat = [{ + $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results) + }]; + + let hasVerifier = 1; +} + +def AMDGPU_MakeDmaDescriptorOp : + AMDGPU_Op<"make_dma_descriptor", [Pure, AttrSizedOperandSegments]>, + Arguments<(ins + AMDGPU_TDMBaseType: $base, + Variadic<Index>: $global_dynamic_sizes, + DenseI64ArrayAttr: $global_static_sizes, + Variadic<Index>: $global_dynamic_strides, + DenseI64ArrayAttr: $global_static_strides, + Variadic<Index>: $shared_dynamic_sizes, + DenseI64ArrayAttr: $shared_static_sizes, + Optional<I16>: $workgroup_mask, + Optional<I1>: $early_timeout, + Optional<Index>: $pad_amount, + Optional<Index>: $pad_interval, + Optional<AnyMemRef>: $atomic_barrier_address, + Variadic<Index>: $atomic_barrier_indices, + Optional<Index>: $global_increment, + Optional<Index>: $lds_increment, + Optional<Index>: $iteration_count)>, + Results<(outs AMDGPU_TDMDescriptorType: $desc)> { + + let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS."; + let description = [{ + Make all descriptor groups needed by tensor memory operations. + + The $base operand corresponds to the base pair addresses, one must be an address in LDS + while the other must be a global memory location. + + $global_{static/dynamic}_sizes determine the size of the tensor. + $global_{static/dynamic}_strides determine the strides of the tensor. + $shared_{static/dynamic}_sizes determines the size of the tile. + + $workgroup_mask broadcast load to workgroups inside of a workgroup cluster + (0 = do not broadcast result to workgroup, 1 = broadcast result to workgroup). Ignored for stores. + An all zeros mask is interpreted as a non-broadcasted load. + + $early_timeout return data to requesters as soon as cache supplies it. + + Padding can be applied to the LDS address when copying from memory to LDS, + but not when copying from LDS to memory. + The values in the padded target addresses remain the same as before the operation was applied. + $pad_interval must be a power of two contained in [2, 256]. + $pad_amount must be a value contained in [1, 128]. + + $atomic_barrier_address must be aligned to 8 bytes. + + 2D and 3D tensors may be iterated over by setting $global_increment, $lds_increment, and $iteration_count. + $global_increment determines how much to increment the starting global memory address per iteration in units of the $base's element type. + $lds_increment determines how much to increment the starting LDS address per iteration in units of the $base's element type. + $iterate_count determines how many times to iterate. + + ```mlir + // Example of moving a two-dimensional tensor to LDS. + %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> + %descriptor = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor + + // Example of moving a two dimension tensor to LDS where padding is applied after every integer. + %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> + %descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padding(%pad_amount pad_every %pad_interval) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor + ``` + }]; + + let assemblyFormat = [{ + $base + `globalSize` custom<DynamicIndexList>($global_dynamic_sizes, $global_static_sizes) + `globalStride` custom<DynamicIndexList>($global_dynamic_strides, $global_static_strides) + `sharedSize` custom<DynamicIndexList>($shared_dynamic_sizes, $shared_static_sizes) + ( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )? + ( `workgroupMask` $workgroup_mask^ ( `earlyTimeout` $early_timeout^)?)? + ( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]` + `:` type($atomic_barrier_address) `)`)? + ( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )? + attr-dict `:` qualified(type($base)) `->` type(results) + }]; + + let extraClassDeclaration = [{ + int64_t getRank() { + return getGlobalStaticSizes().size(); + } + + unsigned getElementTypeWidth() { + return getBase().getType().getElementType().getIntOrFloatBitWidth(); + } + + SmallVector<OpFoldResult> getMixedGlobalSizes() { + return getMixedValues(getGlobalStaticSizes(), getGlobalDynamicSizes(), getContext()); + } + + SmallVector<OpFoldResult> getMixedGlobalStrides() { + return getMixedValues(getGlobalStaticStrides(), getGlobalDynamicStrides(), getContext()); + } + + SmallVector<OpFoldResult> getMixedSharedSizes() { + return getMixedValues(getSharedStaticSizes(), getSharedDynamicSizes(), getContext()); + } + }]; + + let hasVerifier = 1; + let hasFolder = 1; +} + #endif // AMDGPU diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h index dcd9f95..a7680fb 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -25,6 +25,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc" #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc" namespace mlir::amdgpu { /// Parser for the `custom<MNKDimensionList>` custom assembly format used by @@ -52,6 +53,9 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *, #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPU.h.inc" diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index a38cf41..77d7804 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -158,6 +158,18 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [ attr-dict `:` type($result) }]; } +class Arith_IntBinaryOpWithExactFlag<string mnemonic, list<Trait> traits = []> : + Arith_BinaryOp<mnemonic, traits # + [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>, + Arguments<(ins SignlessIntegerOrIndexLike:$lhs, + SignlessIntegerOrIndexLike:$rhs, + UnitAttr:$isExact)>, + Results<(outs SignlessIntegerOrIndexLike:$result)> { + + let assemblyFormat = [{ $lhs `,` $rhs (`exact` $isExact^)? + attr-dict `:` type($result) }]; +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -482,7 +494,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative, // DivUIOp //===----------------------------------------------------------------------===// -def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> { +def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui", + [ConditionallySpeculatable]> { let summary = "unsigned integer division operation"; let description = [{ Unsigned integer division. Rounds towards zero. Treats the leading bit as @@ -493,12 +506,18 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> { `tensor` values, the behavior is undefined if _any_ elements are divided by zero. + If the `exact` attribute is present, the result value is poison if `lhs` is + not a multiple of `rhs`. + Example: ```mlir // Scalar unsigned integer division. %a = arith.divui %b, %c : i64 + // Scalar unsigned integer division where %b is known to be a multiple of %c. + %a = arith.divui %b, %c exact : i64 + // SIMD vector element-wise division. %f = arith.divui %g, %h : vector<4xi32> @@ -519,7 +538,8 @@ def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> { // DivSIOp //===----------------------------------------------------------------------===// -def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> { +def Arith_DivSIOp : Arith_IntBinaryOpWithExactFlag<"divsi", + [ConditionallySpeculatable]> { let summary = "signed integer division operation"; let description = [{ Signed integer division. Rounds towards zero. Treats the leading bit as @@ -530,12 +550,18 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> { behavior is undefined if _any_ of its elements are divided by zero or has a signed division overflow. + If the `exact` attribute is present, the result value is poison if `lhs` is + not a multiple of `rhs`. + Example: ```mlir // Scalar signed integer division. %a = arith.divsi %b, %c : i64 + // Scalar signed integer division where %b is known to be a multiple of %c. + %a = arith.divsi %b, %c exact : i64 + // SIMD vector element-wise division. %f = arith.divsi %g, %h : vector<4xi32> @@ -821,7 +847,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> { // ShRUIOp //===----------------------------------------------------------------------===// -def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { +def Arith_ShRUIOp : Arith_IntBinaryOpWithExactFlag<"shrui", [Pure]> { let summary = "unsigned integer right-shift"; let description = [{ The `shrui` operation shifts an integer value of the first operand to the right @@ -830,12 +856,17 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { filled with zeros. If the value of the second operand is greater or equal than the bitwidth of the first operand, then the operation returns poison. + If the `exact` attribute is present, the result value of shrui is a poison + value if any of the bits shifted out are non-zero. + Example: ```mlir - %1 = arith.constant 160 : i8 // %1 is 0b10100000 + %1 = arith.constant 160 : i8 // %1 is 0b10100000 %2 = arith.constant 3 : i8 - %3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 + %3 = arith.constant 6 : i8 + %4 = arith.shrui %1, %2 exact : i8 // %4 is 0b00010100 + %5 = arith.shrui %1, %3 : i8 // %3 is 0b00000010 ``` }]; let hasFolder = 1; @@ -845,7 +876,7 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { // ShRSIOp //===----------------------------------------------------------------------===// -def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> { +def Arith_ShRSIOp : Arith_IntBinaryOpWithExactFlag<"shrsi", [Pure]> { let summary = "signed integer right-shift"; let description = [{ The `shrsi` operation shifts an integer value of the first operand to the right @@ -856,14 +887,17 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> { operand is greater or equal than bitwidth of the first operand, then the operation returns poison. + If the `exact` attribute is present, the result value of shrsi is a poison + value if any of the bits shifted out are non-zero. + Example: ```mlir - %1 = arith.constant 160 : i8 // %1 is 0b10100000 + %1 = arith.constant 160 : i8 // %1 is 0b10100000 %2 = arith.constant 3 : i8 - %3 = arith.shrsi %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 - %4 = arith.constant 96 : i8 // %4 is 0b01100000 - %5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 + %3 = arith.shrsi %1, %2 exact : i8 // %3 is 0b11110100 + %4 = arith.constant 98 : i8 // %4 is 0b01100010 + %5 = arith.shrsi %4, %2 : i8 // %5 is 0b00001100 ``` }]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 6724d4c..a9b2b9f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []> def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", [AttrSizedOperandSegments, BufferizableOpInterface, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> { + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>]> { let summary = "allocate buffer for a tensor"; let description = [{ @@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp : Bufferization_Op<"materialize_in_destination", [AllElementTypesMatch<["source", "dest"]>, BufferizableOpInterface, DestinationStyleOpInterface, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>, DeclareOpInterfaceMethods<SubsetOpInterface, ["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>, DeclareOpInterfaceMethods<SubsetInsertionOpInterface, diff --git a/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h index f220d20..e4d5b81 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h @@ -47,6 +47,7 @@ struct BufferDeallocationPipelineOptions /// One-Shot bufferization pass. void buildBufferDeallocationPipeline( OpPassManager &pm, const BufferDeallocationPipelineOptions &options); +void buildBufferDeallocationPipeline(OpPassManager &pm); /// Registers all pipelines for the `bufferization` dialect. Currently, /// this includes only the "buffer-deallocation-pipeline". diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 4c1db58..c182090 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -116,6 +116,37 @@ def EmitC_FileOp let skipDefaultBuilders = 1; } +def EmitC_AddressOfOp : EmitC_Op<"address_of", [ + CExpressionInterface, + TypesMatchWith<"input and result reference the same type", "reference", "result", + "emitc::PointerType::get(::llvm::cast<emitc::LValueType>($_self).getValueType())"> +]> { + let summary = "Address operation"; + let description = [{ + This operation models the C & (address of) operator for a single operand, + which must be an emitc.lvalue, and returns an emitc pointer to its location. + + Example: + + ```mlir + // Custom form of applying the & operator. + %0 = emitc.address_of %arg0 : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + ``` + }]; + let arguments = (ins EmitC_LValueType:$reference); + let results = (outs EmitC_PointerType:$result); + let assemblyFormat = [{ + $reference `:` qualified(type($reference)) attr-dict + }]; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + }]; +} + def EmitC_AddOp : EmitC_BinaryOp<"add", []> { let summary = "Addition operation"; let description = [{ @@ -140,7 +171,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> { } def EmitC_ApplyOp : EmitC_Op<"apply", [CExpressionInterface]> { - let summary = "Apply operation"; + let summary = "Deprecated (use address_of/dereference)"; let description = [{ With the `emitc.apply` operation the operators & (address of) and * (contents of) can be applied to a single operand. @@ -439,6 +470,31 @@ def EmitC_ConstantOp }]; } +def EmitC_DereferenceOp : EmitC_Op<"dereference", [ + TypesMatchWith<"input and result reference the same type", "pointer", "result", + "emitc::LValueType::get(::llvm::cast<emitc::PointerType>($_self).getPointee())"> +]> { + let summary = "Dereference operation"; + let description = [{ + This operation models the C * (dereference) operator, which must be of + !emitc.ptr<> type, returning an !emitc.lvalue<> the value pointed to by the + pointer. + + Example: + + ```mlir + // Custom form of the dereference operator. + %0 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> !emitc.lvalue<i32> + ``` + }]; + let arguments = (ins EmitC_PointerType:$pointer); + let results = (outs EmitC_LValueType:$result); + let assemblyFormat = [{ + $pointer `:` qualified(type($pointer)) attr-dict + }]; + let hasVerifier = 1; +} + def EmitC_DivOp : EmitC_BinaryOp<"div", []> { let summary = "Division operation"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h index 3576126..00d5087 100644 --- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h @@ -60,6 +60,13 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>> deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp); +/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name +/// `name`. Return a failure if the FuncOp is found but with a different +/// signature. +FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, + SymbolTableCollection *symbolTables = nullptr); + } // namespace func } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td index 860f893..2c29bb8 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType< GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">; // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops. -def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>; +def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>; class MMAMatrixOf<list<Type> allowedTypes> : ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred, diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index a6c6038..5c7df25 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix", ``` }]; - let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src, + let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32, F64]>>:$src, Arg<GPU_MMAMemRef, "",[MemWriteAt<0, FullEffect>]>:$dstMemref, Variadic<Index>:$indices, IndexAttr:$leadDimension, @@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp ``` }]; - let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA, - Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB, - Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC, + let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opA, + Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opB, + Arg<MMAMatrixOf<[I32, F16, F32, F64]>>:$opC, OptionalAttr<UnitAttr>:$a_transpose, OptionalAttr<UnitAttr>:$b_transpose); diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h index fccb49d..34c85de 100644 --- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h @@ -58,6 +58,10 @@ struct GPUToNVVMPipelineOptions "Whether to use the bareptr calling convention on the host (warning " "this should be false until the GPU layering is fixed)"), llvm::cl::init(false)}; + PassOptions::Option<bool> allowPatternRollback{ + *this, "allow-pattern-rollback", + llvm::cl::desc("Allow pattern rollback during dialect conversion"), + llvm::cl::init(true)}; }; // Options for the gpu to xevm pipeline. diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index c301e0b..25b56cc 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -63,7 +63,7 @@ mlir_tablegen(NVVMRequiresSMTraits.cpp.inc -gen-op-interface-defs) add_mlir_dialect_tablegen_target(MLIRNVVMRequiresSMTraitsIncGen) add_mlir_dialect(NVVMOps nvvm) -add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm) +add_mlir_doc(NVVMOps NVVMOps Dialects/ -gen-op-doc) set(LLVM_TARGET_DEFINITIONS NVVMOps.td) mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) mlir_tablegen(NVVMFromLLVMIRConversions.inc -gen-intr-from-llvmir-conversions) diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 8ad9ed1..b09d320 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables = nullptr); +FailureOr<LLVM::LLVMFuncOp> +lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables = nullptr); + /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 147f8c2..ef16cec 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -931,11 +931,9 @@ def LLVM_DIStringTypeAttr : LLVM_Attr<"DIStringType", "di_string_type", //===----------------------------------------------------------------------===// def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects"> { - let parameters = (ins - "ModRefInfo":$other, - "ModRefInfo":$argMem, - "ModRefInfo":$inaccessibleMem - ); + let parameters = (ins "ModRefInfo":$other, "ModRefInfo":$argMem, + "ModRefInfo":$inaccessibleMem, "ModRefInfo":$errnoMem, + "ModRefInfo":$targetMem0, "ModRefInfo":$targetMem1); let extraClassDeclaration = [{ bool isReadWrite(); }]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index e7b44fd..e2edab4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -758,13 +758,16 @@ def FramePointerKindAll : LLVM_EnumAttrCase<"All", "all", "All", 2>; def FramePointerKindReserved : LLVM_EnumAttrCase<"Reserved", "reserved", "Reserved", 3>; +def FramePointerKindNonLeafNoReserve + : LLVM_EnumAttrCase<"NonLeafNoReserve", "non-leaf-no-reserve", "NonLeafNoReserve", 4>; def FramePointerKindEnum : LLVM_EnumAttr< "FramePointerKind", "::llvm::FramePointerKind", "LLVM FramePointerKind", [FramePointerKindNone, FramePointerKindNonLeaf, - FramePointerKindAll, FramePointerKindReserved]> { + FramePointerKindAll, FramePointerKindReserved, + FramePointerKindNonLeafNoReserve]> { let cppNamespace = "::mlir::LLVM::framePointerKind"; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 490130f..e31e461 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -181,6 +181,18 @@ def LLVM_SMinOp : LLVM_BinarySameArgsIntrOpI<"smin">; def LLVM_UMaxOp : LLVM_BinarySameArgsIntrOpI<"umax">; def LLVM_UMinOp : LLVM_BinarySameArgsIntrOpI<"umin">; +class LLVM_CmpIntrOp<string func> + : LLVM_OneResultIntrOp<func, [0], [0], [Pure, SameTypeOperands]> { + let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$a, + LLVM_ScalarOrVectorOf<AnySignlessInteger>:$b); + let results = (outs LLVM_ScalarOrVectorOf<AnySignlessInteger>:$res); + let assemblyFormat = "`(` operands `)` attr-dict `:` " + "functional-type(operands, results)"; +} + +def LLVM_SCmpOp : LLVM_CmpIntrOp<"scmp">; +def LLVM_UCmpOp : LLVM_CmpIntrOp<"ucmp">; + def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">; def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">; def LLVM_TanOp : LLVM_UnaryIntrOpF<"tan">; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index e425e16..971710f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -39,7 +39,7 @@ class LLVM_TerminatorOp<string mnemonic, list<Trait> traits = []> : class LLVM_ArithmeticOpBase<Type type, string mnemonic, string instName, list<Trait> traits = []> : LLVM_Op<mnemonic, - !listconcat([Pure, SameOperandsAndResultType], traits)>, + !listconcat([SameOperandsAndResultType, NoMemoryEffect], traits)>, LLVM_Builder<"$res = builder.Create" # instName # "($lhs, $rhs);"> { dag commonArgs = (ins LLVM_ScalarOrVectorOf<type>:$lhs, LLVM_ScalarOrVectorOf<type>:$rhs); @@ -116,7 +116,8 @@ class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName, class LLVM_FloatArithmeticOp<string mnemonic, string instName, list<Trait> traits = []> : LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName, - !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)> { + !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>, Pure], + traits)> { dag fmfArg = ( ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags); let arguments = !con(commonArgs, fmfArg); @@ -149,24 +150,26 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic, // Integer binary operations. def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add", - [Commutative]>; -def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>; + [Commutative, Pure]>; +def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", [Pure]>; def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul", - [Commutative]>; -def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv">; -def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">; -def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; -def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">; -def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">; -def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> { + [Commutative, Pure]>; +def LLVM_UDivOp : LLVM_IntArithmeticOpWithExactFlag<"udiv", "UDiv", + [DeclareOpInterfaceMethods<ConditionallySpeculatable>]>; +def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv", + [DeclareOpInterfaceMethods<ConditionallySpeculatable>]>; +def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem", [Pure]>; +def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem", [Pure]>; +def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And", [Pure]>; +def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or", [Pure]> { let hasFolder = 1; } -def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; -def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> { +def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor", [Pure]>; +def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", [Pure]> { let hasFolder = 1; } -def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr">; -def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr">; +def LLVM_LShrOp : LLVM_IntArithmeticOpWithExactFlag<"lshr", "LShr", [Pure]>; +def LLVM_AShrOp : LLVM_IntArithmeticOpWithExactFlag<"ashr", "AShr", [Pure]>; // Base class for compare operations. A compare operation takes two operands // of the same type and returns a boolean result. If the operands are diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 1cc5b74..a0a0051 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -37,50 +37,6 @@ def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>; //===----------------------------------------------------------------------===// def NVVM_Dialect : Dialect { - let summary = "The NVVM dialect that models NVIDIA's public ISA"; - - let description = [{ - The NVVM dialect is MLIR's LLVM-IR-based, NVIDIA-specific backend dialect. It - models NVVM intrinsics and public ISA functionality and introduces NVIDIA - extensions to the MLIR/LLVM type system and address spaces (e.g., global, - shared, and cluster memory), enabling faithful lowering of GPU kernels to the - NVPTX toolchain. While a NVVM op usually maps to a single LLVM IR intrinsic, - the NVVM dialect uses type polymorphism and other attributes so that a single - NVVM op can map to different LLVM intrinsics. - - **Scope and capabilities:** The dialect covers core GPU features such as - thread/block builtins, barriers and atomics, warp-level collectives (e.g., - shuffle/vote), matrix/tensor core operations (e.g., `mma.sync`, `wgmma`), - tensor memory accelerator (TMA) operations, asynchronous copies (`cp.async`, - bulk/tensor variants) with memory barriers, cache and prefetch controls, and - NVVM-specific attributes and enums (e.g., FP rounding modes, memory scopes, - and MMA types/layouts). - - **Non-goals:** NVVM is not a place for convenience or “wrapper” ops. It is - not intended to introduce high-level ops that expand into multiple unrelated - NVVM intrinsics or that lower to no intrinsic at all. Such abstractions belong - in higher-level dialects (e.g., `nvgpu`, `gpu`, or project-specific dialects). - The design intent is a thin, predictable, low-level surface with - near-mechanical lowering to NVVM/LLVM IR. - - **Placement in the lowering pipeline:** NVVM sits below target-agnostic - dialects like `gpu` and NVIDIA's `nvgpu`. Typical pipelines convert - `gpu`/`nvgpu` ops into NVVM using `-convert-gpu-to-nvvm` and - `-convert-nvgpu-to-nvvm`, then translate into LLVM for final code - generation via NVPTX backend. - - **Target configuration and serialization:** NVVM provides a `#nvvm.target` - attribute to describe the GPU target (SM, features, and flags). In - conjunction with `gpu` serialization (e.g., `gpu-module-to-binary`), this - enables producing architecture-specific GPU binaries (such as CUBIN) from - nested GPU modules. - - **Inline PTX:** When an intrinsic is unavailable or a performance-critical - sequence must be expressed directly, NVVM provides an `nvvm.inline_ptx` op to - embed PTX inline as a last-resort escape hatch, with explicit operands and - results. - }]; - let name = "nvvm"; let cppNamespace = "::mlir::NVVM"; let dependentDialects = ["LLVM::LLVMDialect"]; @@ -228,6 +184,54 @@ def NVVMMemorySpaceAttr : let assemblyFormat = "`<` $value `>`"; } +// Attrs describing the scope of the Memory Operation +def MemScopeKindCTA : I32EnumAttrCase<"CTA", 0, "cta">; +def MemScopeKindCluster : I32EnumAttrCase<"CLUSTER", 1, "cluster">; +def MemScopeKindGPU : I32EnumAttrCase<"GPU", 2, "gpu">; +def MemScopeKindSYS : I32EnumAttrCase<"SYS", 3, "sys">; + +def MemScopeKind : I32EnumAttr<"MemScopeKind", "NVVM Memory Scope kind", + [MemScopeKindCTA, MemScopeKindCluster, MemScopeKindGPU, MemScopeKindSYS]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def MemScopeKindAttr : EnumAttr<NVVM_Dialect, MemScopeKind, "mem_scope"> { + let assemblyFormat = "`<` $value `>`"; +} + +// Attrs to disambiguate the cta or cluster space within shared memory +def SharedSpaceCTA : I32EnumAttrCase<"shared_cta", 0, "cta">; +def SharedSpaceCluster : I32EnumAttrCase<"shared_cluster", 1, "cluster">; +def SharedSpace : I32EnumAttr<"SharedSpace", "Shared memory space", + [SharedSpaceCTA, SharedSpaceCluster]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def SharedSpaceAttr : EnumAttr<NVVM_Dialect, SharedSpace, "shared_space"> { + let assemblyFormat = "`<` $value `>`"; +} + +// Attrs describing the Memory Ordering Semantics +def MemOrderKindWeak : I32EnumAttrCase<"WEAK", 0, "weak">; +def MemOrderKindRelaxed : I32EnumAttrCase<"RELAXED", 1, "relaxed">; +def MemOrderKindAcquire : I32EnumAttrCase<"ACQUIRE", 2, "acquire">; +def MemOrderKindRelease : I32EnumAttrCase<"RELEASE", 3, "release">; +def MemOrderKindAcqRel : I32EnumAttrCase<"ACQ_REL", 4, "acq_rel">; +def MemOrderKindSC : I32EnumAttrCase<"SC", 5, "sc">; +def MemOrderKindMMIO : I32EnumAttrCase<"MMIO", 6, "mmio">; +def MemOrderKindVolatile : I32EnumAttrCase<"VOLATILE", 7, "volatile">; + +def MemOrderKind : I32EnumAttr<"MemOrderKind", "NVVM Memory Ordering kind", + [MemOrderKindWeak, MemOrderKindRelaxed, MemOrderKindAcquire, + MemOrderKindRelease, MemOrderKindAcqRel, MemOrderKindSC, + MemOrderKindMMIO, MemOrderKindVolatile]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def MemOrderKindAttr : EnumAttr<NVVM_Dialect, MemOrderKind, "mem_order"> { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // NVVM intrinsic operations //===----------------------------------------------------------------------===// @@ -512,8 +516,7 @@ def NVVM_ReduxOp : //===----------------------------------------------------------------------===// def NVVM_NanosleepOp : NVVM_Op<"nanosleep">, - Arguments<(ins - ConfinedAttr<I32Attr, [IntMinValue<1>, IntMaxValue<1000000>]>:$duration)> + Arguments<(ins I32:$duration)> { let summary = "Suspends the thread for a specified duration."; @@ -531,8 +534,7 @@ def NVVM_NanosleepOp : NVVM_Op<"nanosleep">, string llvmBuilder = [{ createIntrinsicCall(builder, - llvm::Intrinsic::nvvm_nanosleep, - {builder.getInt32($duration)}); + llvm::Intrinsic::nvvm_nanosleep, {$duration}); }]; let assemblyFormat = "attr-dict $duration"; } @@ -657,9 +659,76 @@ def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">, }]; } -def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive">, - Results<(outs I64:$res)>, - Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr)> { +def NVVM_MBarrierExpectTxOp : NVVM_Op<"mbarrier.expect_tx"> { + let summary = "MBarrier expect-tx Operation"; + let description = [{ + The `nvvm.mbarrier.expect_tx` operation increases the transaction count + of the mbarrier located at `addr` by `txcount` amount. The `scope` + specifies the set of threads that can directly observe the memory + synchronizing effect of the `mbarrier.expect_tx` operation. `CTA` + and `CLUSTER` are the only allowed values for `scope`. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx) + }]; + + let arguments = (ins + AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$addr, + I32:$txcount, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope); + + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)"; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierExpectTxOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); + }]; +} + +def NVVM_MBarrierCompleteTxOp : NVVM_Op<"mbarrier.complete_tx"> { + let summary = "MBarrier complete-tx Operation"; + let description = [{ + The `nvvm.mbarrier.complete_tx` operation decrements the transaction + count of the *mbarrier object* at `addr` by `txcount`. It also signals + the completion of asynchronous transactions that were tracked by the + current phase. The `scope` specifies the set of threads that can directly + observe the memory synchronizing effect of the `mbarrier.complete_tx` + operation. `CTA` and `CLUSTER` are the only allowed values for `scope`. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-complete-tx) + }]; + + let arguments = (ins + AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$addr, + I32:$txcount, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope); + + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)"; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierCompleteTxOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); + }]; +} + +def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive"> { let summary = "MBarrier Arrive Operation"; let description = [{ The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the @@ -671,19 +740,40 @@ def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive">, with this release pattern. This operation causes the executing thread to signal its arrival at the barrier. - The operation returns an opaque value that captures the phase of the - *mbarrier object* prior to the arrive-on operation. The contents of this state - value are implementation-specific. - The operation takes the following operand: + - `res`: When the `space` is not shared_cluster, this operation returns an + opaque 64-bit value capturing the phase of the *mbarrier object* prior to + the arrive-on operation. The contents of this return value are + implementation-specific. An *mbarrier object* located in the shared_cluster + space cannot return a value. + + The operation takes the following operands: - `addr`: A pointer to the memory location of the *mbarrier object*. The `addr` - must be a pointer to generic or shared::cta memory. When it is generic, the - underlying address must be within the shared::cta memory space; otherwise - the behavior is undefined. + must be a pointer to generic or shared_cta or shared_cluster memory. When it + is generic, the underlying address must be within the shared_cta memory space; + otherwise the behavior is undefined. + - `count`: This specifies the amount by which the pending arrival count is + decremented. If the `count` argument is not specified, the pending arrival + count is decremented by 1. + - `scope`: This specifies the set of threads that directly observe the memory + synchronizing effect of the `mbarrier.arrive` operation. + - `space`: This indicates the memory space where the mbarrier object resides. + - `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics + and does not provide any ordering or visibility guarantees. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive) }]; - let assemblyFormat = "$addr attr-dict `:` type($addr) `->` type($res)"; + + let results = (outs Optional<I64>:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared, LLVM_PointerSharedCluster]>:$addr, + Optional<I32>:$count, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope, + DefaultValuedAttr<BoolAttr, "false">:$relaxed); + + let assemblyFormat = "$addr (`,` $count^)? attr-dict `:` type($addr) (`->` type($res)^)?"; + + let hasVerifier = 1; let extraClassDeclaration = [{ static mlir::NVVM::IDArgPair @@ -694,7 +784,54 @@ def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive">, string llvmBuilder = [{ auto [id, args] = NVVM::MBarrierArriveOp::getIntrinsicIDAndArgs( *op, moduleTranslation, builder); - $res = createIntrinsicCall(builder, id, args); + + int addrSpace = llvm::cast<LLVMPointerType>(op.getAddr().getType()).getAddressSpace(); + if (addrSpace != static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster)) + $res = createIntrinsicCall(builder, id, args); + else + createIntrinsicCall(builder, id, args); + }]; +} + +def NVVM_MBarrierArriveDropOp : NVVM_Op<"mbarrier.arrive_drop"> { + let summary = "MBarrier Arrive-Drop Operation"; + let description = [{ + The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival + count of the *mbarrier object* by `count` and then performs an arrive-on + operation. When `count` is not specified, it defaults to 1. The decrement + of the expected arrival count applies to all the subsequent phases of the + *mbarrier object*. The remaining semantics are identical to those of the + `nvvm.mbarrier.arrive` operation. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive-drop) + }]; + + let results = (outs Optional<I64>:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared, LLVM_PointerSharedCluster]>:$addr, + Optional<I32>:$count, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope, + DefaultValuedAttr<BoolAttr, "false">:$relaxed); + + let assemblyFormat = "$addr (`,` $count^)? attr-dict `:` type($addr) (`->` type($res)^)?"; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierArriveDropOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + + int addrSpace = llvm::cast<LLVMPointerType>(op.getAddr().getType()).getAddressSpace(); + if (addrSpace != static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster)) + $res = createIntrinsicCall(builder, id, args); + else + createIntrinsicCall(builder, id, args); }]; } @@ -744,8 +881,36 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">, }]; } -def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">, - Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> { +def NVVM_MBarrierArriveDropNocompleteOp : NVVM_Op<"mbarrier.arrive_drop.nocomplete">, + Results<(outs I64:$res)>, + Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + I32:$count)> { + let summary = "MBarrier Arrive-Drop No-Complete Operation"; + let description = [{ + The `nvvm.mbarrier.arrive_drop.nocomplete` operation decrements the expected + arrival count of the *mbarrier object* by the amount `count` and then performs + an arrive-on operation on the *mbarrier object* with the guarantee that it + will not cause the barrier to complete its current phase. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive-drop) + }]; + + let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)"; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + +def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> { let summary = "MBarrier Arrive with Expected Transaction Count"; let description = [{ The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation @@ -756,11 +921,11 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t threads within the CTA. When other threads perform corresponding acquire operations (like 'mbarrier.test.wait'), they synchronize with this release pattern. - This operation first performs an expect-tx operation with the specified transaction - count, then performs an arrive-on operation with an implicit count of 1. The - expect-tx operation increases the tx-count of the *mbarrier object* by the specified - expectCount value, setting the current phase to expect and tracks the completion - of additional asynchronous transactions. + This operation first performs an expect-tx operation with the specified transaction + count, then performs an arrive-on operation with an implicit count of 1. The + expect-tx operation increases the expect-count of the *mbarrier object* by the + specified value (i.e. `txcount`), setting the current phase to expect and track + the completion of additional asynchronous transactions. The operation takes the following operands: - `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic @@ -768,33 +933,89 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t - `txcount`: An unsigned integer specifying the expected transaction count for the expect-tx operation. This represents the number of asynchronous transactions expected to complete before the barrier phase completes. - - `predicate`: Optional predicate for conditional execution. + - `scope`: This specifies the set of threads that directly observe the memory + synchronizing effect of the `mbarrier.test.wait` operation. + - `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics + and does not provide any ordering or visibility guarantees. + - `predicate`: Optional predicate for conditional execution used only when lowering to + inline-ptx. - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive) + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive-drop) }]; - let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); } + + let results = (outs Optional<I64>:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared, LLVM_PointerSharedCluster]>:$addr, + I32:$txcount, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope, + DefaultValuedAttr<BoolAttr, "false">:$relaxed, + PtxPredicate:$predicate); + + let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands) (`->` type($res)^)?"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + bool hasIntrinsic() { return !getPredicate(); } + + bool getAsmValues(RewriterBase &rewriter, + llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues); + + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + + if (op.getNumResults() > 0) + $res = createIntrinsicCall(builder, id, args); + else + createIntrinsicCall(builder, id, args); }]; } -def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">, - Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> { - let summary = "Shared MBarrier Arrive with Expected Transaction Count"; +def NVVM_MBarrierArriveDropExpectTxOp : NVVM_Op<"mbarrier.arrive_drop.expect_tx"> { + let summary = "MBarrier arrive_drop with expected transaction count"; let description = [{ - This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object* - should be accessed using a shared-memory pointer instead of a generic-memory pointer. + The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the + `nvvm.mbarrier.arrive.expect_tx` operation except that it performs an + `arrive_drop` operation instead of only an `arrive` operation. - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive) - }]; - let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); } + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive-drop) + }]; + + let results = (outs Optional<I64>:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared, LLVM_PointerSharedCluster]>:$addr, + I32:$txcount, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope, + DefaultValuedAttr<BoolAttr, "false">:$relaxed); + + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) (`->` type($res)^)?"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + if (op.getNumResults() > 0) + $res = createIntrinsicCall(builder, id, args); + else + createIntrinsicCall(builder, id, args); }]; } def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">, - Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> { + Arguments<(ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + I32:$phase, I32:$ticks)> { let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity"; let description = [{ The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking @@ -847,73 +1068,37 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity" [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) }]; let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - return std::string( - "{\n\t" - ".reg .pred P1; \n\t" - "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra.uni DONE; \n\t" - "bra.uni LAB_WAIT; \n\t" - "DONE: \n\t" - "}" - ); - } - }]; -} - -def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">, - Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> { - let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity"; - let description = [{ - This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object* - should be accessed using a shared-memory pointer instead of a generic-memory pointer. - - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) - }]; - let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - return std::string( - "{\n\t" - ".reg .pred P1; \n\t" - "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra.uni DONE; \n\t" - "bra.uni LAB_WAIT; \n\t" - "DONE: \n\t" - "}" - ); - } - }]; } -def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, - Results<(outs I1:$res)>, - Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, - I64:$state)> { +def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait"> { let summary = "MBarrier Non-Blocking Test Wait Operation"; let description = [{ - The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the + The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the completion of a specific phase of an *mbarrier object*. It uses the default - `.acquire.cta` semantics. This acquire pattern establishes memory ordering for - operations occurring in program order after this wait instruction by making - operations from other threads in the CTA visible to subsequent operations in the current - thread. When this wait completes, it synchronizes with the corresponding release - pattern from the `mbarrier.arrive` operation, establishing memory ordering within + `.acquire.cta` semantics. This acquire pattern establishes memory ordering for + operations occurring in program order after this wait instruction by making + operations from other threads in the CTA visible to subsequent operations in the current + thread. When this wait completes, it synchronizes with the corresponding release + pattern from the `mbarrier.arrive` operation, establishing memory ordering within the CTA. - This operation tests whether the mbarrier phase specified by the state operand - has completed. It is a non-blocking instruction that immediately returns the + This operation tests whether the mbarrier phase specified by the state operand + has completed. It is a non-blocking instruction that immediately returns the completion status without suspending the executing thread. The operation takes the following operands: - - `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic + - `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic addressing, but the address must still be in the shared memory space. - - `state`: An opaque value returned by a previous `mbarrier.arrive` - operation on the same *mbarrier object* during the current or immediately - preceding phase. + - `stateOrPhase`: This argument represents a `state` when it is a 64-bit value + and represents a `phase` when it is a 32-bit value. The `state` is an opaque + value returned by a previous `mbarrier.arrive` operation on the same + *mbarrier object* during the current or immediately preceding phase. + The `phase` is an integer specifying the phase parity (0 or 1). + Even phases have parity 0, odd phases have parity 1. + - `scope`: This specifies the set of threads that directly observe the memory + synchronizing effect of the `mbarrier.test.wait` operation. + - `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics + and does not provide any ordering or visibility guarantees. The operation returns a boolean value indicating whether the specified phase has completed: @@ -940,7 +1125,15 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) }]; - let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)"; + let results = (outs I1:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + AnyTypeOf<[I64, I32]>:$stateOrPhase, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope, + DefaultValuedAttr<BoolAttr, "false">:$relaxed); + + let assemblyFormat = "$addr `,` $stateOrPhase attr-dict `:` type(operands) `->` type($res)"; + let hasVerifier = 1; let extraClassDeclaration = [{ static mlir::NVVM::IDArgPair @@ -955,6 +1148,47 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, }]; } +def NVVM_MBarrierTryWaitOp : NVVM_Op<"mbarrier.try_wait"> { + let summary = "MBarrier try wait on state or phase with an optional timelimit"; + let description = [{ + The `nvvm.mbarrier.try_wait` operation checks whether the specified + *mbarrier object* at `addr` has completed the given phase. Note that + unlike the `nvvm.mbarrier.test.wait` operation, the try_wait operation + is a potentially-blocking one. If the phase is not yet complete, the + calling thread may be suspended. A suspended thread resumes execution + once the phase completes or when a system-defined timeout occurs. + Optionally, the `ticks` operand can be used to provide a custom timeout + (in nanoseconds), overriding the system-defined one. The semantics of + this operation and its operands are otherwise similar to those of the + `nvvm.mbarrier.test.wait` Op. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) + }]; + + let results = (outs I1:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + AnyTypeOf<[I64, I32]>:$stateOrPhase, + Optional<I32>:$ticks, + DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope, + DefaultValuedAttr<BoolAttr, "false">:$relaxed); + + let assemblyFormat = "$addr `,` $stateOrPhase (`,` $ticks^)? attr-dict `:` type(operands) `->` type($res)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierTryWaitOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// @@ -977,6 +1211,23 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> { }]; } +// Attrs describing the reduction operations for the barrier operation. +def BarrierReductionPopc : I32EnumAttrCase<"POPC", 0, "popc">; +def BarrierReductionAnd : I32EnumAttrCase<"AND", 1, "and">; +def BarrierReductionOr : I32EnumAttrCase<"OR", 2, "or">; + +def BarrierReduction + : I32EnumAttr<"BarrierReduction", "NVVM barrier reduction operation", + [BarrierReductionPopc, BarrierReductionAnd, + BarrierReductionOr]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def BarrierReductionAttr + : EnumAttr<NVVM_Dialect, BarrierReduction, "reduction"> { + let assemblyFormat = "`<` $value `>`"; +} + def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> { let summary = "CTA Barrier Synchronization Op"; let description = [{ @@ -991,6 +1242,9 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> { - `numberOfThreads`: Specifies the number of threads participating in the barrier. When specified, the value must be a multiple of the warp size. If not specified, all threads in the CTA participate in the barrier. + - `reductionOp`: specifies the reduction operation (`popc`, `and`, `or`). + - `reductionPredicate`: specifies the predicate to be used with the + `reductionOp`. The barrier operation guarantees that when the barrier completes, prior memory accesses requested by participating threads are performed relative to all threads @@ -1007,31 +1261,37 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> { [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar) }]; - let arguments = (ins - Optional<I32>:$barrierId, - Optional<I32>:$numberOfThreads); + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads, + OptionalAttr<BarrierReductionAttr>:$reductionOp, + Optional<I32>:$reductionPredicate); string llvmBuilder = [{ - llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0); - if ($numberOfThreads) - createIntrinsicCall( - builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count, - {id, $numberOfThreads}); - else - createIntrinsicCall( - builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id}); + auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + if ($reductionOp) + $res = createIntrinsicCall(builder, id, args); + else + createIntrinsicCall(builder, id, args); }]; + let results = (outs Optional<I32>:$res); + let hasVerifier = 1; - let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict"; + let assemblyFormat = + "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? " + "(qualified($reductionOp)^ $reductionPredicate)? (`->` type($res)^)? attr-dict"; - let builders = [ - OpBuilder<(ins), [{ - return build($_builder, $_state, Value{}, Value{}); + let builders = [OpBuilder<(ins), [{ + return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{}); }]>, - OpBuilder<(ins "Value":$barrierId), [{ - return build($_builder, $_state, barrierId, Value{}); - }]> - ]; + OpBuilder<(ins "Value":$barrierId), [{ + return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{}); + }]>]; } def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive"> @@ -1130,6 +1390,27 @@ def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait", [NVVMRequiresSM<90>]> { let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// NVVM Member/Fence +//===----------------------------------------------------------------------===// + +def NVVM_MembarOp : NVVM_Op<"memory.barrier">, + Arguments<(ins MemScopeKindAttr:$scope)> { + let summary = "Memory barrier operation"; + let description = [{ + `membar` operation guarantees that prior memory accesses requested by this + thread are performed at the specified `scope`, before later memory + operations requested by this thread following the membar instruction. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar) + }]; + + let assemblyFormat = "$scope attr-dict"; + let llvmBuilder = [{ + createIntrinsicCall(builder, getMembarIntrinsicID($scope)); + }]; +} + def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> { string llvmBuilder = [{ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_fence_sc_cluster); @@ -1137,15 +1418,36 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> { let assemblyFormat = "attr-dict"; } -def SharedSpaceCTA : I32EnumAttrCase<"shared_cta", 0, "cta">; -def SharedSpaceCluster : I32EnumAttrCase<"shared_cluster", 1, "cluster">; -def SharedSpace : I32EnumAttr<"SharedSpace", "Shared memory space", - [SharedSpaceCTA, SharedSpaceCluster]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; +def NVVM_FenceSyncRestrictOp : NVVM_Op<"fence.sync_restrict">, + Arguments<(ins MemOrderKindAttr:$order)> { + let summary = "Uni-directional thread fence operation"; + let description = [{ + The `nvvm.fence.sync_restrict` Op restricts the class of memory + operations for which the fence instruction provides the memory ordering guarantees. + `sync_restrict` restricts `acquire` memory semantics to `shared_cluster` and + `release` memory semantics to `shared_cta` with cluster scope. + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) + }]; + + let assemblyFormat = "attr-dict"; + let llvmBuilder = [{ + createIntrinsicCall(builder, getFenceSyncRestrictID($order)); + }]; + + let hasVerifier = 1; } -def SharedSpaceAttr : EnumAttr<NVVM_Dialect, SharedSpace, "shared_space"> { - let assemblyFormat = "`<` $value `>`"; + +def NVVM_FenceMbarrierInitOp : NVVM_Op<"fence.mbarrier.init"> { + let description = [{ + Fence operation that applies on the prior nvvm.mbarrier.init + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) + }]; + + let assemblyFormat = "attr-dict"; + let llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::nvvm_fence_mbarrier_init_release_cluster); + }]; } def ProxyAlias : I32EnumAttrCase<"alias", 0, "alias">; @@ -1161,10 +1463,15 @@ def ProxyKind : I32EnumAttr<"ProxyKind", "Proxy kind", } def ProxyKindAttr : EnumAttr<NVVM_Dialect, ProxyKind, "proxy_kind"> { + let description = [{ + ProxyKind attribute represents a memory proxy which is an abstract label + applied to a method of memory access. When two memory operations use distinct + methods of memory access, they are said to be different proxies. + }]; let assemblyFormat = "`<` $value `>`"; } -def NVVM_FenceProxyOp : NVVM_PTXBuilder_Op<"fence.proxy">, +def NVVM_FenceProxyOp : NVVM_Op<"fence.proxy">, Arguments<(ins ProxyKindAttr:$kind, OptionalAttr<SharedSpaceAttr>:$space)> { let description = [{ @@ -1175,32 +1482,12 @@ def NVVM_FenceProxyOp : NVVM_PTXBuilder_Op<"fence.proxy">, }]; let assemblyFormat = "attr-dict"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - std::string ptx = "fence.proxy."; - ptx += stringifyProxyKind(getKind()); - if(getKind() == NVVM::ProxyKind::async_shared) - { ptx += "::"; ptx += stringifySharedSpace(getSpace().value()); } - ptx += ";"; - return ptx; - } - }]; - let hasVerifier = 1; -} -// Attrs describing the scope of the Memory Operation -def MemScopeKindCTA : I32EnumAttrCase<"CTA", 0, "cta">; -def MemScopeKindCluster : I32EnumAttrCase<"CLUSTER", 1, "cluster">; -def MemScopeKindGPU : I32EnumAttrCase<"GPU", 2, "gpu">; -def MemScopeKindSYS : I32EnumAttrCase<"SYS", 3, "sys">; + let llvmBuilder = [{ + createIntrinsicCall(builder, getFenceProxyID($kind, $space)); + }]; -def MemScopeKind : I32EnumAttr<"MemScopeKind", "NVVM Memory Scope kind", - [MemScopeKindCTA, MemScopeKindCluster, MemScopeKindGPU, MemScopeKindSYS]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; -} -def MemScopeKindAttr : EnumAttr<NVVM_Dialect, MemScopeKind, "mem_scope"> { - let assemblyFormat = "`<` $value `>`"; + let hasVerifier = 1; } def NVVM_FenceProxyAcquireOp : NVVM_Op<"fence.proxy.acquire">, @@ -1236,23 +1523,6 @@ def NVVM_FenceProxyAcquireOp : NVVM_Op<"fence.proxy.acquire">, let hasVerifier = 1; } -def NVVM_MembarOp : NVVM_Op<"memory.barrier">, - Arguments<(ins MemScopeKindAttr:$scope)> { - let summary = "Memory barrier operation"; - let description = [{ - `membar` operation guarantees that prior memory accesses requested by this - thread are performed at the specified `scope`, before later memory - operations requested by this thread following the membar instruction. - - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar) - }]; - - let assemblyFormat = "$scope attr-dict"; - let llvmBuilder = [{ - createIntrinsicCall(builder, getMembarIntrinsicID($scope), {}); - }]; -} - def NVVM_FenceProxyReleaseOp : NVVM_Op<"fence.proxy.release">, Arguments<(ins MemScopeKindAttr:$scope, DefaultValuedAttr<ProxyKindAttr, @@ -1279,6 +1549,28 @@ def NVVM_FenceProxyReleaseOp : NVVM_Op<"fence.proxy.release">, let hasVerifier = 1; } +def NVVM_FenceProxySyncRestrictOp : NVVM_Op<"fence.proxy.sync_restrict">, + Arguments<(ins MemOrderKindAttr:$order, + DefaultValuedAttr<ProxyKindAttr, "ProxyKind::GENERIC">:$fromProxy, + DefaultValuedAttr<ProxyKindAttr, "ProxyKind::async">:$toProxy)> { + let summary = "Uni-directional proxy fence operation with sync_restrict"; + let description = [{ + The `nvvm.fence.proxy.sync_restrict` Op used to establish + ordering between a prior memory access performed between proxies. Currently, + the ordering is only supported between async and generic proxies. `sync_restrict` + restricts `acquire` memory semantics to `shared_cluster` and `release` memory + semantics to `shared_cta` with cluster scope. + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) + }]; + + let assemblyFormat = "attr-dict"; + let llvmBuilder = [{ + createIntrinsicCall(builder, getFenceProxySyncRestrictID($order)); + }]; + + let hasVerifier = 1; +} + def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>; def SetMaxRegisterActionDecrease : I32EnumAttrCase<"decrease", 1>; def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action", @@ -1301,22 +1593,6 @@ def NVVM_SetMaxRegisterOp : NVVM_Op<"setmaxregister"> { }]; } -def NVVM_FenceMbarrierInitOp : NVVM_PTXBuilder_Op<"fence.mbarrier.init"> { - let arguments = (ins ); - let description = [{ - Fence operation that applies on the prior nvvm.mbarrier.init - - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar) - }]; - - let assemblyFormat = "attr-dict"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - return std::string("fence.mbarrier_init.release.cluster;"); - } - }]; -} - def ShflKindBfly : I32EnumAttrCase<"bfly", 0>; def ShflKindUp : I32EnumAttrCase<"up", 1>; def ShflKindDown : I32EnumAttrCase<"down", 2>; @@ -1417,7 +1693,7 @@ def NVVM_VoteSyncOp def NVVM_SyncWarpOp : NVVM_Op<"bar.warp.sync">, - Arguments<(ins LLVM_Type:$mask)> { + Arguments<(ins I32:$mask)> { let summary = "Warp Barrier Synchronization Op"; let description = [{ The `nvvm.bar.warp.sync` operation performs barrier synchronization for threads @@ -1476,6 +1752,133 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync"> }]; } +//===----------------------------------------------------------------------===// +// Permute Bytes (Prmt) +//===----------------------------------------------------------------------===// + +// Attributes for the permute operation modes supported by PTX. +def PermuteModeDefault : I32EnumAttrCase<"DEFAULT", 0, "default">; +def PermuteModeF4E : I32EnumAttrCase<"F4E", 1, "f4e">; +def PermuteModeB4E : I32EnumAttrCase<"B4E", 2, "b4e">; +def PermuteModeRC8 : I32EnumAttrCase<"RC8", 3, "rc8">; +def PermuteModeECL : I32EnumAttrCase<"ECL", 4, "ecl">; +def PermuteModeECR : I32EnumAttrCase<"ECR", 5, "ecr">; +def PermuteModeRC16 : I32EnumAttrCase<"RC16", 6, "rc16">; + +def PermuteMode : I32EnumAttr<"PermuteMode", "NVVM permute mode", + [PermuteModeDefault, PermuteModeF4E, + PermuteModeB4E, PermuteModeRC8, PermuteModeECL, + PermuteModeECR, PermuteModeRC16]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} + +def PermuteModeAttr : EnumAttr<NVVM_Dialect, PermuteMode, "permute_mode"> { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_PermuteOp : NVVM_Op<"prmt", [Pure]>, + Results<(outs I32:$res)>, + Arguments<(ins I32:$lo, Optional<I32>:$hi, I32:$selector, + PermuteModeAttr:$mode)> { + let summary = "Permute bytes from two 32-bit registers"; + let description = [{ + The `nvvm.prmt` operation constructs a permutation of the + bytes of the first one or two operands, selecting based on + the 2 least significant bits of the final operand. + + The bytes in the first one or two source operands are numbered. + The first source operand (%lo) is numbered {b3, b2, b1, b0}, + in the case of the '``default``', '``f4e``' and '``b4e``' variants, + the second source operand (%hi) is numbered {b7, b6, b5, b4}. + + Modes: + - `default`: Index mode - each nibble in `selector` selects a byte from the 8-byte pool + - `f4e` : Forward 4 extract - extracts 4 contiguous bytes starting from position in `selector` + - `b4e` : Backward 4 extract - extracts 4 contiguous bytes in reverse order + - `rc8` : Replicate 8 - replicates the lower 8 bits across the 32-bit result + - `ecl` : Edge clamp left - clamps out-of-range indices to the leftmost valid byte + - `ecr` : Edge clamp right - clamps out-of-range indices to the rightmost valid byte + - `rc16` : Replicate 16 - replicates the lower 16 bits across the 32-bit result + + Depending on the 2 least significant bits of the %selector operand, the result + of the permutation is defined as follows: + + +------------+----------------+--------------+ + | Mode | %selector[1:0] | Output | + +------------+----------------+--------------+ + | '``f4e``' | 0 | {3, 2, 1, 0} | + | +----------------+--------------+ + | | 1 | {4, 3, 2, 1} | + | +----------------+--------------+ + | | 2 | {5, 4, 3, 2} | + | +----------------+--------------+ + | | 3 | {6, 5, 4, 3} | + +------------+----------------+--------------+ + | '``b4e``' | 0 | {5, 6, 7, 0} | + | +----------------+--------------+ + | | 1 | {6, 7, 0, 1} | + | +----------------+--------------+ + | | 2 | {7, 0, 1, 2} | + | +----------------+--------------+ + | | 3 | {0, 1, 2, 3} | + +------------+----------------+--------------+ + | '``rc8``' | 0 | {0, 0, 0, 0} | + | +----------------+--------------+ + | | 1 | {1, 1, 1, 1} | + | +----------------+--------------+ + | | 2 | {2, 2, 2, 2} | + | +----------------+--------------+ + | | 3 | {3, 3, 3, 3} | + +------------+----------------+--------------+ + | '``ecl``' | 0 | {3, 2, 1, 0} | + | +----------------+--------------+ + | | 1 | {3, 2, 1, 1} | + | +----------------+--------------+ + | | 2 | {3, 2, 2, 2} | + | +----------------+--------------+ + | | 3 | {3, 3, 3, 3} | + +------------+----------------+--------------+ + | '``ecr``' | 0 | {0, 0, 0, 0} | + | +----------------+--------------+ + | | 1 | {1, 1, 1, 0} | + | +----------------+--------------+ + | | 2 | {2, 2, 1, 0} | + | +----------------+--------------+ + | | 3 | {3, 2, 1, 0} | + +------------+----------------+--------------+ + | '``rc16``' | 0 | {1, 0, 1, 0} | + | +----------------+--------------+ + | | 1 | {3, 2, 3, 2} | + | +----------------+--------------+ + | | 2 | {1, 0, 1, 0} | + | +----------------+--------------+ + | | 3 | {3, 2, 3, 2} | + +------------+----------------+--------------+ + + [For more information, see PTX ISA] + (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt) + }]; + + let assemblyFormat = [{ + $mode $selector `,` $lo (`,` $hi^)? attr-dict `:` type($res) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::PermuteOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">; def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">; def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">; @@ -1503,7 +1906,7 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">, LLVM_PointerGlobal:$src, I32Attr:$size, LoadCacheModifierAttr:$modifier, - Optional<LLVM_Type>:$cpSize)> { + Optional<I32>:$cpSize)> { let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)"; let hasVerifier = 1; let extraClassDeclaration = [{ @@ -1907,45 +2310,57 @@ def NVVM_ConvertF4x2ToF16x2Op : // Base class for conversions from F32x2 to FPx2 formats // (F16x2, BF16x2) -// TODO: In separate PR, add .rn and .rz rounding variants for this conversion -// as currently only support .rs rounding mode class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> : - NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>, + NVVM_Op<mnemonic, [Pure]>, Results<(outs dstType:$dst)>, - Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits, - DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd, + Arguments<(ins F32:$src_hi, F32:$src_lo, + Optional<I32>:$random_bits, + DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd, DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat, DefaultValuedAttr<BoolAttr, "false">:$relu)> { - let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)"; + let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # "."; let description = [{ - Converts two F32 values to packed }] # dstFormat # [{ format using stochastic - rounding (.rs) mode with randomness provided by the `rbits` parameter. The - `relu` attribute clamps negative results to 0. The `sat` attribute determines - saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands - `a` and `b` in the PTX ISA, respectively. + Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with + the specified rounding mode. The `src_hi` and `src_lo` parameters + correspond to operands `a` and `b` in the PTX ISA, respectively. + + The `random_bits` parameter is required for stochastic rounding and + provides the [random bits](}] # + !if(!eq(dstFormat, "F16x2"), + "https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-f16", + "https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-bf16") # + [{) to be used for the conversion. + + The `relu` attribute clamps negative results to 0. + + The `sat` attribute determines saturation behavior. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) }]; - let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)"; + let assemblyFormat = "$src_hi `,` $src_lo (`,` $random_bits^)? attr-dict `:` type($dst)"; let hasVerifier = 1; let extraClassDeclaration = [{ - llvm::Intrinsic::ID getIntrinsicID(); + static NVVM::IDArgPair + getIntrinsicIDAndArgs( + NVVM::ConvertF32x2To}] # dstFormat # [{Op &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - auto intId = op.getIntrinsicID(); - $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits}); + auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat # + [{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder); + $dst = createIntrinsicCall(builder, intId, args); }]; - } +} -// F32x2 -> F16x2 with stochastic rounding -def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>; +// F32x2 -> F16x2 +def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>; -// F32x2 -> BF16x2 with stochastic rounding -def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>; +// F32x2 -> BF16x2 +def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>; // Base class for stochastic rounding conversions from F32x4 to FPx4 formats // (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4) @@ -2028,6 +2443,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> { /// Generate the signature part of the mma intrinsic name. class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> { list<WMMA_REGS> id_frags = !cond( + // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type. + !or(!eq(A.ptx_elt_type, "e4m3"), + !eq(A.ptx_elt_type, "e5m2"), + !eq(A.ptx_elt_type, "e3m2"), + !eq(A.ptx_elt_type, "e2m3"), + !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C], // FP16 ops are identified by accumulator & result type. !eq(A.ptx_elt_type, "f16") : [D, C], // other ops are identified by input types. @@ -2154,6 +2575,31 @@ class NVVM_MMA_OPS { list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat( tf32_mma_ops, bf16_mma_ops, f64_mma_ops, fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops); + + list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS< + [GEOM<16,8,16>, GEOM<16,8,32>], + ["bf16"], [], ["f32"], []>.ret; + list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS< + [GEOM<16,8,8>, GEOM<16,8,16>], + ["tf32"], [], ["f32"], []>.ret; + list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS< + [GEOM<16,8,16>, GEOM<16,8,32>], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; + list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS< + [GEOM<16,8,64>], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f16", "f32"], ["f16", "f32"]>.ret; + list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS< + [GEOM<16,8,64>, GEOM<16,8,128>], + ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret; + list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS< + [GEOM<16,8,32>, GEOM<16,8,64>], + ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret; + list<list<WMMA_REGS>> all_mma_sp_sync_ops = !listconcat( + bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops, + subint_mma_sp_ops, int_mma_sp_ops); + } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -2259,6 +2705,16 @@ def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options", def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> { let assemblyFormat = "`<` $value `>`"; } +/// MMA kind types (for mixed-precision FP8 operations) +def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>; +def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind", + [MMAKindF8F6F4]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def MMAKindAttr : EnumAttr<NVVM_Dialect, MMAKind, "mma_kind"> { + let assemblyFormat = "`<` $value `>`"; +} /// Attribute to hold the MMA shape def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> { @@ -2403,12 +2859,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>; def MMATypeS4 : I32EnumAttrCase<"s4", 8>; def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>; def MMATypeF64 : I32EnumAttrCase<"f64", 10>; +def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>; +def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>; +def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>; +def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>; +def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>; def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types", [MMATypeF16, MMATypeF32, MMATypeTF32, MMATypeBF16, MMATypeS8, MMATypeU8, MMATypeS32, MMATypeS4, MMATypeU4, - MMATypeB1, MMATypeF64]> { + MMATypeB1, MMATypeF64, + MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } @@ -2845,6 +3307,216 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> { let hasVerifier = 1; } +/// Generate enum value of the mma.sync intrinsic. +class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite, + WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> { + string signature = MMA_SIGNATURE<A, B, C, D>.ret; + string id = "llvm::Intrinsic::nvvm_mma" + # "_" # !subst("::", "_", Metadata) + # "_" # A.geom + # "_row_col" + # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "") + # !if(Satfinite, "_satfinite", "") + # signature; +} + +// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported; +// false otherwise. +// E.g. +// if NVVM_MMA_SP_SUPPORTED<...>.ret then +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata, + string kind, int satf> { + // MMA.SP ops check both layouts. + string a_type = frags[0].ptx_elt_type; + string b_type = frags[1].ptx_elt_type; + string c_type = frags[2].ptx_elt_type; + string d_type = frags[3].ptx_elt_type; + string geom = frags[0].geom; + + bit is_int = !or(!eq(a_type, "s8"), + !eq(a_type, "u8"), + !eq(a_type, "s4"), + !eq(a_type, "u4")); + + bit ret = !cond( + + // Limit satf to valid types + !and(!eq(satf, 1), + !eq(is_int, 0)): false, + + // f16/bf16/tf32 requires A and B to be the same type. + !and(!or(!eq(a_type, "f16"), + !eq(a_type, "bf16"), + !eq(a_type, "tf32")), + !ne(a_type, b_type)): false, + + // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type. + !and(!or(!eq(geom, "m16n8k16"), + !eq(geom, "m16n8k32"), + !eq(geom, "m16n8k64")), + !ne(c_type, d_type)): false, + + !and(!eq(kind, ""), + !or(!eq(a_type, "e3m2"), + !eq(a_type, "e2m3"), + !eq(a_type, "e2m1"), + !eq(b_type, "e3m2"), + !eq(b_type, "e2m3"), + !eq(b_type, "e2m1"))): false, + + !and(!eq(kind, ""), + !eq(geom, "m16n8k64"), + !or(!eq(c_type, "f16"), + !eq(d_type, "f16"))): false, + + !and(!ne(kind, ""), + !or(!eq(metadata, "sp"), + !ne(geom, "m16n8k64"), + !eq(is_int, 1))): false, + + // All other are OK. + true: true + ); +} + +/// Helper to create the mapping between the configuration and the mma.sp.sync +/// intrinsic enum value. +class MMA_SP_SYNC_INTR { + list<list<list<list<string>>>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops, + !foreach(metadata, ["sp", "sp::ordered_metadata"], + !foreach(kind, ["", "kind::f8f6f4"], + !foreach (satf, [0, 1], + !if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret, + "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k + # " && \"" # op[0].ptx_elt_type # "\" == eltypeA" + # " && \"" # op[1].ptx_elt_type # "\" == eltypeB" + # " && \"" # op[2].ptx_elt_type # "\" == eltypeC" + # " && \"" # op[3].ptx_elt_type # "\" == eltypeD" + # " && (satf.has_value() ? " # satf # " == static_cast<int>(*satf) : true)" + # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata") # ")\n" + # " return " # + MMA_SP_SYNC_NAME<metadata, kind, satf, op[0], op[1], op[2], op[3]>.id # ";", + "") // if supported + ) // satf + ) // kind + ) // metadata + ); // all_mma_sp_sync_ops + list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el, + !listconcat(acc, el)); + list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el)); + list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); + string id = !foldl("", f3, acc, el, acc # "\n" # el); +} + +def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { + + let summary = "cooperative sparse matrix-multiply and accumulate"; + + let description = [{ + The `nvvm.mma.sp.sync` operation collectively performs the sparse operation + `D = matmul(A_sparse, B) + C` using all threads in a warp. + + This operation is similar to `nvvm.mma.sync` but with structured sparsity + in the A operand. The sparsity follows the 2:4 structured sparse pattern + where 2 out of every 4 elements are non-zero. + + All the threads in the warp must execute the same `mma.sp.sync` operation. + + The `sparseMetadata` operand provides the sparsity indices that indicate + which elements in the A operand are non-zero. The `sparsitySelector` + controls how the indices are distributed among threads in the warp and + should typically be 0 or 1. + + The optional `orderedMetadata` attribute specifies the metadata ordering: + - Absence (default): Uses standard sparse metadata ordering + - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+) + + The optional `kind` attribute specifies mixed-precision modes for FP8 operations: + - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+) + - Only valid with ordered metadata and m16n8k64 shape + + The shapes, layouts, and data types follow the same constraints as the + regular `nvvm.mma.sync` operation, but the A operand contains only the + non-zero elements in compressed format. + + Example: + ```mlir + %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + + // With ordered metadata: + %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + ``` + }]; + + let results = (outs LLVM_AnyStruct:$res); + let arguments = (ins NVVM_MMAShapeAttr:$shape, + OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior, + OptionalAttr<MMATypesAttr>:$multiplicandAPtxType, + OptionalAttr<MMATypesAttr>:$multiplicandBPtxType, + UnitAttr:$orderedMetadata, + OptionalAttr<MMAKindAttr>:$kind, + Variadic<LLVM_Type>:$operandA, + Variadic<LLVM_Type>:$operandB, + Variadic<LLVM_Type>:$operandC, + I32:$sparseMetadata, + I32:$sparsitySelector); + + let extraClassDeclaration = !strconcat([{ + static llvm::Intrinsic::ID getIntrinsicID( + int64_t m, int64_t n, uint64_t k, + std::optional<MMAIntOverflow> satf, + bool orderedMetadata, + std::optional<MMAKind> kind, + mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, + mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) { + llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); + llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); + llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); + llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum); + }], + MMA_SP_SYNC_INTR<>.id, [{ + return 0; + } + + static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType, + bool isAccumulator); + + MMATypes accumPtxType(); + MMATypes resultPtxType(); + + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]); + + let builders = [ + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + "ValueRange":$operandB, "ValueRange":$operandC, + "Value":$sparseMetadata, "Value":$sparsitySelector, + "ArrayRef<int64_t>":$shape, + "std::optional<MMAIntOverflow>":$intOverflow, + "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes)> + ]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // NVVM TMA Ops //===----------------------------------------------------------------------===// @@ -3206,12 +3878,7 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch", let llvmBuilder = [{ auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op, moduleTranslation, builder); - - if(op.getTensormap()) - // Overloaded intrinsic - createIntrinsicCall(builder, id, args, {args[0]->getType()}); - else - createIntrinsicCall(builder, id, args); + createIntrinsicCall(builder, id, builder.getVoidTy(), args); }]; } @@ -3372,16 +4039,17 @@ def NVVM_CpAsyncBulkTensorReduceOp : def NVVM_CpAsyncBulkGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> { - let summary = "Async bulk copy from global memory to Shared cluster memory"; + let summary = "Async bulk copy from global to Shared {cta or cluster} memory"; let description = [{ - Initiates an asynchronous copy operation from global memory to cluster's - shared memory. + Initiates an asynchronous copy operation from global memory to shared + memory or shared_cluster memory. - The `multicastMask` operand is optional. When it is present, the Op copies + The `multicastMask` operand is optional and can be used only when the + destination is shared::cluster memory. When it is present, this Op copies data from global memory to shared memory of multiple CTAs in the cluster. Operand `multicastMask` specifies the destination CTAs in the cluster such that each bit position in the 16-bit `multicastMask` operand corresponds to - the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. + the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. The `l2CacheHint` operand is optional, and it is used to specify cache eviction policy that may be used during the memory access. @@ -3390,7 +4058,7 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp : }]; let arguments = (ins - LLVM_PointerSharedCluster:$dstMem, + AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem, LLVM_PointerGlobal:$srcMem, LLVM_PointerShared:$mbar, I32:$size, @@ -3404,6 +4072,8 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp : attr-dict `:` type($dstMem) `,` type($srcMem) }]; + let hasVerifier = 1; + let extraClassDeclaration = [{ static mlir::NVVM::IDArgPair getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, @@ -4669,6 +5339,551 @@ def NVVM_ClusterLaunchControlQueryCancelOp } //===----------------------------------------------------------------------===// +// NVVM tcgen05.mma Ops +//===----------------------------------------------------------------------===// + +def Tcgen05MMAKindF16 : I32EnumAttrCase<"F16", 0, "f16">; +def Tcgen05MMAKindTF32 : I32EnumAttrCase<"TF32", 1, "tf32">; +def Tcgen05MMAKindF8F6F4 : I32EnumAttrCase<"F8F6F4", 2, "f8f6f4">; +def Tcgen05MMAKindINT8 : I32EnumAttrCase<"I8", 3, "i8">; + +def Tcgen05MMAKind : I32EnumAttr< + "Tcgen05MMAKind", + "tcgen05 MMA Supported Types", + [Tcgen05MMAKindF8F6F4, Tcgen05MMAKindINT8, Tcgen05MMAKindF16, + Tcgen05MMAKindTF32]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def Tcgen05MMAKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMAKind, "tcgen05_mma_kind"> { + let description = [{ + The Tcgen05MMAKind attribute describes the allowed set of types for matrix A and B in the tcgen05.mma.{sp} Op. The following are supported types for each kind: + + ``` + +-------------+--------------------------------------------+ + | Matrix Kind | supported types for A / B | + +-------------+--------------------------------------------+ + | f16 | f16, bf16 | + | tf32 | tf32 | + | f8f6f4 | e4m3, e5m2, e2m3, e3m2, e2m1 | + | i8 | unsigned 8b, signed 8b | + +-------------+--------------------------------------------+ + ``` + }]; + let assemblyFormat = "`<` $value `>`"; +} + +def Tcgen05MMACollectorOpDiscard : I32EnumAttrCase<"DISCARD", 0, "discard">; +def Tcgen05MMACollectorOpLastUse : I32EnumAttrCase<"LASTUSE", 1, "lastuse">; +def Tcgen05MMACollectorOpFill : I32EnumAttrCase<"FILL", 2, "fill">; +def Tcgen05MMACollectorOpUse : I32EnumAttrCase<"USE", 3, "use">; + +def Tcgen05MMACollectorOp : I32EnumAttr< + "Tcgen05MMACollectorOp", + "tcgen05.mma Collector Buffer Operation", + [Tcgen05MMACollectorOpDiscard, + Tcgen05MMACollectorOpLastUse, + Tcgen05MMACollectorOpFill, + Tcgen05MMACollectorOpUse]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def Tcgen05MMACollectorOpAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorOp, "tcgen05_mma_collectorop"> { + let description = [{ + Tcgen05MMACollectorOp attribute specifies the collector buffer operations. + The following are the supported operations: + * discard : Release buffer after use (default) + * lastuse : Mark buffer for last use + * fill : Fill buffer + * use : Use buffer without modification + }]; + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma", + [AttrSizedOperandSegments, + NVVMRequiresSMa<[100, 110]>]> { + let summary = "Performs MMA operation on 5th-gen tensor cores"; + + let description = [{ + The `tcgen05.mma` operation is an asynchronous tensor core instruction that + performs matrix multiplication, accumulation in a single fused operation. It + targets 5th-generation tensor cores, providing developers with fine-grained + control over execution and scheduling. + + ``` + D = A * B + (D * 2^ -scaleInputD) // if `scaleInputD` is provided + D = A * B // if `enableInputD` is false + D = A * B + D // otherwise + ``` + + where: + - A is an `M x K` matrix in tensor memory or described using shared memory descriptor + - B is a `K x N` matrix described using shared memory descriptor + - D is an `M x N` accumulator matrix in tensor memory + + The `shared memory descriptor` can be generated using `tcgen05.mma_smem_desc` Op + + - idesc is a 32-bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor) + + Optional Operands: + - `scaleInputD` is an Immediate value operand used for scaling D matrix by 2 ^ (-scaleInputD). The valid range is [0, 15] + + - `disableOutputLane` is a vector mask for selective output + * vector<4 x i32> when ctaGroup is CTA_1 + * vector<8 x i32> when ctaGroup is CTA_2 + + Required Attributes: + - `kind` is a Tcgen05MMAKind attribute + + - `ctaGroup` specifies CTA group configuration + * cta_1: MMA will be performed on the current thread's CTA + * cta_2: MMA will be performed on the current thread and it's peer CTA + + Default Attributes: + - collectorOp is a Tcgen05MMACollectorOp attribute with matrix A as the collector buffer + + - `aShift` shifts the rows of the A matrix down by one row and can only be + applied if A is in tensor memory + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma) + }]; + + let arguments = (ins + Tcgen05MMAKindAttr:$kind, + CTAGroupKindAttr:$ctaGroup, + DefaultValuedAttr<Tcgen05MMACollectorOpAttr, + "Tcgen05MMACollectorOp::DISCARD">:$collectorOp, + UnitAttr:$aShift, + LLVM_PointerTensor:$matrixD, + AnyTypeOf<[LLVM_PointerTensor, I64]>:$matrixA, + I64:$matrixB, + I32:$idesc, + I1:$enableInputD, + Optional<I64>:$scaleInputD, + Optional<FixedVectorOfLengthAndType<[4, 8], [I32]>>:$disableOutputLane + ); + + let assemblyFormat = [{ + $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD (`scale` `=` $scaleInputD^)? + (`mask` `=` $disableOutputLane^)? attr-dict `:` `(` type(operands) `)` + }]; + + let hasVerifier = true; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + let llvmBuilder = [{ + auto [ID, args] = NVVM::Tcgen05MMAOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, ID, args); + }]; +} + +def NVVM_Tcgen05MMASparseOp : NVVM_Op<"tcgen05.mma.sp", + [AttrSizedOperandSegments, + NVVMRequiresSMa<[100, 110]>]> { + let summary = "Performs MMA operation with sparse A matrix on 5th-gen tensor cores"; + + let description = [{ + The `tcgen05.mma.sp` operation is an asynchronous tensor core instruction + that performs matrix multiplication, accumulation with sparse `A` matrix in + a single fused operation. It targets 5th-generation tensor cores, providing + developers with fine-grained control over execution and scheduling. + + ``` + D = A * B + (D * 2^ -scaleInputD) // if `scaleInputD` is provided + D = A * B // if `enableInputD` is false + D = A * B + D // otherwise + ``` + + where: + - A is an `M x (K / 2)` matrix in tensor memory or described using shared memory descriptor + - B is a `K x N` matrix described using shared memory descriptor + - D is an `M x N` accumulator matrix in tensor memory + - sparseMetadata located in tensor memory specifies the mapping of the `K / 2` + non-zero elements to the K elements before performing the MMA operation + + Other attributes and operands are similar to that of tcgen05.mma Op + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma-sp) + }]; + + let arguments = (ins + Tcgen05MMAKindAttr:$kind, + CTAGroupKindAttr:$ctaGroup, + DefaultValuedAttr<Tcgen05MMACollectorOpAttr, + "Tcgen05MMACollectorOp::DISCARD">:$collectorOp, + UnitAttr:$aShift, + LLVM_PointerTensor:$matrixD, + AnyTypeOf<[LLVM_PointerTensor, I64]>:$matrixA, + I64:$matrixB, + I32:$idesc, + I1:$enableInputD, + LLVM_PointerTensor:$sparseMetadata, + Optional<I64>:$scaleInputD, + Optional<FixedVectorOfLengthAndType<[4, 8], [I32]>>:$disableOutputLane + ); + + let assemblyFormat = [{ + $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $sparseMetadata (`scale` `=` $scaleInputD^)? (`mask` `=` $disableOutputLane^)? attr-dict `:` `(` type(operands) `)` + }]; + + let hasVerifier = true; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + let llvmBuilder = [{ + auto [ID, args] = NVVM::Tcgen05MMASparseOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, ID, args); + }]; +} + +def Tcgen05MMAKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">; +def Tcgen05MMAKindMXF4 : I32EnumAttrCase<"MXF4", 1, "mxf4">; +def Tcgen05MMAKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">; + +def Tcgen05MMABlockScaleKind : I32EnumAttr< + "Tcgen05MMABlockScaleKind", + "tcgen05.mma.block_scale supported types", + [Tcgen05MMAKindMXF8F6F4, Tcgen05MMAKindMXF4, Tcgen05MMAKindMXF4NVF4]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def Tcgen05MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScaleKind, + "tcgen05_mma_block_scale_kind"> { + let description = [{ + The Tcgen05MMABlockScaleKind attribute describes the allowed set of types for matrix A and B in the tcgen05.mma.{sp}.block_scale Op. The following are supported types for each kind: + + ``` + +--------------+-------------------------------------------+ + | Matrix Kind | supported types for A / B | + +--------------+-------------------------------------------+ + | mxf8f6f4 | e4m3, e5m3, e2m3, e3m2, e2m1 | + | mxf4 | e2m1 | + | mxf4nvf4 | e2m1 | + +--------------+-------------------------------------------+ + ``` + }]; + let assemblyFormat = "`<` $value `>`"; +} + +def Tcgen05MMABlockScaleDefault : I32EnumAttrCase<"DEFAULT", 0, "default">; +def Tcgen05MMABlockScaleBlock16 : I32EnumAttrCase<"BLOCK16", 1, "block16">; +def Tcgen05MMABlockScaleBlock32 : I32EnumAttrCase<"BLOCK32", 2, "block32">; + +def Tcgen05MMABlockScale + : I32EnumAttr<"Tcgen05MMABlockScale", + "tcgen05.mma block scale attribute", + [Tcgen05MMABlockScaleDefault, Tcgen05MMABlockScaleBlock16, + Tcgen05MMABlockScaleBlock32]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def Tcgen05MMABlockScaleAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScale, + "tcgen05_mma_block_scale"> { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale", + [NVVMRequiresSMa<[100, 110]>]> { + let summary = "Performs block scaled MMA operation on 5th-gen tensor cores"; + + let description = [{ + The `tcgen05.mma.block_scale` operation is an asynchronous tensor core instruction + that performs matrix multiplication, accumulation with block scaling in a + single fused operation. It targets 5th-generation tensor cores, providing + developers with fine-grained control over execution and scheduling. + + ``` + D = (A * scale_a) * (B * scale_b)` // if `enableInputD` is false + D = (A * scale_a) * (B * scale_b) + D` + ``` + + where: + - A is an M x (K / 2) matrix in tensor memory or described using shared memory descriptor + - B is a K x N matrix described using shared memory descriptor + - D is an M x N accumulator matrix in tensor memory + - `scale_a` and `scale_b` are matrices in tensor memory used to scale `A` and `B` respectively + + The `shared memory descriptor` can be generated using `tcgen05.mma_smem_desc` Op + + - `idesc` is a 32 bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor) + + Required Attributes: + - `kind` is a Tcgen05MMABlockScaleKind attribute + + - `ctaGroup` specifies CTA group configuration + * cta_1: MMA will be performed on the current thread's CTA + * cta_2: MMA will be performed on the current thread and it's peer CTA + + Default Attributes: + - collectorOp is a Tcgen05MMACollectorOp attribute with matrix A as the collector buffer + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma) + }]; + + let arguments = (ins + Tcgen05MMABlockScaleKindAttr:$kind, + CTAGroupKindAttr:$ctaGroup, + DefaultValuedAttr<Tcgen05MMABlockScaleAttr, + "Tcgen05MMABlockScale::DEFAULT">:$blockScale, + DefaultValuedAttr<Tcgen05MMACollectorOpAttr, + "Tcgen05MMACollectorOp::DISCARD">:$collectorOp, + LLVM_PointerTensor:$matrixD, + AnyTypeOf<[LLVM_PointerTensor, I64]>:$matrixA, + I64:$matrixB, + I32:$idesc, I1:$enableInputD, + LLVM_PointerTensor:$scaleA, + LLVM_PointerTensor:$scaleB + ); + + let assemblyFormat = [{ + $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $scaleA `,` $scaleB + attr-dict `:` `(` type(operands) `)` + }]; + + let hasVerifier = true; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + let llvmBuilder = [{ + auto [ID, args] = NVVM::Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, ID, args); + }]; +} + +def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale", + [NVVMRequiresSMa<[100, 110]>]> { + let summary = "Performs block scaled MMA operation with sparse A matrix on 5th-gen tensor cores"; + + let description = [{ + The `tcgen05.mma.sp.block_scale` operation is an asynchronous tensor core + instruction that performs matrix multiplication, accumulation with block + scaling, and sparse `A` matrix in a single fused operation. It targets + 5th-generation tensor cores, providing developers with fine-grained control + over execution, and scheduling. + + ``` + D = (A * scale_a) * (B * scale_b) // if `enableInputD` is specified + D = (A * scale_a) * (B * scale_b) + D // otherwise + ``` + + where: + - A is an M x (K / 2) matrix in tensor memory or described using shared memory descriptor + - B is a K x N matrix described using shared memory descriptor + - D is an M x N accumulator matrix in tensor memory + - `scale_a` and `scale_b` are matrices in tensor memory used to scale `A` and `B` respectively + + Other attributes and operands are similar to that of tcgen05.mma.block_scale Op + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma-sp) + }]; + + let arguments = (ins + Tcgen05MMABlockScaleKindAttr:$kind, + CTAGroupKindAttr:$ctaGroup, + DefaultValuedAttr<Tcgen05MMABlockScaleAttr, + "Tcgen05MMABlockScale::DEFAULT">:$blockScale, + DefaultValuedAttr<Tcgen05MMACollectorOpAttr, + "Tcgen05MMACollectorOp::DISCARD">:$collectorOp, + LLVM_PointerTensor:$matrixD, + AnyTypeOf<[LLVM_PointerTensor, I64]>:$matrixA, + I64:$matrixB, + I32:$idesc, + I1:$enableInputD, + LLVM_PointerTensor:$sparseMetadata, + LLVM_PointerTensor:$scaleA, + LLVM_PointerTensor:$scaleB + ); + + let assemblyFormat = [{ + $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $sparseMetadata `,` $scaleA `,` $scaleB + attr-dict `:` `(` type(operands) `)` + }]; + + let hasVerifier = true; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + let llvmBuilder = [{ + auto [ID, args] = NVVM::Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, ID, args); + }]; +} + +def Tcgen05MMACollectorBBuffer0 : I32EnumAttrCase<"B0", 0, "b0">; +def Tcgen05MMACollectorBBuffer1 : I32EnumAttrCase<"B1", 1, "b1">; +def Tcgen05MMACollectorBBuffer2 : I32EnumAttrCase<"B2", 2, "b2">; +def Tcgen05MMACollectorBBuffer3 : I32EnumAttrCase<"B3", 3, "b3">; + +def Tcgen05MMACollectorBBuffer : I32EnumAttr< + "Tcgen05MMACollectorBBuffer", + "tcgen05 MMA Collector Buffer B Attribute", + [Tcgen05MMACollectorBBuffer0, Tcgen05MMACollectorBBuffer1, Tcgen05MMACollectorBBuffer2, + Tcgen05MMACollectorBBuffer3]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def Tcgen05MMACollectorBBufferAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorBBuffer, "tcgen05_mma_collectorb"> { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws", + [NVVMRequiresSMa<[100, 110]>]> { + let summary = "Performs weight stationary convolution MMA operation on 5th-gen tensor cores"; + + let description = [{ + The `tcgen05.mma.ws` operation is an asynchronous tensor core instruction + that performs weight stationary convolution matrix multiplication, accumulation + in a single fused operation. It targets 5th-generation tensor cores, providing + developers with fine-grained control over execution, and scheduling. + + ``` + D = A * B` // if `enableInputD` is false + D = A * B + D` // otherwise + ``` + + where: + - A is an `M x K` matrix in tensor memory or described using shared memory descriptor + - B is a `K x N` matrix described using shared memory descriptor + - D is an `M x N` accumulator matrix in tensor memory + + The `shared memory descriptor` can be generated using `tcgen05.mma_smem_desc` Op + + - idesc is a 32-bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor) + + Optional Operands: + - zeroColMask is a 64 bit value representing the [Zero-column mask descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-zero-column-mask-descriptor) + + Required Attributes: + - `kind` is a Tcgen05MMAKind attribute + + Default Valued Attributes: + - collectorBBuffer specifies collector buffer for matrix B: b0 (default), b1, b2, b3 + + - collectorOp is a Tcgen05MMACollectorOp attribute with matrix B as the collector buffer + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma-ws) + }]; + + let arguments = (ins + Tcgen05MMAKindAttr:$kind, + DefaultValuedAttr<Tcgen05MMACollectorBBufferAttr, + "Tcgen05MMACollectorBBuffer::B0">:$collectorBBuffer, + DefaultValuedAttr<Tcgen05MMACollectorOpAttr, + "Tcgen05MMACollectorOp::DISCARD">:$collectorOp, + LLVM_PointerTensor:$matrixD, + AnyTypeOf<[LLVM_PointerTensor, I64]>:$matrixA, + I64:$matrixB, + I32:$idesc, + I1:$enableInputD, + Optional<I64>:$zeroColMask + ); + + let assemblyFormat = [{ + $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD (`,` $zeroColMask^)? + attr-dict `:` `(` type(operands) `)` + }]; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + let llvmBuilder = [{ + auto [ID, args] = + NVVM::Tcgen05MMAWsOp::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); + createIntrinsicCall(builder, ID, args); + }]; +} + +def NVVM_Tcgen05MMAWsSparseOp : NVVM_Op<"tcgen05.mma.ws.sp", + [NVVMRequiresSMa<[100, 110]>]> { + let summary = "Performs weight stationary convolution MMA with sparse A matrix on 5th-gen tensor cores"; + + let description = [{ + The `tcgen05.mma.ws.sp` operation is an asynchronous tensor core instruction + that performs weight stationary convolution matrix multiplication, accumulation + with sparse `A` matrix in a single fused operation. It targets 5th-generation + tensor cores, providing developers with fine-grained control over execution, + and scheduling. + + ``` + D = A * B` // if `enableInputD` is false + D = A * B + D` // otherwise + ``` + + where: + - A is an M x (K / 2) matrix in memory or descriptor format + - B is a K x N matrix + - D is an M x N accumulator matrix + - sparseMetadata located in tensor memory specifies the mapping of the `K / 2` + non-zero elements to the K elements before performing the MMA operation + + Other attributes and operands are similar to that of tcgen05.mma.ws Op + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma-ws-sp) + }]; + + let arguments = (ins + Tcgen05MMAKindAttr:$kind, + DefaultValuedAttr<Tcgen05MMACollectorBBufferAttr, + "Tcgen05MMACollectorBBuffer::B0">:$collectorBBuffer, + DefaultValuedAttr<Tcgen05MMACollectorOpAttr, + "Tcgen05MMACollectorOp::DISCARD">:$collectorOp, + LLVM_PointerTensor:$matrixD, + AnyTypeOf<[LLVM_PointerTensor, I64]>:$matrixA, + I64:$matrixB, + I32:$idesc, + I1:$enableInputD, + LLVM_PointerTensor:$sparseMetadata, + Optional<I64>:$zeroColMask + ); + + let assemblyFormat = [{ + $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $sparseMetadata (`,` $zeroColMask^)? attr-dict `:` `(` type(operands) `)` + }]; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + let llvmBuilder = [{ + auto [ID, args] = NVVM::Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, ID, args); + }]; +} + +//===----------------------------------------------------------------------===// // NVVM target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 921fdf36..cd36300 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -254,7 +254,7 @@ def ROCDL_ReadlaneOp : ROCDL_IntrOp<"readlane", [], [0], [AllTypesMatch<["res", } //===----------------------------------------------------------------------===// -// Thread index and Block index +// Thread, Block and Cluster index //===----------------------------------------------------------------------===// def ROCDL_ThreadIdXOp : ROCDL_SpecialIdRegisterOp<"workitem.id.x">; @@ -265,6 +265,10 @@ def ROCDL_BlockIdXOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.x">; def ROCDL_BlockIdYOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.y">; def ROCDL_BlockIdZOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.z">; +def ROCDL_ClusterIdXOp : ROCDL_SpecialIdRegisterOp<"cluster.id.x">; +def ROCDL_ClusterIdYOp : ROCDL_SpecialIdRegisterOp<"cluster.id.y">; +def ROCDL_ClusterIdZOp : ROCDL_SpecialIdRegisterOp<"cluster.id.z">; + def ROCDL_WavefrontSizeOp : ROCDL_SpecialIdRegisterOp<"wavefrontsize">; //===----------------------------------------------------------------------===// @@ -321,6 +325,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> { let assemblyFormat = "attr-dict"; } +def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>, @@ -389,6 +394,15 @@ def ROCDL_GetBarrierStateOp : ROCDL_ConcreteNonMemIntrOp<"s.get.barrier.state", let assemblyFormat = "$id attr-dict `:` type($res)"; } +def ROCDL_GetNamedBarrierStateOp : ROCDL_ConcreteNonMemIntrOp<"s.get.named.barrier.state", [], 1, [], []>, + Arguments<(ins Arg<ROCDLBufferLDS, "", []>:$ptr)> { + let description = [{ + Available on gfx1250+. + }]; + let results = (outs I32:$res); + let assemblyFormat = "$ptr attr-dict `:` type($res)"; +} + def ROCDL_WaitDscntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.dscnt", [], 0, [0], ["count"]>, Arguments<(ins I16Attr:$count)> { let summary = "Wait until DSCNT is less than or equal to `count`"; @@ -582,57 +596,208 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f //===---------------------------------------------------------------------===// // WMMA intrinsics -class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands, - list<Trait> traits = []> : - ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>, - Arguments<(ins Variadic<LLVM_Type>:$args)> { - let assemblyFormat = - "$args attr-dict `:` functional-type($args, $res)"; +class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, + [0], [0], [], 1, 0, 0, 0, [], []>, + Arguments<(ins + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, + [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>, + Arguments<(ins + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I1Attr, "0">:$opsel)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, + [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>, + Arguments<(ins + DefaultValuedAttr<I1Attr, "0">:$signA, + LLVM_ScalarOrVectorOf<AB>:$a, + DefaultValuedAttr<I1Attr, "0">:$signB, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I1Attr, "0">:$clamp)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, + [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>, + Arguments<(ins + DefaultValuedAttr<I1Attr, "0">:$signA, + LLVM_ScalarOrVectorOf<AB>:$a, + DefaultValuedAttr<I1Attr, "0">:$signB, + LLVM_ScalarOrVectorOf<AB>:$b, + DefaultValuedAttr<I16Attr, "0">:$modC, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, + [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>, + Arguments<(ins + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, + DefaultValuedAttr<I16Attr, "0">:$modC, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic, + [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>, + Arguments<(ins + DefaultValuedAttr<I1Attr, "0">:$signA, + LLVM_ScalarOrVectorOf<AB>:$a, + DefaultValuedAttr<I1Attr, "0">:$signB, + LLVM_ScalarOrVectorOf<AB>:$b, + DefaultValuedAttr<I16Attr, "0">:$modC, + LLVM_ScalarOrVectorOf<C>:$c, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<D>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic, + [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>, + Arguments<(ins + DefaultValuedAttr<I1Attr, "0">:$signA, + LLVM_ScalarOrVectorOf<AB>:$a, + DefaultValuedAttr<I1Attr, "0">:$signB, + LLVM_ScalarOrVectorOf<AB>:$b, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c attr-dict `:` functional-type(operands, $res) + }]; +} + +// Overloaded operands: [1, 3] refers to LLVM intrinsic parameter positions where +// A is at position 1 and B is at position 3 (after format parameters). +class ROCDL_WMMA_Scale_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic, + [0], [1, 3], [], 1, 0, 0, 0, [0, 2, 4, 6, 7, 9, 10, 12, 13], + ["fmtA", "fmtB", "modC", "scaleAType", "fmtScaleA", + "scaleBType", "fmtScaleB", "reuseA", "reuseB"]>, + Arguments<(ins + DefaultValuedAttr<I32Attr, "0">:$fmtA, + LLVM_ScalarOrVectorOf<AB>:$a, + DefaultValuedAttr<I32Attr, "0">:$fmtB, + LLVM_ScalarOrVectorOf<AB>:$b, + DefaultValuedAttr<I16Attr, "0">:$modC, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I32Attr, "0">:$scaleAType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleA, + ScaleExpTy:$scaleA, + DefaultValuedAttr<I32Attr, "0">:$scaleBType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleB, + ScaleExpTy:$scaleB, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_Scale_F4_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic, + [0], [0, 1], [], 1, 0, 0, 0, [2, 4, 5, 7, 8, 10, 11], + ["modC", "scaleAType", "fmtScaleA", + "scaleBType", "fmtScaleB", "reuseA", "reuseB"]>, + Arguments<(ins + LLVM_ScalarOrVectorOf<AB>:$a, + LLVM_ScalarOrVectorOf<AB>:$b, + DefaultValuedAttr<I16Attr, "0">:$modC, + LLVM_ScalarOrVectorOf<CD>:$c, + DefaultValuedAttr<I32Attr, "0">:$scaleAType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleA, + ScaleExpTy:$scaleA, + DefaultValuedAttr<I32Attr, "0">:$scaleBType, + DefaultValuedAttr<I32Attr, "0">:$fmtScaleB, + ScaleExpTy:$scaleB, + DefaultValuedAttr<I1Attr, "0">:$reuseA, + DefaultValuedAttr<I1Attr, "0">:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf<CD>:$res); + let assemblyFormat = [{ + $a `,` $b `,` $c `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res) + }]; } // Available from gfx11 -def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>; -def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>; -def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>; -def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>; -def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>; -def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>; +def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>; +def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>; +def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>; +def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>; +def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>; +def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>; // Available from gfx12 -def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; -def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>; -def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; -def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>; -def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>; +def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>; +def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>; // Available from gfx1250 -def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>; -def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>; -def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>; -def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>; -def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>; -def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>; -def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>; -def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>; -def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>; -def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>; -def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>; +def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>; +def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>; +def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>; +def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>; +def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>; +def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>; +def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>; +def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>; +def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>; + +// Scaled wmma intrinsics (available from gfx1250) +def ROCDL_wmma_scale_f32_16x16x128_f8f6f4 : ROCDL_WMMA_Scale_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4", AnyInteger, F32, I32>; +def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_WMMA_Scale_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4", AnyInteger, F32, I64>; +def ROCDL_wmma_scale_f32_32x16x128_f4 : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale.f32.32x16x128.f4", AnyInteger, F32, I32>; +def ROCDL_wmma_scale16_f32_32x16x128_f4 : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale16.f32.32x16x128.f4", AnyInteger, F32, I64>; //===---------------------------------------------------------------------===// // LDS transpose intrinsics (available in GFX950) -def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>; - class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> : ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> { dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr); @@ -650,6 +815,58 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">; def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">; def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">; + + +//===---------------------------------------------------------------------===// +// Glb/DS load-transpose intrinsics (available in GFX1250+) + +class AddrKind<string n, int s> { + string name = n; + int space = s; +} +def GlobalAddrKind : AddrKind<"global", 1>; +def DSAddrKind : AddrKind<"ds", 3>; + +class ROCDL_TrLoadOpMeta<AddrKind kind, int inElemBits, int outElemBits> { + AddrKind addrKind = kind; + string inBits = !cast<string>(inElemBits); + string outBits = !cast<string>(outElemBits); + string inBitsEnc = !if(!eq(addrKind.space, 1), + !if(!or(!eq(inElemBits, 8), !eq(inElemBits, 16)), "", inBits), inBits); + string mnemonic = addrKind.name # ".load.tr" # inBitsEnc # ".b" # outBits; +} + +class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> : + ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> { + + dag args = (ins Arg<LLVM_PointerInAddressSpace<meta.addrKind.space>, "", [MemRead]>:$ptr); + let arguments = !con(args, baseArgs); + let summary = "Loads and transposes a matrix from " # meta.addrKind.name # " memory to registers (available in gfx1250+)."; + let description = [{ + Load a matrix of }] # meta.inBits # [{-bit data from the }] # meta.addrKind.name # [{ memory, + transpose data between row-major and column-major order, + and store the result into a }] # meta.outBits # [{-bit vector register. + + Available in gfx1250+. + }]; + let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)"; + let extraClassDefinition = [{ + ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { + return {getPtr()}; + } + }]; +} + +def ROCDL_GlobalLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64>>; +def ROCDL_GlobalLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64>>; +def ROCDL_GlobalLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96>>; +def ROCDL_GlobalLoadTr8_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128>>; + +def ROCDL_DsLoadTr4_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64>>; +def ROCDL_DsLoadTr8_B64 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64>>; +def ROCDL_DsLoadTr6_B96 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96>>; +def ROCDL_DsLoadTr16_B128 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128>>; + //===---------------------------------------------------------------------===// // Load to LDS intrinsic (available in GFX9 and GFX10) //===---------------------------------------------------------------------===// @@ -707,7 +924,7 @@ foreach bitsVal = [8, 32, 64, 128] in { let arguments = !con(args, baseArgs); let assemblyFormat = [{ $globalPtr `,` $ldsPtr `,` $offset `,` $aux - attr-dict `:` type($globalPtr) `,` type($ldsPtr) + attr-dict `:` qualified(type($globalPtr)) `,` qualified(type($ldsPtr)) }]; let description = [{ Asynchronously loads }] # !cast<string>(bitsVal) # [{ bits of data from a global memory pointer @@ -724,6 +941,34 @@ foreach bitsVal = [8, 32, 64, 128] in { } } +foreach bitsVal = [8, 32, 64, 128] in { + defvar bitsStr = "b" # !cast<string>(bitsVal); + def ROCDL_ClusterLoadAsyncToLDS # !toupper(bitsStr) # Op : + ROCDL_IntrOp<"cluster.load.async.to.lds." # bitsStr, [], [], [], 0, 0, 1, 0, [2, 3, 4], ["offset", "cpol", "mask"]> { + dag args = (ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr, + Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr, + I32Attr:$offset, + I32Attr:$cpol, + I32Attr:$mask); + let arguments = !con(args, baseArgs); + let assemblyFormat = [{ + $globalPtr `,` $ldsPtr `,` $offset `,` $cpol `,` $mask + attr-dict `:` qualified(type($globalPtr)) `,` qualified(type($ldsPtr)) + }]; + let description = [{ + Broadcasts memory load of }] # !cast<string>(bitsVal) # [{ bits of data for a cluster of workgroups. + + Available on gfx1250+. + }]; + + let extraClassDefinition = [{ + ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { + return {getGlobalPtr(), getLdsPtr()}; + } + }]; + } +} + //===---------------------------------------------------------------------===// // Tensor load/store intrinsics (available in GFX1250) //===---------------------------------------------------------------------===// @@ -1669,6 +1914,33 @@ def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res", } //===----------------------------------------------------------------------===// +// Math operations +//===----------------------------------------------------------------------===// + +class ROCDL_Math_IntrOp<string mnemonic, list<Trait> traits = [Pure]> : + ROCDL_IntrOp<mnemonic, [0], [], traits, 1>, + Arguments<(ins LLVM_AnyFloat:$arg)> { + let results = (outs LLVM_AnyFloat:$res); + let description = [{ + Note: In the general case, prefer the conventional `arith`, `math`, or `llvm` ops over this. + Use this ROCDL-specific operation only when you fully understand its implication and + when it is strictly necessary. This op is usually chosen when a small loss in precision is + acceptable in exchange for higher execution speed. + }]; + let assemblyFormat = + "$arg qualified(type($arg)) attr-dict `->` qualified(type($res))"; +} + +def ROCDLTanh : ROCDL_Math_IntrOp<"tanh">; +def ROCDLSin : ROCDL_Math_IntrOp<"sin">; +def ROCDLCos : ROCDL_Math_IntrOp<"cos">; +def ROCDLRcp : ROCDL_Math_IntrOp<"rcp">; +def ROCDLExp : ROCDL_Math_IntrOp<"exp">; +def ROCDLExp2 : ROCDL_Math_IntrOp<"exp2">; +def ROCDLLog : ROCDL_Math_IntrOp<"log">; +def ROCDLSqrt : ROCDL_Math_IntrOp<"sqrt">; + +//===----------------------------------------------------------------------===// // ROCDL target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 9aae1b8..521afc9 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata doc: |- Fills the output tensor with the given value. - Works for arbitrary ranked output tensors since the operation performs scalar - accesses only and is thus rank polymorphic. Numeric casting is performed on - the value operand, promoting it to the same data type as the output. + Works for arbitrary ranked output tensors since the operation performs + scalar accesses only and is thus rank polymorphic. The value operand + type must match the element type of the output. implements: - LinalgFillOpInterface defines: @@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig - !LinalgOperandDefConfig name: value kind: scalar - type_var: T1 + type_var: T - !LinalgOperandDefConfig name: O kind: output_tensor - type_var: U + type_var: T shape_map: affine_map<() -> ()> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: @@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: value + scalar_arg: value --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 7ff44c2..2754ee3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>, def Linalg_SoftmaxOp : Linalg_Op<"softmax", [DestinationStyleOpInterface, PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes"]>, DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<TilingInterface, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index 6504ca8..784bdd8 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> : DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, DestinationStyleOpInterface, LinalgRelayoutOpInterface, ConditionallySpeculatable, NoMemoryEffect, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>, TypesMatchWith<"result type matches type of dest", "dest", "result", "$_self">])> { @@ -108,7 +109,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ within [0, n). - The tiled dimensions (of size `inner_tiles`) are added to the end of the result tensor in the order in which they appear, i.e. - `shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`. + `shape(result)[rank(source) + i] = inner_tiles[i]` for `0 <= i < k`. - The following relationship for the tiled dimensions holds: `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`, where (⌈/⌉ indicates CeilDiv). diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index dfb32a0..4948bff 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -222,7 +222,6 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [ let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasFolder = 1; - let hasVerifier = 1; } @@ -620,7 +619,6 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ let hasCustomAssemblyFormat = 1; let hasFolder = 1; - let hasVerifier = 1; let extraClassDeclaration = structuredOpsBaseDecls # [{ /// Get the arity enum corresponding to the kind of op, e.g. if arg is diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index de07f50..9da01f3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -103,6 +103,17 @@ std::optional<SmallVector<ReassociationIndices>> getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes); //===----------------------------------------------------------------------===// +// Convolution matcher utility +//===----------------------------------------------------------------------===// + +/// Given a linalg `op` this function returns true if it is a convolution op of +/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred +/// from the indexing maps. +template <typename ConvOpTy> +bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides); + +//===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td index 3be84ae..20dd452 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -19,7 +19,14 @@ def MemRef_Dialect : Dialect { manipulation ops, which are not strongly associated with any particular other dialect or domain abstraction. }]; - let dependentDialects = ["arith::ArithDialect"]; + let dependentDialects = [ + // `arith` is a dependency because it is used to materialize constants, + // and in some canonicalization patterns. + "arith::ArithDialect", + // `ub` is a dependency because `AllocaOp::getDefaultValue` can produce a + // `ub.poison` value. + "ub::UBDialect" + ]; let hasConstantMaterializer = 1; } diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 8965302..0bf2292 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1783,7 +1783,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> : def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> { + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes"]>]> { let summary = "operation to produce a memref with a higher rank."; let description = [{ The `memref.expand_shape` op produces a new view with a higher rank whose diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index f3e40aa..c403386 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass implement the `ReifyRankedShapedTypeOpInterface` in terms of shapes of its operands. }]; + let options = [ + Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool", + /*default=*/"true", + "Throw an error when pattern rewriter hits iteration limit">, + ]; let dependentDialects = [ "memref::MemRefDialect", "tensor::TensorDialect" ]; @@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> { `ReifyRankedShapedTypeOpInterface` in terms of shapes of its operands. }]; + let options = [ + Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool", + /*default=*/"true", + "Throw an error when pattern rewriter hits iteration limit">, + ]; let dependentDialects = [ "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect" ]; diff --git a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h index d9b2646..7be525e 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h +++ b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h @@ -58,8 +58,10 @@ namespace mlir { namespace acc { -// Forward declaration for RecipeKind enum +// Forward declarations enum class RecipeKind : uint32_t; +bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr); namespace detail { /// This class contains internal trait classes used by OpenACCSupport. @@ -79,11 +81,27 @@ struct OpenACCSupportTraits { // Used to report a case that is not supported by the implementation. virtual InFlightDiagnostic emitNYI(Location loc, const Twine &message) = 0; + + /// Check if a symbol use is valid for use in an OpenACC region. + virtual bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) = 0; }; + /// SFINAE helpers to detect if implementation has optional methods + template <typename ImplT, typename... Args> + using isValidSymbolUse_t = + decltype(std::declval<ImplT>().isValidSymbolUse(std::declval<Args>()...)); + + template <typename ImplT> + using has_isValidSymbolUse = + llvm::is_detected<isValidSymbolUse_t, ImplT, Operation *, SymbolRefAttr, + Operation **>; + /// This class wraps a concrete OpenACCSupport implementation and forwards /// interface calls to it. This provides type erasure, allowing different /// implementation types to be used interchangeably without inheritance. + /// Methods can be optionally implemented; if not present, default behavior + /// is used. template <typename ImplT> class Model final : public Concept { public: @@ -102,6 +120,14 @@ struct OpenACCSupportTraits { return impl.emitNYI(loc, message); } + bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) final { + if constexpr (has_isValidSymbolUse<ImplT>::value) + return impl.isValidSymbolUse(user, symbol, definingOpPtr); + else + return acc::isValidSymbolUse(user, symbol, definingOpPtr); + } + private: ImplT impl; }; @@ -154,6 +180,15 @@ public: /// unsupported case. InFlightDiagnostic emitNYI(Location loc, const Twine &message); + /// Check if a symbol use is valid for use in an OpenACC region. + /// + /// \param user The operation using the symbol. + /// \param symbol The symbol reference being used. + /// \param definingOpPtr Optional output parameter to receive the defining op. + /// \return true if the symbol use is valid, false otherwise. + bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr = nullptr); + /// Signal that this analysis should always be preserved so that /// underlying implementation registration is not lost. bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index 05d2316..84fbf2c 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -152,6 +152,9 @@ mlir::ValueRange getDataOperands(mlir::Operation *accOp); /// Used to get a mutable range iterating over the data operands. mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp); +/// Used to get the recipe attribute from a data clause operation. +mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp); + /// Used to check whether the provided `type` implements the `PointerLikeType` /// interface. inline bool isPointerLikeType(mlir::Type type) { @@ -174,7 +177,27 @@ static constexpr StringLiteral getDeclareActionAttrName() { } static constexpr StringLiteral getRoutineInfoAttrName() { - return StringLiteral("acc.routine_info"); + return RoutineInfoAttr::name; +} + +static constexpr StringLiteral getSpecializedRoutineAttrName() { + return SpecializedRoutineAttr::name; +} + +/// Used to check whether the current operation is marked with +/// `acc routine`. The operation passed in should be a function. +inline bool isAccRoutine(mlir::Operation *op) { + return op->hasAttr(mlir::acc::getRoutineInfoAttrName()); +} + +/// Used to check whether this is a specialized accelerator version of +/// `acc routine` function. +inline bool isSpecializedAccRoutine(mlir::Operation *op) { + return op->hasAttr(mlir::acc::getSpecializedRoutineAttrName()); +} + +static constexpr StringLiteral getFromDefaultClauseAttrName() { + return StringLiteral("acc.from_default"); } static constexpr StringLiteral getVarNameAttrName() { diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 5b89f74..146dc5d 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -152,6 +152,32 @@ def OpenACC_LoopParMode : I32EnumAttr< let genSpecializedAttr = 0; } +// Parallelism level (gang/worker/vector/seq). +// GangDim1 is the default gang level (equivalent to just "gang"). +// GangDim2/GangDim3 are for gang(dim:2) and gang(dim:3). +def OpenACC_ParLevelSeq : I32EnumAttrCase<"seq", 0>; +def OpenACC_ParLevelGangDim1 : I32EnumAttrCase<"gang_dim1", 1>; +def OpenACC_ParLevelGangDim2 : I32EnumAttrCase<"gang_dim2", 2>; +def OpenACC_ParLevelGangDim3 : I32EnumAttrCase<"gang_dim3", 3>; +def OpenACC_ParLevelWorker : I32EnumAttrCase<"worker", 4>; +def OpenACC_ParLevelVector : I32EnumAttrCase<"vector", 5>; + +def OpenACC_ParLevel : I32EnumAttr<"ParLevel", + "Parallelism level (gang/worker/vector/seq)", + [OpenACC_ParLevelSeq, + OpenACC_ParLevelGangDim1, OpenACC_ParLevelGangDim2, + OpenACC_ParLevelGangDim3, + OpenACC_ParLevelWorker, OpenACC_ParLevelVector]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::acc"; +} + +def OpenACC_ParLevelAttr : EnumAttr<OpenACC_Dialect, + OpenACC_ParLevel, + "par_level"> { + let assemblyFormat = [{ ```<` $value `>` }]; +} + def OpenACC_PrivateRecipe : I32EnumAttrCase<"private_recipe", 0>; def OpenACC_FirstprivateRecipe : I32EnumAttrCase<"firstprivate_recipe", 1>; def OpenACC_ReductionRecipe : I32EnumAttrCase<"reduction_recipe", 2>; @@ -637,7 +663,8 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio DefaultValuedAttr<BoolAttr, "false">:$implicit, DefaultValuedAttr<OpenACC_DataClauseModifierAttr, "mlir::acc::DataClauseModifier::none">:$modifiers, - OptionalAttr<StrAttr>:$name)); + OptionalAttr<StrAttr>:$name, + OptionalAttr<SymbolRefAttr>:$recipe)); let description = !strconcat(extraDescription, [{ Description of arguments: @@ -725,6 +752,7 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio | `bounds` `(` $bounds `)` | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly) + | `recipe` `(` custom<RecipeSym>($recipe) `)` ) `->` type($accVar) attr-dict }]; @@ -746,7 +774,7 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), /*implicit=*/$_builder.getBoolAttr(implicit), /*modifiers=*/nullptr, - /*name=*/nullptr); + /*name=*/nullptr, /*recipe=*/nullptr); }]>, OpBuilder<(ins "::mlir::Value":$var, "bool":$structured, "bool":$implicit, @@ -764,7 +792,7 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), /*implicit=*/$_builder.getBoolAttr(implicit), /*modifiers=*/nullptr, - /*name=*/$_builder.getStringAttr(name)); + /*name=*/$_builder.getStringAttr(name), /*recipe=*/nullptr); }]>, OpBuilder<(ins "::mlir::Type":$accVarType, "::mlir::Value":$var, "::mlir::Type":$varType, "::mlir::Value":$varPtrPtr, @@ -775,10 +803,27 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio "::mlir::acc::DataClause":$dataClause, "bool":$structured, "bool":$implicit, "::mlir::StringAttr":$name), [{ + // Builder provided to ease transition for new data clause modifiers operand. build($_builder, $_state, accVarType, var, varType, varPtrPtr, bounds, asyncOperands, asyncOperandsDeviceType, asyncOnly, dataClause, structured, implicit, ::mlir::acc::DataClauseModifier::none, name); }]>, + OpBuilder<(ins "::mlir::Type":$accVarType, "::mlir::Value":$var, + "::mlir::Type":$varType, "::mlir::Value":$varPtrPtr, + "::mlir::ValueRange":$bounds, + "::mlir::ValueRange":$asyncOperands, + "::mlir::ArrayAttr":$asyncOperandsDeviceType, + "::mlir::ArrayAttr":$asyncOnly, + "::mlir::acc::DataClause":$dataClause, "bool":$structured, + "bool":$implicit, + "::mlir::acc::DataClauseModifier":$modifiers, + "::mlir::StringAttr":$name), + [{ + // Builder provided to simplify building after recipe operand was added. + build($_builder, $_state, accVarType, var, varType, varPtrPtr, bounds, + asyncOperands, asyncOperandsDeviceType, asyncOnly, dataClause, + structured, implicit, modifiers, name, /*recipe=*/nullptr); + }]>, ]; } @@ -1375,6 +1420,19 @@ def OpenACC_PrivateRecipeOp ::mlir::Type varType, ::llvm::StringRef varName = "", ::mlir::ValueRange bounds = {}); + + /// Creates a PrivateRecipeOp using the same variable type as an existing + /// FirstprivateRecipeOp. This is a useful in cases where we promote private variables to firstprivate by analysis + /// This function reuses the init region from a firstprivate recipe when building a private + /// recipe. Callers thus must ensure that this is semantically valid for the language + /// lowering (e.g. that private does not perform extra default initialization + /// that firstprivate intentionally omits, such as for C++ classes or Fortran + /// derived types with default initialization). + static std::optional<PrivateRecipeOp> createAndPopulate( + ::mlir::OpBuilder &builder, + ::mlir::Location loc, + ::llvm::StringRef recipeName, + ::mlir::acc::FirstprivateRecipeOp firstprivRecipe); }]; } @@ -1665,12 +1723,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", Optional<I1>:$ifCond, Optional<I1>:$selfCond, UnitAttr:$selfAttr, - Variadic<AnyType>:$reductionOperands, - OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes, + Variadic<OpenACC_AnyPointerOrMappableType>:$reductionOperands, Variadic<OpenACC_AnyPointerOrMappableType>:$privateOperands, - OptionalAttr<SymbolRefArrayAttr>:$privatizationRecipes, Variadic<OpenACC_AnyPointerOrMappableType>:$firstprivateOperands, - OptionalAttr<SymbolRefArrayAttr>:$firstprivatizationRecipes, Variadic<OpenACC_AnyPointerOrMappableType>:$dataClauseOperands, OptionalAttr<DefaultValueAttr>:$defaultAttr, UnitAttr:$combined); @@ -1796,16 +1851,12 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly) - | `firstprivate` `(` custom<SymOperandList>($firstprivateOperands, - type($firstprivateOperands), $firstprivatizationRecipes) - `)` + | `firstprivate` `(` $firstprivateOperands `:` type($firstprivateOperands) `)` | `num_gangs` `(` custom<NumGangs>($numGangs, type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers, type($numWorkers), $numWorkersDeviceType) `)` - | `private` `(` custom<SymOperandList>( - $privateOperands, type($privateOperands), $privatizationRecipes) - `)` + | `private` `(` $privateOperands `:` type($privateOperands) `)` | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength, type($vectorLength), $vectorLengthDeviceType) `)` | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands), @@ -1813,9 +1864,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", $waitOnly) | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` - | `reduction` `(` custom<SymOperandList>( - $reductionOperands, type($reductionOperands), $reductionRecipes) - `)` + | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)` ) $region attr-dict-with-keyword }]; @@ -1863,12 +1912,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", Optional<I1>:$ifCond, Optional<I1>:$selfCond, UnitAttr:$selfAttr, - Variadic<AnyType>:$reductionOperands, - OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes, + Variadic<OpenACC_AnyPointerOrMappableType>:$reductionOperands, Variadic<OpenACC_AnyPointerOrMappableType>:$privateOperands, - OptionalAttr<SymbolRefArrayAttr>:$privatizationRecipes, Variadic<OpenACC_AnyPointerOrMappableType>:$firstprivateOperands, - OptionalAttr<SymbolRefArrayAttr>:$firstprivatizationRecipes, Variadic<OpenACC_AnyPointerOrMappableType>:$dataClauseOperands, OptionalAttr<DefaultValueAttr>:$defaultAttr, UnitAttr:$combined); @@ -1949,20 +1995,14 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly) - | `firstprivate` `(` custom<SymOperandList>($firstprivateOperands, - type($firstprivateOperands), $firstprivatizationRecipes) - `)` - | `private` `(` custom<SymOperandList>( - $privateOperands, type($privateOperands), $privatizationRecipes) - `)` + | `firstprivate` `(` $firstprivateOperands `:` type($firstprivateOperands) `)` + | `private` `(` $privateOperands `:` type($privateOperands) `)` | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands), $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum, $waitOnly) | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` - | `reduction` `(` custom<SymOperandList>( - $reductionOperands, type($reductionOperands), $reductionRecipes) - `)` + | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)` ) $region attr-dict-with-keyword }]; @@ -2001,8 +2041,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", corresponding `device_type` attributes must be modified as well. }]; - let arguments = (ins - Variadic<IntOrIndex>:$asyncOperands, + let arguments = (ins Variadic<IntOrIndex>:$asyncOperands, OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType, OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly, Variadic<IntOrIndex>:$waitOperands, @@ -2017,12 +2056,12 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", OptionalAttr<DeviceTypeArrayAttr>:$numWorkersDeviceType, Variadic<IntOrIndex>:$vectorLength, OptionalAttr<DeviceTypeArrayAttr>:$vectorLengthDeviceType, - Optional<I1>:$ifCond, - Optional<I1>:$selfCond, - UnitAttr:$selfAttr, + Optional<I1>:$ifCond, Optional<I1>:$selfCond, UnitAttr:$selfAttr, + Variadic<OpenACC_AnyPointerOrMappableType>:$reductionOperands, + Variadic<OpenACC_AnyPointerOrMappableType>:$privateOperands, + Variadic<OpenACC_AnyPointerOrMappableType>:$firstprivateOperands, Variadic<OpenACC_AnyPointerOrMappableType>:$dataClauseOperands, - OptionalAttr<DefaultValueAttr>:$defaultAttr, - UnitAttr:$combined); + OptionalAttr<DefaultValueAttr>:$defaultAttr, UnitAttr:$combined); let regions = (region AnyRegion:$region); @@ -2110,6 +2149,18 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", /// types. void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, llvm::ArrayRef<DeviceType>); + + /// Adds a private clause variable to this operation, including its recipe. + void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe); + /// Adds a firstprivate clause variable to this operation, including its + /// recipe. + void addFirstPrivatization(MLIRContext *, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe); + /// Adds a reduction clause variable to this operation, including its + /// recipe. + void addReduction(MLIRContext *, mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe); }]; let assemblyFormat = [{ @@ -2118,10 +2169,12 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)` | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>($asyncOperands, type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly) + | `firstprivate` `(` $firstprivateOperands `:` type($firstprivateOperands) `)` | `num_gangs` `(` custom<NumGangs>($numGangs, type($numGangs), $numGangsDeviceType, $numGangsSegments) `)` | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers, type($numWorkers), $numWorkersDeviceType) `)` + | `private` `(` $privateOperands `:` type($privateOperands) `)` | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength, type($vectorLength), $vectorLengthDeviceType) `)` | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands), @@ -2129,6 +2182,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", $waitOnly) | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` + | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)` ) $region attr-dict-with-keyword }]; @@ -2593,11 +2647,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", OptionalAttr<DeviceTypeArrayAttr>:$tileOperandsDeviceType, Variadic<OpenACC_AnyPointerOrMappableType>:$cacheOperands, Variadic<OpenACC_AnyPointerOrMappableType>:$privateOperands, - OptionalAttr<SymbolRefArrayAttr>:$privatizationRecipes, Variadic<OpenACC_AnyPointerOrMappableType>:$firstprivateOperands, - OptionalAttr<SymbolRefArrayAttr>:$firstprivatizationRecipes, Variadic<AnyType>:$reductionOperands, - OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes, OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined, UnitAttr:$unstructured ); @@ -2775,16 +2826,12 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", $workerNumOperandsDeviceType, $worker) | `vector` `` custom<DeviceTypeOperandsWithKeywordOnly>($vectorOperands, type($vectorOperands), $vectorOperandsDeviceType, $vector) - | `private` `(` custom<SymOperandList>( - $privateOperands, type($privateOperands), $privatizationRecipes) `)` - | `firstprivate` `(` custom<SymOperandList>($firstprivateOperands, - type($firstprivateOperands), $firstprivatizationRecipes) `)` + | `private` `(` $privateOperands `:` type($privateOperands) `)` + | `firstprivate` `(` $firstprivateOperands `:` type($firstprivateOperands) `)` | `tile` `(` custom<DeviceTypeOperandsWithSegment>($tileOperands, type($tileOperands), $tileOperandsDeviceType, $tileOperandsSegments) `)` - | `reduction` `(` custom<SymOperandList>( - $reductionOperands, type($reductionOperands), $reductionRecipes) - `)` + | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)` | `cache` `(` $cacheOperands `:` type($cacheOperands) `)` ) custom<LoopControl>($region, $lowerbound, type($lowerbound), $upperbound, @@ -2834,11 +2881,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", /*tileOperandsDeviceType=*/nullptr, /*cacheOperands=*/{}, /*privateOperands=*/{}, - /*privatizationRecipes=*/nullptr, /*firstprivateOperands=*/{}, - /*firstprivatizationRecipes=*/nullptr, /*reductionOperands=*/{}, - /*reductionRecipes=*/nullptr, /*combined=*/nullptr); }] > @@ -3241,6 +3285,18 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> { OptionalAttr<DeviceTypeArrayAttr>:$gangDimDeviceType); let extraClassDeclaration = [{ + // 'create' function to generate an 'empty' routine. + static RoutineOp create(::mlir::OpBuilder & builder, + ::mlir::Location location, + ::llvm::StringRef sym_name, + mlir::SymbolRefAttr func_name, bool implicit) { + return create(builder, location, sym_name, func_name, /*bindIDName=*/{}, + /*bindStrName=*/{}, /*bindIdNameDeviceType=*/{}, + /*bindStrnameDeviceType=*/{}, /*worker=*/{}, /*vector=*/{}, + /*seq=*/{}, /*nohost=*/false, implicit, /*gang=*/{}, + /*gangDim=*/{}, /*gangDimDeviceType=*/{}); + } + static StringRef getGangDimKeyword() { return "dim"; } /// Return true if the op has the worker attribute for the @@ -3276,6 +3332,26 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> { std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(); std::optional<::std::variant<mlir::SymbolRefAttr, mlir::StringAttr>> getBindNameValue(mlir::acc::DeviceType deviceType); + + // Add an entry to the 'seq' attribute for each additional device types. + void addSeq(MLIRContext *, llvm::ArrayRef<DeviceType>); + // Add an entry to the 'vector' attribute for each additional device types. + void addVector(MLIRContext *, llvm::ArrayRef<DeviceType>); + // Add an entry to the 'worker' attribute for each additional device types. + void addWorker(MLIRContext *, llvm::ArrayRef<DeviceType>); + // Add an entry to the 'gang' attribute for each additional device type. + void addGang(MLIRContext *, llvm::ArrayRef<DeviceType>); + // Add an entry to the 'gang' attribute with a value for each additional + // device type. + void addGang(MLIRContext *, llvm::ArrayRef<DeviceType>, uint64_t); + // Add an entry to the 'bind' string-name attribute for each additional + // device_type. + void addBindStrName(MLIRContext *, llvm::ArrayRef<DeviceType>, + mlir::StringAttr); + // Add an entry to the 'bind' ID-name attribute for each additional + // device_type. + void addBindIDName(MLIRContext *, llvm::ArrayRef<DeviceType>, + mlir::SymbolRefAttr); }]; let assemblyFormat = [{ @@ -3307,6 +3383,58 @@ def RoutineInfoAttr : OpenACC_Attr<"RoutineInfo", "routine_info"> { let assemblyFormat = "`<` `[` `` $accRoutines `]` `>`"; } +def SpecializedRoutineAttr : OpenACC_Attr<"SpecializedRoutine", + "specialized_routine"> { + let summary = "Marks a specialized device version of an acc routine"; + + let description = [{ + This attribute is attached to a function that was specialized from a host + function marked with `acc.routine_info`. It captures the parallelism level, + a reference to the original `acc.routine` operation, and the original + function name (since the specialized function may be renamed). + + Example - before specialization: + ```mlir + acc.routine @routine_gang func(@foo) gang + acc.routine @routine_vector func(@foo) vector + + func.func @foo() attributes { + acc.routine_info = #acc.routine_info<[@routine_gang, @routine_vector]> + } { ... } + ``` + + After specialization, there are three functions: the original function and + two specialized versions (one per parallelism level): + ```mlir + acc.routine @routine_gang func(@foo) gang + acc.routine @routine_vector func(@foo) vector + + // Original function (unchanged) + func.func @foo() attributes { + acc.routine_info = #acc.routine_info<[@routine_gang, @routine_vector]> + } { ... } + + // Specialized for gang parallelism + func.func @foo_gang() attributes { + acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "foo"> + } { ... } + + // Specialized for vector parallelism + func.func @foo_vector() attributes { + acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "foo"> + } { ... } + ``` + }]; + + let parameters = (ins + "SymbolRefAttr":$routine, + "ParLevelAttr":$level, + "StringAttr":$funcName + ); + + let assemblyFormat = "`<` $routine `,` $level `,` $funcName `>`"; +} + //===----------------------------------------------------------------------===// // 2.14.1. Init Directive //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td index 054c13a..d958006 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td @@ -44,4 +44,60 @@ def PartialEntityAccessOpInterface : OpInterface<"PartialEntityAccessOpInterface ]; } +def AddressOfGlobalOpInterface : OpInterface<"AddressOfGlobalOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that compute the address of a global variable + or symbol. + }]; + + let methods = [ + InterfaceMethod<"Get the symbol reference to the global", "::mlir::SymbolRefAttr", + "getSymbol", (ins)>, + ]; +} + +def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that define global variables. This interface + provides a uniform way to query properties of global variables across + different dialects. + }]; + + let methods = [ + InterfaceMethod<"Check if the global variable is constant", "bool", + "isConstant", (ins), [{ + return false; + }]>, + InterfaceMethod<"Get the initialization region (returns nullptr if none)", + "::mlir::Region*", "getInitRegion", (ins)>, + ]; +} + +def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that indirectly access global symbols. + This interface provides a way to query which global symbols are referenced + by an operation, which is useful for tracking dependencies and performing + analysis on global variable usage. + + The symbolTable parameter is optional. If null, implementations will look up + their own symbol table. This allows callers to pass a pre-existing symbol + table for efficiency when querying multiple operations. + }]; + + let methods = [ + InterfaceMethod<"Get the symbols referenced by this operation", + "void", + "getReferencedSymbols", + (ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols, + "::mlir::SymbolTable *":$symbolTable)>, + ]; +} + #endif // OPENACC_OPS_INTERFACES diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td index d1bbc7f..3f11bf6 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td @@ -176,6 +176,50 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { return false; }] >, + InterfaceMethod< + /*description=*/[{ + Generates a load operation from the pointer-like type. This dereferences + the pointer and returns the loaded value. + + The `srcPtr` parameter is the pointer to load from. If the current type is + represented in a way that it does not capture the pointee type, `valueType` + must be passed in to provide the necessary type information. + + Returns the loaded value, or an empty Value if load generation failed. + }], + /*retTy=*/"::mlir::Value", + /*methodName=*/"genLoad", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$srcPtr, + "::mlir::Type":$valueType), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return {}; + }] + >, + InterfaceMethod< + /*description=*/[{ + Generates a store operation to the pointer-like type. This stores a value + to the memory location pointed to by the pointer. + + The `destPtr` parameter is the pointer to store to. The `valueToStore` + parameter is the value to be stored. The type information is derived from + the valueToStore parameter itself. + + Returns true if store was successfully generated, false otherwise. + }], + /*retTy=*/"bool", + /*methodName=*/"genStore", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::Value":$valueToStore, + "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$destPtr), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, ]; } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h index 9647357..e9ce9b3 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h @@ -10,8 +10,11 @@ #define MLIR_DIALECT_OPENACC_OPENACCUTILS_H_ #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "llvm/ADT/SmallVector.h" namespace mlir { +class DominanceInfo; +class PostDominanceInfo; namespace acc { /// Used to obtain the enclosing compute construct operation that contains @@ -52,6 +55,32 @@ std::string getRecipeName(mlir::acc::RecipeKind kind, mlir::Type type); // base `array` from an operation that only accesses a subarray. mlir::Value getBaseEntity(mlir::Value val); +/// Check if a symbol use is valid for use in an OpenACC region. +/// This includes looking for various attributes such as `acc.routine_info` +/// and `acc.declare` attributes. +/// \param user The operation using the symbol +/// \param symbol The symbol reference being used +/// \param definingOpPtr Optional output parameter to receive the defining op +/// \return true if the symbol use is valid, false otherwise +bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr = nullptr); + +/// Collects all data clauses that dominate the compute construct. +/// This includes data clauses from: +/// - The compute construct itself +/// - Enclosing data constructs +/// - Applicable declare directives (those that dominate and post-dominate) +/// This is used to determine if a variable is already covered by an existing +/// data clause. +/// \param computeConstructOp The compute construct operation +/// \param domInfo Dominance information +/// \param postDomInfo Post-dominance information +/// \return Vector of data clause values that dominate the compute construct +llvm::SmallVector<mlir::Value> +getDominatingDataClauses(mlir::Operation *computeConstructOp, + mlir::DominanceInfo &domInfo, + mlir::PostDominanceInfo &postDomInfo); + } // namespace acc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td index 40ccd1f..b37cc28 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -63,4 +63,93 @@ def ACCImplicitData : Pass<"acc-implicit-data", "mlir::ModuleOp"> { ]; } +def ACCImplicitDeclare : Pass<"acc-implicit-declare", "mlir::ModuleOp"> { + let summary = "Applies implicit acc declare to globals referenced in compute and routine acc regions"; + let description = [{ + This pass applies implicit `acc declare` actions to global variables + referenced in OpenACC compute regions and routine functions. + + The pass performs the following actions: + + 1. Hoists address-of operations for non-constant globals out of OpenACC + regions when they can be implicitly mapped rather than declared. + + 2. Collects global symbols referenced in: + - OpenACC compute constructs (parallel, kernels, serial) + - Functions marked with acc routine + - Initialization regions of existing acc declare globals + - Private/firstprivate/reduction recipe operations + + 3. Marks collected globals with the acc.declare attribute using the + copyin data clause. + + The pass avoids unnecessary declare marking by: + - Skipping function symbols (which use acc routine instead) + - Hoisting non-constant global references that can use implicit mapping + - Only processing symbols that are not already valid in device regions + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect"]; +} + +def ACCImplicitRoutine : Pass<"acc-implicit-routine", "mlir::ModuleOp"> { + let summary = "Generate implicit acc routine for functions in acc regions"; + let description = [{ + This pass implements the implicit rules described in OpenACC specification + for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1). + + "If no explicit routine directive applies to a procedure whose definition + appears in the program unit being compiled, then the implementation applies + an implicit routine directive to that procedure if any of the following + conditions holds: + - The procedure is called or its address is accessed in a compute region." + + The specification further states: + "When the implementation applies an implicit routine directive to a procedure, + it must recursively apply implicit routine directives to other procedures for + which the above rules specify relevant dependencies. Such dependencies can + form a cycle, so the implementation must take care to avoid infinite recursion." + + This pass implements these requirements by: + 1. Walking through all OpenACC compute constructs and functions already + marked with `acc routine` in the module and identifying function calls + within these regions. + 2. Creating implicit `acc.routine` operations for functions that don't already + have routine declarations. + 3. Recursively walking through all existing `acc routine` and creating + implicit routine operations for function calls within these routines, + while avoiding infinite recursion through proper tracking. + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect"]; + let options = [ + Option<"deviceType", "device-type", "mlir::acc::DeviceType", + "mlir::acc::DeviceType::None", + "Target device type for implicit routine generation. " + "Ensures that `acc routine` device_type clauses are " + "properly considered not just default clauses.", + [{::llvm::cl::values( + clEnumValN(mlir::acc::DeviceType::None, "none", "none"), + clEnumValN(mlir::acc::DeviceType::Host, "host", "host"), + clEnumValN(mlir::acc::DeviceType::Multicore, "multicore", "multicore"), + clEnumValN(mlir::acc::DeviceType::Nvidia, "nvidia", "nvidia"), + clEnumValN(mlir::acc::DeviceType::Radeon, "radeon", "radeon")) + }]> + ]; +} + +def ACCLegalizeSerial : Pass<"acc-legalize-serial", "mlir::func::FuncOp"> { + let summary = "Legalize OpenACC serial constructs"; + let description = [{ + This pass converts `acc.serial` constructs into `acc.parallel` constructs + with `num_gangs(1)`, `num_workers(1)`, and `vector_length(1)`. + + This transformation simplifies processing of acc regions by unifying the + handling of serial and parallel constructs. Since an OpenACC serial region + executes sequentially (like a parallel region with a single gang, worker, + and vector), this conversion is semantically equivalent while enabling code + reuse in later compilation stages. + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect", + "mlir::arith::ArithDialect"]; +} + #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 8e43c42..05e2ee4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -21,6 +21,7 @@ include "mlir/Dialect/OpenMP/OpenMPOpBase.td" include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/BuiltinAttributes.td" //===----------------------------------------------------------------------===// // V5.2: [6.3] `align` clause @@ -723,10 +724,9 @@ class OpenMP_LinearClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { - let arguments = (ins - Variadic<AnyType>:$linear_vars, - Variadic<I32>:$linear_step_vars - ); + let arguments = (ins Variadic<AnyType>:$linear_vars, + Variadic<I32>:$linear_step_vars, + OptionalAttr<ArrayAttr>:$linear_var_types); let optAssemblyFormat = [{ `linear` `(` diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index d9882cb..ea5489f 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -40,13 +40,15 @@ class OpenMP_EnumAttr<EnumAttrInfo enumInfo, string name> // capture_clause enum. //===----------------------------------------------------------------------===// -def CaptureClauseTo : I32EnumAttrCase<"to", 0>; -def CaptureClauseLink : I32EnumAttrCase<"link", 1>; -def CaptureClauseEnter : I32EnumAttrCase<"enter", 2>; +def CaptureClauseNone : I32EnumAttrCase<"none", 0>; +def CaptureClauseTo : I32EnumAttrCase<"to", 1>; +def CaptureClauseLink : I32EnumAttrCase<"link", 2>; +def CaptureClauseEnter : I32EnumAttrCase<"enter", 3>; def DeclareTargetCaptureClause : OpenMP_I32EnumAttr< "DeclareTargetCaptureClause", "capture clause", [ + CaptureClauseNone, CaptureClauseTo, CaptureClauseLink, CaptureClauseEnter @@ -126,6 +128,7 @@ def ClauseMapFlagsAttachAuto : I32BitEnumAttrCaseBit<"attach_auto", 15>; def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>; def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>; def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>; +def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 19>; def ClauseMapFlags : OpenMP_BitEnumAttr< "ClauseMapFlags", @@ -149,7 +152,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr< ClauseMapFlagsAttachAuto, ClauseMapFlagsRefPtr, ClauseMapFlagsRefPtee, - ClauseMapFlagsRefPtrPtee + ClauseMapFlagsRefPtrPtee, + ClauseMapFlagsIsDevicePtr ]>; def ClauseMapFlagsAttr : OpenMP_EnumAttr<ClauseMapFlags, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 377f1fe..bbfe805 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1972,7 +1972,7 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove, Symbol]> { let summary = "declares a reduction kind"; let description = [{ - Declares an OpenMP reduction kind. This requires two mandatory and three + Declares an OpenMP reduction kind. This requires two mandatory and four optional regions. 1. The optional alloc region specifies how to allocate the thread-local @@ -2001,6 +2001,9 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove, allocated by the initializer region. The region has an argument that contains the value of the thread-local reduction accumulator. This will be executed after the reduction has completed. + 6. The DataPtrPtr region specifies how to access the base address of a + descriptor. This is used, in particular, for GPU reductions in order + know where partial reduction results are stored in remote lanes. Note that the MLIR type system does not allow for type-polymorphic reductions. Separate reduction declarations should be created for different @@ -2008,23 +2011,32 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove, For initializer and reduction regions, the operand to `omp.yield` must match the parent operation's results. + + * `$byref_element_type`: For by-ref reductions, we want to keep track of the + boxed/allocated type. For example, for a `real, allocatable` variable, + `real` should be stored in this attribute. + }]; let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttr:$type); + TypeAttr:$type, + OptionalAttr<TypeAttr>:$byref_element_type + ); let regions = (region MaxSizedRegion<1>:$allocRegion, AnyRegion:$initializerRegion, AnyRegion:$reductionRegion, AnyRegion:$atomicReductionRegion, - AnyRegion:$cleanupRegion); + AnyRegion:$cleanupRegion, + MaxSizedRegion<1>:$dataPtrPtrRegion); let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword " "( `alloc` $allocRegion^ )? " "`init` $initializerRegion " "`combiner` $reductionRegion " "( `atomic` $atomicReductionRegion^ )? " - "( `cleanup` $cleanupRegion^ )? "; + "( `cleanup` $cleanupRegion^ )? " + "( `data_ptr_ptr` $dataPtrPtrRegion^ )? "; let extraClassDeclaration = [{ BlockArgument getAllocMoldArg() { @@ -2056,6 +2068,10 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove, auto ®ion = getCleanupRegion(); return region.empty() ? nullptr : region.getArgument(0); } + BlockArgument getDataPtrPtrRegionArg() { + auto ®ion = getDataPtrPtrRegion(); + return region.empty() ? nullptr : region.getArgument(0); + } PointerLikeType getAccumulatorType() { if (getAtomicReductionRegion().empty()) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index cd033c1..8bdf3e0 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [ getNumDynamicControlOperands() + getRank()); } + BlockArgument getTiedBlockArgument(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); + return getBody()->getArgument(getRank() + opResult.getResultNumber()); + } + ::mlir::Value getInductionVar(int64_t idx) { return getInductionVars()[idx]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 7c735d8..0005fad 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, /// tiled in a manner that is consistent for all the passed slices. Note that /// the method replaces the uses of `candidateSlices` with the tiled and fused /// consumer value but does not delete the slice operations. +/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion +/// is to take the consumer operation, and find the slices to use for fusion +/// by walking its operands to the `loops` and then into the body to get the +/// slices used for fusion. struct SCFFuseConsumerOfSliceResult { // Original untiled consumer operands. SmallVector<OpOperand *> origConsumerOperands; @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, MutableArrayRef<LoopLikeOpInterface> loops); +/// Fuse the `consumer` operation into the loop nest provided by `loops`. +/// The transformation looks for operands in the `consumer` that are defined +/// by the outermost loop of the loop nest in `loops`. The nested loop is +/// expected to have the structure of the loops generated through tiling. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops); + /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. FailureOr<SmallVector<scf::ForOp>> diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index b628f1a..ecbbf39 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -792,7 +792,7 @@ def SPIRV_C_FPGABufferLocationINTEL : I32EnumAttrCase<"FPGAB Extension<[SPV_INTEL_fpga_buffer_location]> ]; } -def SPIRV_C_ArbitraryPrecisionFixedPointINTEL : I32EnumAttrCase<"ArbitraryPrecisionFixedPointINTEL", 5922> { +def SPIRV_C_ArbitraryPrecisionFixedPointINTEL : I32EnumAttrCase<"ArbitraryPrecisionFixedPointINTEL", 5922> { list<Availability> availability = [ Extension<[SPV_INTEL_arbitrary_precision_fixed_point]> ]; @@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerg def SPIRV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; def SPIRV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>; def SPIRV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>; +def SPIRV_OC_OpSwitch : I32EnumAttrCase<"OpSwitch", 251>; def SPIRV_OC_OpKill : I32EnumAttrCase<"OpKill", 252>; def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; @@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor, SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge, SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional, - SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, + SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index acb6467..27c9add 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -244,6 +244,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [ // ----- +def SPIRV_SwitchOp : SPIRV_Op<"Switch", + [AttrSizedOperandSegments, InFunctionScope, + DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>, + Pure, Terminator]> { + let summary = [{ + Multi-way branch to one of the operand label <id>. + }]; + + let description = [{ + Selector must have a type of OpTypeInt. Selector is compared for equality to + the Target literals. + + Default must be the <id> of a label. If Selector does not equal any of the + Target literals, control flow branches to the Default label <id>. + + Target must be alternating scalar integer literals and the <id> of a label. + If Selector equals a literal, control flow branches to the following label + <id>. It is invalid for any two literal to be equal to each other. If Selector + does not equal any literal, control flow branches to the Default label <id>. + Each literal is interpreted with the type of Selector: The bit width of + Selector’s type is the width of each literal’s type. If this width is not a + multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values + are interpreted as being sign extended. + + If Selector is an OpUndef, behavior is undefined. + + This instruction must be the last instruction in a block. + + #### Example: + + ```mlir + spirv.Switch %selector : si32, [ + default: ^bb1(%a : i32), + 0: ^bb1(%b : i32), + 1: ^bb3(%c : i32) + ] + ``` + }]; + + let arguments = (ins + SPIRV_Integer:$selector, + Variadic<AnyType>:$defaultOperands, + VariadicOfVariadic<AnyType, "case_operand_segments">:$targetOperands, + OptionalAttr<AnyIntElementsAttr>:$literals, + DenseI32ArrayAttr:$case_operand_segments + ); + + let results = (outs); + + let successors = (successor AnySuccessor:$defaultTarget, + VariadicSuccessor<AnySuccessor>:$targets); + + let builders = [ + OpBuilder<(ins "Value":$selector, + "Block *":$defaultTarget, + "ValueRange":$defaultOperands, + CArg<"ArrayRef<APInt>", "{}">:$literals, + CArg<"BlockRange", "{}">:$targets, + CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>, + OpBuilder<(ins "Value":$selector, + "Block *":$defaultTarget, + "ValueRange":$defaultOperands, + CArg<"ArrayRef<int32_t>", "{}">:$literals, + CArg<"BlockRange", "{}">:$targets, + CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>, + OpBuilder<(ins "Value":$selector, + "Block *":$defaultTarget, + "ValueRange":$defaultOperands, + CArg<"DenseIntElementsAttr", "{}">:$literals, + CArg<"BlockRange", "{}">:$targets, + CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)> + ]; + + let assemblyFormat = [{ + $selector `:` type($selector) `,` `[` `\n` + custom<SwitchOpCases>(ref(type($selector)),$defaultTarget, + $defaultOperands, + type($defaultOperands), + $literals, + $targets, + $targetOperands, + type($targetOperands)) + `]` + attr-dict + }]; + + let extraClassDeclaration = [{ + /// Return the operands for the target block at the given index. + OperandRange getTargetOperands(unsigned index) { + return getTargetOperands()[index]; + } + + /// Return a mutable range of operands for the target block at the + /// given index. + MutableOperandRange getTargetOperandsMutable(unsigned index) { + return getTargetOperandsMutable()[index]; + } + }]; + + let autogenSerialization = 0; + let hasVerifier = 1; +} + + +// ----- + def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> { let summary = [{ Deprecated (use OpTerminateInvocation or OpDemoteToHelperInvocation). diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td index 588b5eb..0b8c465 100644 --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -16,7 +16,7 @@ def OutlineShapeComputationPass let summary = "Using shape.func to preserve shape computation"; let description = [{ This pass outlines the shape computation part in high level IR by adding - shape.func and populate corresponding mapping infoemation into + shape.func and populate corresponding mapping information into ShapeMappingAnalysis. The shape computation part is usually introduced by shape reification, and each single dynamic shape is denoted by shape.with_shape. @@ -80,12 +80,12 @@ def OutlineShapeComputationPass For the above example, the shape computation is inlined in the input IR, which is used for two values' (test.abs and test.concat) shape. And the shape - compuatation part is outlined in the output IR. + computation part is outlined in the output IR. - And the shape mapping infomation will be: + And the shape mapping information will be: ``` - // ---- Shape Mapping Infomation ----- + // ---- Shape Mapping Information ----- // - Shape for: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0) // - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0) ``` diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index af64370..419ecda 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -58,9 +58,10 @@ enum class SparseEmitStrategy { namespace sparse_tensor { /// Defines a strategy for loop ordering during sparse code generation. +/// See Passes.td for strategy descriptions. enum class LoopOrderingStrategy : unsigned { - kDefault, ///< Default strategy (eagerly selects last loop in topological - ///< sort). + kDefault, + kDenseOuter, }; } // namespace sparse_tensor diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 75e77d6..0b8562e 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -85,7 +85,9 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> { "mlir::sparse_tensor::LoopOrderingStrategy::kDefault", "Set the loop ordering strategy for sparse code generation", [{llvm::cl::values( clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default", - "Default strategy (eagerly selects last loop in topological sort)"))}]>, + "Default strategy (eagerly selects last loop in topological sort)"), + clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDenseOuter, "dense-outer", + "Prefer dense, then compressed, then singleton dimensions outermost"))}]>, ]; } diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 2453cf5..35d2b60 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [ def Tensor_ConcatOp : Tensor_Op<"concat", [Pure, DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> { + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>, + ]> { let summary = "tensor concatenation operation"; let description = [{ The "concat" operation constructs a tensor out of a variadic list of input @@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [ def Tensor_EmptyOp : Tensor_Op<"empty", [Pure, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> { + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>]> { let summary = "empty tensor operation"; let description = [{ @@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>, AttrSizedOperandSegments, Pure, OffsetSizeAndStrideOpInterface @@ -467,6 +471,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", // a Range vector. OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges, CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>, + // Build an ExtractSliceOp with mixed static and dynamic sizes, inferred + // result type, offsets set to 0 and strides set to 1. + OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, + "ArrayRef<OpFoldResult>":$sizes, + CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -486,17 +495,13 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", /// An extract_slice result type can be inferred, when it is not /// rank-reduced, from the source type and the static representation of - /// offsets, sizes and strides. Special sentinels encode the dynamic case. + /// sizes. Special sentinels encode the dynamic case. static RankedTensorType inferResultType( RankedTensorType sourceTensorType, - ArrayRef<int64_t> staticOffsets, - ArrayRef<int64_t> staticSizes, - ArrayRef<int64_t> staticStrides); + ArrayRef<int64_t> staticSizes); static RankedTensorType inferResultType( RankedTensorType sourceTensorType, - ArrayRef<OpFoldResult> staticOffsets, - ArrayRef<OpFoldResult> staticSizes, - ArrayRef<OpFoldResult> staticStrides); + ArrayRef<OpFoldResult> staticSizes); /// If the rank is reduced (i.e. the desiredResultRank is smaller than the /// number of sizes), drop as many size 1 as needed to produce an inferred type @@ -509,15 +514,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", static RankedTensorType inferCanonicalRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<int64_t> staticOffsets, - ArrayRef<int64_t> staticSizes, - ArrayRef<int64_t> staticStrides); + ArrayRef<int64_t> staticSizes); static RankedTensorType inferCanonicalRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<OpFoldResult> staticOffsets, - ArrayRef<OpFoldResult> staticSizes, - ArrayRef<OpFoldResult> staticStrides); + ArrayRef<OpFoldResult> staticSizes); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -740,7 +741,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [ def Tensor_GenerateOp : Tensor_Op<"generate", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, RecursiveMemoryEffects, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>, SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { let summary = "Creates a dynamically sized tensor from elements"; let description = [{ @@ -835,7 +837,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>, AttrSizedOperandSegments, DestinationStyleOpInterface, Pure, @@ -932,7 +935,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ // a Range vector and inferred result type. OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<Range>":$ranges, - CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)> + CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>, + // Build an InsertSliceOp with mixed static and dynamic sizes, offsets set + // to 0, strides set to 1 and inferred result type. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef<OpFoldResult>":$sizes, + CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -1256,7 +1264,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { def Tensor_PadOp : Tensor_Op<"pad", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [ + "reifyResultShapes"]>, AttrSizedOperandSegments, Pure, SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { @@ -1764,7 +1773,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [ def Tensor_SplatOp : Tensor_Op<"splat", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, - DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes"]>, Pure, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index c774d87..e23827f 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -476,7 +476,16 @@ extensionComplianceMap = { {{fp32T, i64T}, SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}}, - {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}}, + {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3, Extension::int64}, + {{{fp8e4m3T, i64T}, SpecificationVersion::V_1_1_DRAFT}}, + allOf}, + {{Extension::fp8e5m2, Extension::int64}, + {{{fp8e5m2T, i64T}, SpecificationVersion::V_1_1_DRAFT}}, + allOf}, + {{Extension::bf16, Extension::int64}, + {{{bf16T, i64T}, SpecificationVersion::V_1_1_DRAFT}}, + allOf}}}, {"tosa.avg_pool2d", {{{Extension::int16}, {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}}, @@ -857,15 +866,7 @@ extensionComplianceMap = { {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0}, {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0}, {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0}, - {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, - {{Extension::bf16, Extension::mxfp}, - {{{fp4e2m1T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, - {{fp6e3m2T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, - {{fp6e2m3T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, - {{bf16T, fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}, - {{bf16T, fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT}, - {{bf16T, fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT}}, - allOf}}}, + {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast_from_block_scaled", {{{Extension::bf16, Extension::mxfp}, {{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 5b595dd..cc23955 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -240,6 +240,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic, // DOUBLEROUND : Adds double rounding support to the RESCALE operator. // INEXACTROUND : Adds inexact rounding support to the RESCALE operator. // DYNAMIC : Removes all Compile Time Constant state for CTC inputs. +// MXFP : Microscaling formats. //===----------------------------------------------------------------------===// def Tosa_NONE : I32EnumAttrCase<"none", 0>; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 467dba3..370ce8c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -826,12 +826,12 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"intdiv", [SameOperandsAndResultElementTy }]; let arguments = (ins - Tosa_Int32Tensor:$input1, - Tosa_Int32Tensor:$input2 + Tosa_Int32Or64Tensor:$input1, + Tosa_Int32Or64Tensor:$input2 ); let results = (outs - Tosa_Int32Tensor:$output + Tosa_Int32Or64Tensor:$output ); list<Availability> availability = [ @@ -2219,7 +2219,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { // Operator: transpose //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", - [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface , + ["reifyResultShapes"]>, AllElementTypesMatch<["input1", "output"]>]> { let summary = "Transpose operator."; @@ -2270,7 +2271,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> { let arguments = (ins Tosa_Tensor3D:$values, - Tosa_Int32Tensor2D:$indices + Tosa_IndexTensor2D:$indices ); let results = (outs @@ -2307,7 +2308,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> { let arguments = (ins Tosa_Tensor3D:$values_in, - Tosa_Int32Tensor2D:$indices, + Tosa_IndexTensor2D:$indices, Tosa_Tensor3D:$input ); @@ -2463,7 +2464,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape, list<Availability> availability = [ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, - Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_MXFP, Tosa_EXT_INT64]>, + Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_INT64]>, ]; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 414b51b..266a9e3 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -202,10 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[ def Tosa_TensorUpto4D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>; -def Tosa_Int32TensorUpto4D : AnyTypeOf<[ - Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>; -def Tosa_Int32Tensor2D : AnyTypeOf<[ - Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>; +def Tosa_IndexTensor2D : AnyTypeOf<[ + Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>; def Tosa_TensorAtLeast1D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 14b00b0..12f5202 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> { }]; } +def TosaArithConstantToTosaConstPass + : Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> { + let summary = "Convert tensor arith.constant operations into tosa.const"; + let description = [{ + Normalizes tensor-valued arith.constant operations into tosa.const so that + subsequent TOSA passes operate on a consistent representation of constants. + }]; +} + def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> { let summary = "Convert integer types to signless"; let description = [{ @@ -166,4 +175,27 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { ]; } +def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> { + let summary = "Narrow I64 TOSA operations to I32"; + let description = [{ + This pass narrows TOSA operations with 64-bit integer tensor types to + 32-bit integer tensor types. This can be useful for backends that do not + support the EXT-INT64 extension of TOSA. + }]; + + let options = [ + Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false", + "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing" + "is safe. This option may lead to data loss if not used carefully.">, + Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false", + "If enabled, the pass will convert function I/O types as well. Otherwise casts will" + "be inserted at the I/O boundaries."> + ]; + + let dependentDialects = [ + "func::FuncDialect", + "tosa::TosaDialect", + ]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 9d9a934..e9ad786 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -88,6 +88,8 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, IntegerAttr quantBits, int filterQuantDim, bool isSigned, BoolAttr narrowRange); +Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType); + } // namespace tosa } // namespace mlir diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td index c400a2e..1bff39a 100644 --- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td +++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td @@ -66,4 +66,24 @@ def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// UnreachableOp +//===----------------------------------------------------------------------===// + +def UnreachableOp : UB_Op<"unreachable", [Terminator]> { + let summary = "Unreachable operation."; + let description = [{ + The `unreachable` operation triggers immediate undefined behavior if + executed. + + Example: + + ``` + ub.unreachable + ``` + }]; + + let assemblyFormat = "attr-dict"; +} + #endif // MLIR_DIALECT_UB_IR_UBOPS_TD diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43172ff..d8ed46c 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2160,25 +2160,25 @@ def Vector_GatherOp : ]; } -def Vector_ScatterOp : - Vector_Op<"scatter", [ - DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>, - DeclareOpInterfaceMethods<AlignmentAttrOpInterface> - ]>, - Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base, - Variadic<Index>:$offsets, - VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices, - VectorOfNonZeroRankOf<[I1]>:$mask, - AnyVectorOfNonZeroRank:$valueToStore, - OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> { +def Vector_ScatterOp + : Vector_Op<"scatter", + [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>, + DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]>, + Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base, + Variadic<Index>:$offsets, + VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices, + VectorOfNonZeroRankOf<[I1]>:$mask, + AnyVectorOfNonZeroRank:$valueToStore, + OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>, + Results<(outs Optional<AnyRankedTensor>:$result)> { let summary = [{ - scatters elements from a vector into memory as defined by an index vector + scatters elements from a vector into memory or ranked tensor as defined by an index vector and a mask vector }]; let description = [{ - The scatter operation stores elements from a n-D vector into memory as + The scatter operation stores elements from a n-D vector into memory or ranked tensor as defined by a base with indices and an additional n-D index vector, but only if the corresponding bit in a n-D mask vector is set. Otherwise, no action is taken for that element. Informally the semantics are: @@ -2221,31 +2221,28 @@ def Vector_ScatterOp : }]; let extraClassDeclaration = [{ - MemRefType getMemRefType() { return getBase().getType(); } + ShapedType getBaseType() { return getBase().getType(); } VectorType getIndexVectorType() { return getIndices().getType(); } VectorType getMaskVectorType() { return getMask().getType(); } VectorType getVectorType() { return getValueToStore().getType(); } }]; - let assemblyFormat = - "$base `[` $offsets `]` `[` $indices `]` `,` " - "$mask `,` $valueToStore attr-dict `:` type($base) `,` " - "type($indices) `,` type($mask) `,` type($valueToStore)"; + let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` " + "$mask `,` $valueToStore attr-dict `:` type($base) `,` " + "type($indices) `,` type($mask) `,` " + "type($valueToStore) (`->` type($result)^)?"; let hasCanonicalizer = 1; let hasVerifier = 1; - let builders = [ - OpBuilder<(ins "Value":$base, - "ValueRange":$indices, - "Value":$index_vec, - "Value":$mask, - "Value":$valueToStore, - CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{ - return build($_builder, $_state, base, indices, index_vec, mask, valueToStore, + let builders = [OpBuilder< + (ins "Type":$resultType, "Value":$base, "ValueRange":$indices, + "Value":$index_vec, "Value":$mask, "Value":$valueToStore, + CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), + [{ + return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore, alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) : nullptr); - }]> - ]; + }]>]; } def Vector_ExpandLoadOp : @@ -2427,6 +2424,7 @@ def Vector_CompressStoreOp : def Vector_ShapeCastOp : Vector_Op<"shape_cast", [Pure, + DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]> ]>, Arguments<(ins AnyVectorOfAnyRank:$source)>, @@ -2607,7 +2605,9 @@ def Vector_ConstantMaskOp : } def Vector_CreateMaskOp : - Vector_Op<"create_mask", [Pure]>, + Vector_Op<"create_mask", [Pure, + DeclareOpInterfaceMethods<VectorUnrollOpInterface> + ]>, Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a vector mask"; diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index a57aadc..45626aa 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -219,14 +219,18 @@ bool isLinearizableVector(VectorType type); /// Creates a TransferReadOp from `source`. /// -/// The shape of the vector to read is specified via `inputVectorSizes`. If the -/// shape of the output vector differs from the shape of the value being read, -/// masking is used to avoid out-of-bounds accesses. Set +/// If the shape of vector to read differs from the shape of the value being +/// read, masking is used to avoid out-of-bounds accesses. Set /// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute /// instead of explicit masks. /// /// Note: all read offsets are set to 0. Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, + const VectorType &vecToReadTy, + std::optional<Value> padValue = std::nullopt, + bool useInBoundsInsteadOfMasking = false); + +Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef<int64_t> inputVectorSizes, std::optional<Value> padValue = std::nullopt, bool useInBoundsInsteadOfMasking = false, diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt index 0fe0182..bbe8e4e 100644 --- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt @@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector) add_mlir_interface(X86VectorInterfaces) add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen) + +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..6f377e1 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td) +mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs) +add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h new file mode 100644 index 0000000..e1d8b87 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h @@ -0,0 +1,31 @@ +//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H +#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +//===----------------------------------------------------------------------===// +// X86Vector Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace x86vector { +void registerTransformDialectExtension(DialectRegistry ®istry); + +} // namespace x86vector +} // namespace mlir + +#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td new file mode 100644 index 0000000..3c5294f --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td @@ -0,0 +1,43 @@ +//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef X86VECTOR_TRANSFORM_OPS +#define X86VECTOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/RegionKindInterface.td" + +def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect, + "apply_patterns.x86vector.vector_contract_to_fma", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Collect patterns to lower a F32 type vector.contract operation to a FMA. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect, + "apply_patterns.x86vector.vector_contract_to_packed_type_dot_product", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Collect patterns to lower a BF16/Int8 type vector.contract operation + to a BF16/Int8 dot-product. + }]; + + let assemblyFormat = "attr-dict"; +} + + +#endif // X86VECTOR_TRANSFORM_OPS + diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index d54111c..fc46dff 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -80,6 +80,18 @@ struct MaskHelper { }; //===----------------------------------------------------------------------===// + +// A set of patterns for specialized lowering of vector contraction +// operation to vector fused multiply and add (FMA) operation. +void populateVectorContractToFMAPatterns(RewritePatternSet &patterns); + +// A set of patterns for lowering 32-bit packed vector contraction operations +// to their corresponding packed-type dot-product operations, ultimately +// targeting the relevant x86 LLVM intrinsics (e.g., BF16 and Int8). +void populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns); + +//===----------------------------------------------------------------------===// /// Helpers extracted from: /// - clang/lib/Headers/avxintrin.h /// - clang/test/CodeGen/X86/avx-builtins.c diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 3f27d69..eae0bd4 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -223,6 +223,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { InterfaceMethod<"Derive a new layout by dropping InstData", "xegpu::DistributeLayoutAttr", "dropInstData">, + InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims", + "xegpu::DistributeLayoutAttr", + "setUnitDimData", + /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>, + InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims", + "xegpu::DistributeLayoutAttr", + "setUnitDimLayout", + /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>, InterfaceMethod<[{Delinearizes a linear ID into its multidimensional indices based on the effective layout level.}], "FailureOr<SmallVector<Value>>", @@ -283,9 +291,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { } return true; }]>, - InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}], + InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}], /*retTy=*/"bool", /*methodName=*/"isSliceOf", + /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>, + + InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}], + /*retTy=*/"bool", + /*methodName=*/"isEqualTo", /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)> ]; } @@ -487,6 +500,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { return {}; } + //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1 + DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims); + + //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 + DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims); + /// Delinearizes a linear ID into its multidimensional indices /// based on the effective level of the layout. FailureOr<SmallVector<Value>> @@ -501,6 +520,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { /// Check if this is slice of some other layout. bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; } + + /// Check if this is identical to some other layout. + bool isEqualTo(const xegpu::DistributeLayoutAttr &other); }]; @@ -635,6 +657,8 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { SliceAttr attr = flatten(); auto parent = dyn_cast<LayoutAttr>(attr.getParent()); parent = parent.dropSgLayoutAndData(); + if (!parent) + return nullptr; return SliceAttr::get(getContext(), parent, attr.getDims()); } @@ -642,9 +666,17 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { SliceAttr attr = flatten(); auto parent = dyn_cast<LayoutAttr>(attr.getParent()); parent = parent.dropInstData(); + if (!parent) + return nullptr; return SliceAttr::get(getContext(), parent, attr.getDims()); } + //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1 + DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims); + + //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 + DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims); + /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]> /// it will coalese two slice operations and return a simplified SliceAttr @@ -666,7 +698,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Check if this is slice of some other layout. bool isSliceOf(const xegpu::DistributeLayoutAttr &other); - + + /// Check if this is identical to some other layout. + bool isEqualTo(const xegpu::DistributeLayoutAttr &other); }]; let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`"; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 689ebd0..b54d620 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -76,10 +76,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface For the case of dynamic memrefs or pointer, the shape and layout information of the memory region should be explicitly passed via `shape` and `strides` parameters. - - `offsets`: index values represents offsets from the "source" at the each dimension + - `offsets`: [optional] index values represents offsets from the "source" at the each dimension at which the subview of the target memory will be created. It is encoded via "offsets" and "const_offsets", such that it can accept various forms, such as, - operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]). + operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]). Offsets is optional and may be set at load_nd, store_nd, and prefetch_nd. - `shape`: the shape information of the memory region pointed by the "source". It is typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. @@ -236,7 +236,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface return static_cast<unsigned>(MemorySpace::Global); } - xegpu::DistributeLayoutAttr getLayoutAttr() { + xegpu::DistributeLayoutAttr getDescLayoutAttr() { return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout()); } @@ -253,12 +253,32 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { It issues an instruction to prefetch a block of data from continuous memory regions to each level of the cache based on their cache policy. - Example: + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + + Arguments: + - `TensorDesc`: A tensor descriptor specifying the base nd-region of + memory and tensor tile to be prefetched. + + - `offsets`: [optional] index values representing per-dimension offsets from the + base position encoded in `TensorDesc`. It is encoded via "offsets" + and "const_offsets". + + - `l1_hint`, `l2_hint`, `l3_hint`: [optional] An cache-hint attribute + indicating the desired behavior at the L1, L2, and L3 cache levels. + + - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand. + Only valid at the workgroup and subgroup levels. + + Example (Workgroup level): ```mlir - xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>, + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + xegpu.prefetch_nd %tdesc[%c0, %c1] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, - l3_hint = #xegpu.cache_hint<cached>} - : !xegpu.tensor_desc<8x16xf16> + l3_hint = #xegpu.cache_hint<cached>, + layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32]> } + : !xegpu.tensor_desc<32x256xf16> ``` }]; @@ -268,7 +288,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { OptionalAttr<DenseI64ArrayAttr>: $const_offsets, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, - OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); + OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint, + OptionalAttr<DistributeLayoutAttr>:$layout); let extraClassDeclaration = extraBaseClassDeclaration # [{ xegpu::TensorDescType getTensorDescType() { @@ -283,7 +304,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { return getMixedValues(statics, dynamics, getContext()); } - xegpu::DistributeLayoutAttr getLayoutAttr() { + xegpu::DistributeLayoutAttr getDescLayoutAttr() { return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout()); } @@ -308,7 +329,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { "ArrayRef<OpFoldResult>": $offsets, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, - "xegpu::CachePolicyAttr": $l3_hint)> + "xegpu::CachePolicyAttr": $l3_hint, + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; @@ -325,25 +347,48 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ a block of data from memory to register. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. - VNNI transformation is an hardware feature for Intel GPU, which is used to - do data packing during the load for B operand of matrix operation, if - the bit width of the data type is less then 32 bits, e.g., fp16. And - transpose is another Intel hardware feature, which will do transpose - operation when loading the data if the bit width of the data type is - fp32 or fp64. It implies that vnni and transpose cannot exit at the - same time. It is only available to 1D or 2D blocked tensor_desc. - In SIMT mode, result vector represents the data to be loaded by each work-item. + On Intel GPUs, hardware-supported packing rearranges data elements during + the load of the B operand when the element bit-width is less than 32 bits + (for example, fp16). The transpose feature reorders data during the load + when the element type is fp32 or fp64. These two features are mutually + exclusive and shall not be enabled simultaneously. Both features support only + 2D blocked tensor_desc. + + At lane level, result vector represents the data to be loaded by each lane. + + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + + Arguments: + + - `TensorDesc`: A tensor descriptor specifying the base nd-region of memory + and the tensor tile to be loaded. + + - `offsets`: Index values representing per-dimension offsets from the base position + encoded in `TensorDesc`. They are encoded via `offsets` and `const_offsets`. - Example 1: + - `packed`: [optional] A unit attribute indicating that packing is applied + during the load when supported by the hardware. Only valid at lane level. + + - `transpose`: [optional] An attribute describing a hardware-supported transpose + to be applied during the load. Only valid at Lane level. + + - `l1_hint`, `l2_hint`, `l3_hint`: [optional] Cache-hint attributes indicating the + desired behavior at the L1, L2, and L3 cache levels. + + - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand as well as the result of the load (they are identical). Only valid at workgroup and subgroup levels. + + Example 1 (Workgroup level): ```mlir xegpu.load_nd %1 {transpose = [1, 0], l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, - l3_hint = #xegpu.cache_hint<streaming>} - : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> + l3_hint = #xegpu.cache_hint<streaming>, + layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32]>} + : !xegpu.tensor_desc<32x256xf32> -> vector<32x256xf32> ``` - Example 2 (SIMT mode): + Example 2 (lane level): ```mlir xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> @@ -360,7 +405,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ OptionalAttr<DenseI64ArrayAttr>: $transpose, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, - OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); + OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint, + OptionalAttr<DistributeLayoutAttr>:$layout); let results = (outs XeGPU_ValueType: $value); @@ -381,7 +427,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ return getMixedValues(statics, dynamics, getContext()); } - xegpu::DistributeLayoutAttr getLayoutAttr() { + xegpu::DistributeLayoutAttr getDescLayoutAttr() { return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout()); } @@ -389,7 +435,6 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ return getTensorDescType().getShape(); } - }]; let assemblyFormat = [{ @@ -409,7 +454,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, - "xegpu::CachePolicyAttr": $l3_hint)> + "xegpu::CachePolicyAttr": $l3_hint, + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; @@ -428,16 +474,36 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ Corresponding cache hint attribute will be masked. It is only available to 1D or 2D blocked tensor_desc. - In SIMT mode, the input vector represents the data to be stored by each work-item. + At lane level, the input vector represents the data to be stored by each lane. + + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + + Arguments: + + - `value`: A vector value representing the tensor tile to be stored. + + - `TensorDesc`: A tensor descriptor specifying the base nd-region of memory and + the tensor tile to be stored. - Example 1: + - `offsets`: Index values representing per-dimension offsets from the base position + encoded in `TensorDesc`. They are encoded via `offsets` and `const_offsets`. + + - `l1_hint`, `l2_hint`, `l3_hint`: [optional] Cache-hint attributes indicating the + desired behavior at the L1, L2, and L3 cache levels. + + - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand as well as + the value to be stored (they are identical). Only valid at workgroup and subgroup levels. + + Example 1 (Workgroup level): ```mlir xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, - l3_hint = #xegpu.cache_hint<write_through>} - : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + l3_hint = #xegpu.cache_hint<write_through>, + layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32]>} + : vector<32x256xf16>, !xegpu.tensor_desc<32x256xf16> ``` - Example 2 (SIMT mode): + Example 2 (lane level): ```mlir xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, @@ -454,7 +520,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ OptionalAttr<DenseI64ArrayAttr>: $const_offsets, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, - OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); + OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint, + OptionalAttr<DistributeLayoutAttr>:$layout); let extraClassDeclaration = extraBaseClassDeclaration # [{ VectorType getValueType() { @@ -473,7 +540,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ return getMixedValues(statics, dynamics, getContext()); } - xegpu::DistributeLayoutAttr getLayoutAttr() { + xegpu::DistributeLayoutAttr getDescLayoutAttr() { return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout()); } @@ -499,7 +566,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ "ArrayRef<OpFoldResult>": $offsets, "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, - "xegpu::CachePolicyAttr": $l3_hint)> + "xegpu::CachePolicyAttr": $l3_hint, + "xegpu::DistributeLayoutAttr": $layout)> ]; @@ -561,21 +629,22 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> { "create_tdesc" is similar to "create_nd_tdesc" in terms that it creates a Tensor Descriptor (TensorDescType) for a memory region. While "create_nd_tdesc" is for creating continuous subviews, "create_tdesc" is for creating non-continuous - (scattered) subviews, allowing each work-item in a subgroup specifying their own offset. + (scattered) subviews, allowing each lane in a subgroup specifying their own offset. It accepts the following parameters: Arguments: + - `source`: a 1D memref or pointer (i64, i32, ui64, ui32) represents the flattened memory object. + - `offsets`: a vector containing offsets of each access point. Its size is fixed to the hardware supportted subgroup size, e.g., 16 on PVC, - implying each element in the vector corresponds to a work-item (SIMT lane) - in the subgroup. + implying each element in the vector corresponds to a SIMT lane in the subgroup. Results: - `res`: scattered tensor descriptor - The first dimension of the result TensorDesc corresponds to work-items, so it should + The first dimension of the result TensorDesc corresponds to lanes, so it should match the dimension of offsets. It may also has a second dimension corresponding to the chunk_size if the chunk size is larger than 1. @@ -664,27 +733,39 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead. + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + Arguments: + - `source`: represents the memory region to be loaded from, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. - tensor_desc cannot be used in SIMT mode. + tensor_desc cannot be used at lane level. + - `offsets`: represents offsets from source. required if `source` in not a TensorDescType. offsets is a vector of `index` type and vector length is either the subgroup size - or 1 in SIMT mode. scalar offset is also valid for SIMT mode. - - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache. - - `offset_align_byte`: required if `source` is a pointer. If `source` is not a pointer, + or 1 at lane level. scalar offset is also valid for lane level. + + - `l1_hint`, `l2_hint`, `l3_hint`: [optional] cache hints for each level of cache. + + - `offset_align_byte`: [optional] required if `source` is a pointer. If `source` is not a pointer, it is not allowed. Represents the alignment in bytes of each offset in offsets. - Example 1: + - `layout`: [optional] Describes the expected layout of the `tensor_desc` or `offsets` + operand. Only valid at workgroup and subgroup levels. + + Example 1 (Workgroup level): ```mlir xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, - l3_hint = #xegpu.cache_hint<cached>} - : !xegpu.tensor_desc<16xf16> + l3_hint = #xegpu.cache_hint<cached>, + layout = #xegpu.layout<sg_layout = [8], sg_data = [32]> + } + : !xegpu.tensor_desc<256xf16> ``` - Example 2: + Example 2 (lane level): A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (ui64, ui32, i64, i32). @@ -698,8 +779,8 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { : memref<1024xf32>, vector<4xindex> ``` - Example 3 (SIMT mode): - SIMT mode only accepts the offsets variant. + Example 3 (lane level): + lane level only accepts the offsets variant. ```mlir xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, @@ -707,8 +788,8 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { : memref<256xf32>, vector<1xindex> ``` - Example 4 (SIMT mode): - SIMT mode only accepts the offsets variant. + Example 4 (lane level): + lane level only accepts the offsets variant. ```mlir xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, @@ -724,7 +805,8 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint, OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint, OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint, - OptionalAttr<I64Attr>:$offset_align_byte); + OptionalAttr<I64Attr>:$offset_align_byte, + OptionalAttr<DistributeLayoutAttr>:$layout); let extraClassDeclaration = extraBaseClassDeclaration # [{ Type getSourceType() { @@ -764,54 +846,67 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a set of scattered data points from memory."; - let description = [{ It (aka. load) load data per each work-item. The output + let description = [{ It (aka. load) load data per each lane. The output describes the data being loaded at the subgroup level, so its size is - consistent with the number of work-items in a subgroup. When the chunk size + consistent with the number of lanes in a subgroup. When the chunk size is larger than 2, the output vector is a 2D vector, with dim-0 correspoding - to work-items, and dim-1 corresponding to the chunk size loaded by each work-item. + to lanes, and dim-1 corresponding to the chunk size loaded by each lane. The mask operand masks out memory access so that it is safe to pass out-of-boundary - addresses/offsets as long as they are masked. It applies to slots of SIMD lanes. + addresses/offsets as long as they are masked. Each mask element applies to one lane. + + In lane level, the result is a 1D vector that represents the data to be loaded by + each lane. If size is not 1, size should be equal to the chunk size. - In SIMT mode, the result is a 1D vector that represents the data to be loaded by - each work-item. If size is not 1, size should be equal to the chunk size, + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. Arguments: + - `source`: represents the memory region to be loaded from, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. - tensor_desc cannot be used in SIMT mode. + tensor_desc cannot be used at lane level. + - `offsets`: represents offsets from source. required if `source` in not a TensorDescType. offsets is a vector of `index` type and vector length is either the subgroup size - or 1 in SIMT mode. scalar offset is also valid for SIMT mode. + or 1 at lane level. scalar offset is also valid for lane level. + - `mask`: is a vector of `i1` type, which is used to mask out the memory access. - mask is a vector of size equal to the subgroup size, or 1 in SIMT mode. - scalar mask is also valid for SIMT mode. - - `chunk_size`: (optional) represents contiguous number of elements to load from per work item. - - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache. + mask is a vector of size equal to the subgroup size, or 1 at lane level. + scalar mask is also valid for lane level. + + - `chunk_size`: [optional] represents contiguous number of elements to load from per work item. + + - `l1_hint`, `l2_hint`, `l3_hint`: [optional] cache hints for each level of cache. + + - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand or the result + of load. Only valid at workgroup and subgroup levels. Results: - `res`: represents loaded data - Example 1: + Example 1 (Workgroup level): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, - l3_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>, - vector<16xi1> -> vector<16xf32> + l3_hint = #xegpu.cache_hint<uncached>}, + layout = #xegpu.layout<sg_layout = [8], sg_data = [32]>> + : !xegpu.tensor_desc<256xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>, + vector<256xi1> -> vector<256xf32> ``` - Example 2: + Example 2 (Subgroup level): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, - l3_hint = #xegpu.cache_hint<uncached>}> + l3_hint = #xegpu.cache_hint<uncached>}, + layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>> : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>, vector<16xi1> -> vector<16x8xf32> ``` - Example 3: + Example 3 (Subgroup level): A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc @@ -822,12 +917,13 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { %mask = vector.constant_mask [16]: vector<16xi1> %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, - l3_hint = #xegpu.cache_hint<cached>} + l3_hint = #xegpu.cache_hint<cached>, + layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> ``` - Example 4 (SIMT mode): - SIMT mode only accepts the offsets variant. chunk_size can be inferred from result + Example 4 (lane level): + lane level only accepts the offsets variant. chunk_size can be inferred from result type. In this example, chunk_size is 8. ```mlir %2 = xegpu.load %1[%2], %0 <{l1_hint = #xegpu.cache_hint<cached>, @@ -844,7 +940,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint, OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint, OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint, - OptionalAttr<XeGPU_LayoutAttr>:$layout); + OptionalAttr<DistributeLayoutAttr>:$layout); let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value); let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -903,7 +999,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint, - "xegpu::LayoutAttr": $layout)> + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; @@ -919,41 +1015,56 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is introduced on purpose, making sure users are aware of this implicit transformation. - In SIMT mode, the result is a 1D vector that represents the data to be stored by - each work-item. If size is not 1, size should be equal to the chunk size. + In lane level, the result is a 1D vector that represents the data to be stored by + each lane. If size is not 1, size should be equal to the chunk size. + + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. Arguments: + - `value`: represents the data to be stored. + - `dest`: represents the memory region to be stored to, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. - tensor_desc cannot be used in SIMT mode. + tensor_desc cannot be used at lane level. + - `offsets`: represents offsets from dest. required if `source` in not a TensorDescType. offsets is a vector of `index` type and vector length is either the subgroup size - or 1 in SIMT mode. scalar offset is also valid for SIMT mode. + or 1 at lane level. scalar offset is also valid for lane level. + - `mask`: is a vector of `i1` type, which is used to mask out the memory access. - mask is a vector of size equal to the subgroup size, or 1 in SIMT mode. - scalar mask is also valid for SIMT mode. - - `chunk_size`: (optional) represents contiguous number of elements to store to per work item. - - `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache. + mask is a vector of size equal to the subgroup size, or 1 at lane level. + scalar mask is also valid for lane level. + + - `chunk_size`: [optional] represents contiguous number of elements to store to per work item. + + - `l1_hint`, `l2_hint`, `l3_hint`: [optional] cache hints for each level of cache. + + - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand or the value + to be stored. Only valid at workgroup and subgroup levels. + - Example 1: + Example 1 (Workgroup level): ```mlir xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, - l3_hint = #xegpu.cache_hint<write_through>}> - : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1> + l3_hint = #xegpu.cache_hint<write_through>, + layout = #xegpu.layout<sg_layout = [8], sg_data = [16]>}> + : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.scattered_tdesc_attr<>>, vector<256xi1> ``` - Example 2: + Example 2 (Subgroup level): ```mlir xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<write_back>, - l3_hint = #xegpu.cache_hint<write_through>}> + l3_hint = #xegpu.cache_hint<write_through>, + layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}> : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1> ``` - Example 3: + Example 3 (Subgroup level): A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). @@ -965,12 +1076,13 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { %mask = vector.constant_mask [16]: vector<16xi1> xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, - l3_hint = #xegpu.cache_hint<cached>} + l3_hint = #xegpu.cache_hint<cached>, + layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> ``` - Example 4 (SIMT mode): - SIMT mode only accepts the offsets variant. chunk_size can be inferred from value + Example 4 (Lane level): + Lane level IR only accepts the offsets variant. chunk_size can be inferred from value type. In this example, chunk_size is 8. ```mlir xegpu.store %0, %1[%2], %3 <{l1_hint = #xegpu.cache_hint<uncached>, @@ -988,7 +1100,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint, OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint, OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint, - OptionalAttr<XeGPU_LayoutAttr>:$layout); + OptionalAttr<DistributeLayoutAttr>:$layout); let extraClassDeclaration = extraBaseClassDeclaration#[{ Type getDestType() { @@ -1046,7 +1158,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { "xegpu::CachePolicyAttr": $l1_hint, "xegpu::CachePolicyAttr": $l2_hint, "xegpu::CachePolicyAttr": $l3_hint, - "xegpu::LayoutAttr": $layout)> + "xegpu::DistributeLayoutAttr": $layout)> ]; let hasVerifier = 1; @@ -1061,8 +1173,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", the current position in the number of elements. However, `update_nd_offset` is to update the start point of a 2D block, so its offset constains two elements representing the shift in each dimension. `update_offset` is to - update the offset per work-item, so its offsets contains values representing - shifts for each work-item. + update the offset per lane, so its offsets contains values representing + shifts for each lane. Example: ```mlir @@ -1112,28 +1224,57 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>] size, B of `kxn` size, and accumulate on matrix C of `mxn` to the same size matrix , `m=8`, `n=16` and `k=8 * 32/bit_width_of_elem_type`. So for fp16 data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`, - and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS - also requires A and B to be loaded with the required data layout. Specially, - VNNI layout is required for B operand. It is achieved via adding `packed` - attribute to the `load_nd` operator. Due to the VNNI transformation, B operands - can be represented as a 3D vector, with the last dimension representing the VNNI - factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>` - can be represented as `B: vector<8x16x2xf16>`. - - In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result, + and `C/D: vector<8x16xf32>`. + + In lane level code, each lane from a subgroup holds a data fragment for A, B, C and the result, which are represented as 1D vectors. Please refer to [OpenCL Intel extentions] (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html) for more details about the fragment distribution. - Note: on PVC, the hardware can perform load with VNNI transformation when data - element type is 16-bit or lower precision, taking 2 or 4 elements from - the first dimension and inserted into the newly added innermost dimension. + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + + Arguments: + + - `lhs`: A vector value representing the left-hand-side matrix tile (A) participating in the + matrix multiply. + + - `rhs`: A vector value representing the right-hand-side matrix tile (B). + + - `acc`: [optional] A vector value representing the accumulator matrix tile (C). When present, the + result is computed as `lhs * rhs + acc`; otherwise, the accumulator is implicitly assumed to be zero. + + - `layout_a`, `layout_b`, `layout_cd`: [optional] Attributes that identify this + operation as anchor for operands A, B, and the accumulator/result, enabling users to assign layouts + that govern distribution at the subgroup and/or lane level. Only valid at workgroup and subgroup + level. + + Example 1 (Workgroup level): + + ```mlir + %d = xegpu.dpas %a, %b, %c <{ + layout_a = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 128]>, + layout_b = #xegpu.layout<sg_layout = [4, 8], sg_data = [128, 16]>, + layout_cd = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16]>} + : vector<64x128xf16>, vector<128x128xf16>, vector<64x128xf32> -> vector<64x128xf32> + ``` + + Example 2 (Lane level): + + ```mlir + %d = xegpu.dpas %a, %b, %c + : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + ``` }]; let arguments = (ins XeGPU_DpasOprType : $lhs, XeGPU_DpasOprType : $rhs, - Optional<XeGPU_DpasResType>: $acc); + Optional<XeGPU_DpasResType>: $acc, + OptionalAttr<DistributeLayoutAttr>:$layout_a, + OptionalAttr<DistributeLayoutAttr>:$layout_b, + OptionalAttr<DistributeLayoutAttr>:$layout_cd + ); let results = (outs XeGPU_DpasResType: $result); let extraClassDeclaration = [{ @@ -1180,13 +1321,35 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure, has the same shape with `TensorDesc`, and is used to enable or disable specific data points of the `TensorDesc`. The `value` operand represents the new value to be applied during the modification. + + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + + Arguments: + - `kind`: An attribute that specifies the atomic operation to be performed + (e.g., add, min, max, exchange, etc.). + + - `tensorDesc`: A `TensorDesc` describing the memory region on which the atomic + read-modify-write is performed. + + - `mask`: A predicate mask with the same shape as `tensorDesc`. Only elements + with a true (non-zero) mask value participate in the atomic operation; + masked-out elements are not modified. + + - `value`: The input values used by the atomic operation. It must have the same + shape and element type as `tensorDesc` and `result`. + + - `layout`: [optional] An attribute that identifies the operation as an anchor, + enabling users to assign a layout that governs distribution at the subgroup + and/or lane level. Only valid at workgroup and subgroup levels. }]; let arguments = (ins AtomicRMWKindAttr:$kind, XeGPU_TensorDesc:$tensorDesc, XeGPU_MaskType:$mask, - XeGPU_ValueType:$value); + XeGPU_ValueType:$value, + OptionalAttr<DistributeLayoutAttr>:$layout); let results = (outs XeGPU_ValueType:$result); @@ -1264,10 +1427,29 @@ def XeGPU_FenceOp: XeGPU_Op<"fence", []> { def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["source", "result"]>]> { let summary = "Convert the layout of the input operand"; let description = [{ - `convert_layout` redistribute data across subgroups and/or work-items from the `input_layout` to + `convert_layout` redistribute data across subgroups and/or lanes from the `input_layout` to the `target_layout`. Both `input_layout` and `target_layout` must correspond to the same programming - scope, such as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once + scope, such as workgroup level (wg) or subgroup level (sg) code. This operation is not valid once the IR is lowered to WI level because that is the end result of all distributions. + + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + + Arguments: + - `source`: The input vector whose data is to be redistributed. The source and + result types must match. + - `input_layout`: The layout attribute describing the current distribution of `source` + across subgroups and/or lanes. + - `target_layout`: The layout attribute describing the desired distribution of the result + across subgroups and/or lanes. + + Example (Subgroup level): + ```mlir + %coop_a = xegpu.convert_layout %a <{ + input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>, + target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> + : vector<128x128xf16> + ``` }]; let arguments = (ins XeGPU_VectorType: $source, DistributeLayoutAttr: $input_layout, @@ -1282,12 +1464,6 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou let hasCanonicalizer = 1; } -def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">; -class StaticShared1DMemRefOf<list<Type> allowedTypes> : - ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred], - "statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory", - "mlir::MemRefType">; - class SizeInBits<string name> : StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()" "*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">; @@ -1304,11 +1480,20 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, as the underlying shared local memory. Arguments: - - `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer. + - `source` : 1D or 2D statically shape memref, representing the raw SLM buffer. The provided memref must be contiguous. + Results: - `mem_desc` : the memory descriptor. + + Example: + ```mlir + %mdesc = xegpu.create_mem_desc %mref + : memref<4096xi8, 3> + -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> + ``` + }]; - let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source); + let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[XeGPU_ScalarType]>, StaticShared2DMemRefOf<[XeGPU_ScalarType]>]>:$source); let results = (outs XeGPU_MemDesc:$mem_desc); let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))"; } @@ -1332,17 +1517,30 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed. + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + Arguments: - `mem_desc`: the memory descriptor identifying the SLM region. - `offsets`: the coordinates within the matrix to read from. - - `subgroup_block_io`: [optional] An attribute indicating that the operation can be - lowered to a subgroup block load. When this attribute is present, - the offsets are subgroup-uniform across all lanes. - - `layout`: [optional] An attribute for guiding distributions among - subgroups and/or work-items. It currently can accept either - LayoutAttr or SliceAttr. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be lowered + to a subgroup block load. When this attribute is present, the offsets are subgroup-uniform + across all lanes. Only used on subgroup and lane level. + - `layout`: [optional] Describes the expected layout of the `mem_desc` operand as well as + the result of load (they are identical). + Only valid at workgroup and subgroup levels. + Results: - `res`: the matrix elements loaded from SLM. + + Example (Workgroup level): + ```mlir + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %c0] <{ + layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 16]> }> + : !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128], block = [16, 16]>> + , index, index -> vector<128x128xf16> + ``` }]; let builders = [ @@ -1382,16 +1580,26 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed. + This operation serves as an anchor through which users assign a layout attribute + to govern computation distribution. + Arguments: - `mem_desc`: the memory descriptor specifying the SLM region. - `offsets`: the coordinates within the matrix where the data will be written. - `data`: the values to be stored in the matrix. - - `subgroup_block_io`: [optional] An attribute indicating that the operation can be - lowered to a subgroup block store. When this attribute is present, - the offsets are subgroup-uniform across all lanes. - - `layout`: [optional] An attribute for guiding distributions among - subgroups and/or work-items. It currently can accept either - LayoutAttr or SliceAttr. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be lowered + to a subgroup block load. When this attribute is present, the offsets are subgroup-uniform + across all lanes. Only used on subgroup and lane level. + - `layout`: [optional] Describes the expected layout of the `tensor_desc` operand as well as + the value to be stored (they are identical). Only valid at workgroup and subgroup levels. + + Example (Workgroup level): + ```mlir + %c0 = arith.constant 0 : index + xegpu.store_matrix %1, %0[%c0, %c0] <{ + layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 16]> }> + : vector<128x128xf16>, !xegpu.mem_desc<128x128xf16>>, index, index + ``` }]; let builders = [ OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc, diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index b1196fb..716681f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -13,8 +13,9 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td" include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td" include "mlir/IR/BuiltinTypes.td" -def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; -def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>; +def XeGPU_IntType : AnyTypeOf<[I1, I<4>, I8, I16, I32, I64, SI1, SI8, SI16, + SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; +def XeGPU_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>; def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>; def XeGPU_BaseAddrType @@ -35,6 +36,17 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [], let mnemonic = typeMnemonic; } +def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">; +class StaticShared1DMemRefOf<list<Type> allowedTypes> : + ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred], + "reside in share memory and statically 1d shaped " # MemRefOf<allowedTypes>.summary # " ", + "mlir::MemRefType">; + +class StaticShared2DMemRefOf<list<Type> allowedTypes>: + ConfinedType<MemRefRankOf<allowedTypes, [2]>, [HasStaticShapePred, isSharedPred], + "reside in share memory and statically 2d shaped " # MemRefOf<allowedTypes>.summary # " ", + "mlir::MemRefType">; + def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", [ShapedTypeInterface], "::mlir::TensorType"> { let summary = "TensorDesc describing regions of interested data."; diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 34f333e..29579ac 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -42,10 +42,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [ let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result."; let description = [{ - Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout` - attribute to the result tensor descriptor. The layout is defined by the - `sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle - to the transformed op. + Given an `xegpu.create_nd_desc` operation, this transform adds + `xegpu.layout` attribute to the result tensor descriptor. The layout is + defined by the `sg_layout`, and `sg_data` and optional `inst_data` + attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is + wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute. Returns a handle to + the transformed op. }]; let arguments = (ins @@ -55,7 +57,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [ Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data, DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout, DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data, - DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims ); let results = (outs TransformHandleTypeInterface:$transformed); @@ -63,7 +66,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [ OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedSgLayout, "ArrayRef<OpFoldResult>":$mixedSgData, - "ArrayRef<OpFoldResult>":$mixedInstData + "ArrayRef<OpFoldResult>":$mixedInstData, + "ArrayRef<int64_t>":$sliceDims )>, ]; @@ -72,6 +76,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout) `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data) (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)? + (`slice_dims` `=` $slice_dims^)? attr-dict `:` functional-type(operands, results) }]; @@ -107,7 +112,9 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [ Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The target operand/result value is defined by the `index` argument. The layout - is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes. + is defined by the `sg_layout`, `sg_data` and optional `inst_data` + attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is + wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute. }]; let arguments = (ins TransformHandleTypeInterface:$target, @@ -118,6 +125,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout, DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data, DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims, DefaultValuedAttr<UnitAttr, "false">:$result ); @@ -128,6 +136,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [ "ArrayRef<OpFoldResult>":$mixedSgLayout, "ArrayRef<OpFoldResult>":$mixedSgData, "ArrayRef<OpFoldResult>":$mixedInstData, + "ArrayRef<int64_t>":$sliceDims, CArg<"bool", "false">:$result )>, ]; @@ -137,6 +146,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout) `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data) (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)? + (`slice_dims` `=` $slice_dims^)? attr-dict `:` qualified(type(operands)) }]; @@ -161,4 +171,173 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [ }]; } +def SetGPULaunchThreadsOp + : Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TransformOpInterface + ]> { + + let summary = "Set number of threads for a given gpu.launch operation"; + let description = [{ + Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + Variadic<TransformAnyParamTypeOrAnyHandle>:$threads, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads + ); + let results = (outs); + let builders = [ + OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>, + ]; + + let assemblyFormat = [{ + $target + `threads` `=` custom<DynamicIndexList>($threads, $static_threads) + attr-dict `:` qualified(type(operands)) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() { + Builder b(getContext()); + return getMixedValues(getStaticThreads(), getThreads(), b); + } + }]; +} + +def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [ + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TransformOpInterface +]> { + + let summary = "Adds xegpu prefetch ops to matmul operand tiles."; + let description = [{ + Given a target value (e.g., `vector`) residing in a `scf.for` loop, this + transform finds the corresponding `xegpu.load_nd` op and inserts + `xegpu.prefetch_nd` operations for the tile. The load op must reside within + the `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch` + argument (default value is 1). Returns a handle to the created + `xegpu.create_nd_desc` op. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target, + Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch, + DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch + ); + + let results = (outs TransformHandleTypeInterface:$desc_op); + + let assemblyFormat = [{ + $target + `nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)? + attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + OpFoldResult getNbPrefetch() { + auto cxt = getContext(); + if (getDynamicNbPrefetch()) + return OpFoldResult(getDynamicNbPrefetch()); + return OpFoldResult(IntegerAttr::get( + IntegerType::get(cxt, 64), getStaticNbPrefetch())); + } + }]; +} + +def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TransformOpInterface +]> { + + let summary = "Convert xegpu.layout attribute for a value."; + let description = [{ + Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute + of a value. The input and target layouts are defined by the `*sg_layout`, + `*sg_data` and optional `*inst_data` attributes. Returns a handle to the + emitted `xegpu.convert_layout` op. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target, + Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_layout, + Variadic<TransformAnyParamTypeOrAnyHandle>:$input_sg_data, + Variadic<TransformAnyParamTypeOrAnyHandle>:$input_inst_data, + Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_layout, + Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_data, + Variadic<TransformAnyParamTypeOrAnyHandle>:$target_inst_data, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_layout, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_data, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_inst_data, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_layout, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_data, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data + ); + + let results = (outs TransformHandleTypeInterface:$newConvertOp); + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef<OpFoldResult>":$mixedInputSgLayout, + "ArrayRef<OpFoldResult>":$mixedInputSgData, + "ArrayRef<OpFoldResult>":$mixedInputInstData, + "ArrayRef<OpFoldResult>":$mixedTargetSgLayout, + "ArrayRef<OpFoldResult>":$mixedTargetSgData, + "ArrayRef<OpFoldResult>":$mixedTargetInstData + )>, + ]; + + let assemblyFormat = [{ + $target + `input_sg_layout` `=` custom<DynamicIndexList>($input_sg_layout, $static_input_sg_layout) + `input_sg_data` `=` custom<DynamicIndexList>($input_sg_data, $static_input_sg_data) + (`input_inst_data` `=` custom<DynamicIndexList>($input_inst_data, $static_input_inst_data)^)? + `target_sg_layout` `=` custom<DynamicIndexList>($target_sg_layout, $static_target_sg_layout) + `target_sg_data` `=` custom<DynamicIndexList>($target_sg_data, $static_target_sg_data) + (`target_inst_data` `=` custom<DynamicIndexList>($target_inst_data, $static_target_inst_data)^)? + attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgLayout() { + Builder b(getContext()); + return getMixedValues(getStaticInputSgLayout(), getInputSgLayout(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputSgData() { + Builder b(getContext()); + return getMixedValues(getStaticInputSgData(), getInputSgData(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInputInstData() { + Builder b(getContext()); + return getMixedValues(getStaticInputInstData(), getInputInstData(), b); + } + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgLayout() { + Builder b(getContext()); + return getMixedValues(getStaticTargetSgLayout(), getTargetSgLayout(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetSgData() { + Builder b(getContext()); + return getMixedValues(getStaticTargetSgData(), getTargetSgData(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedTargetInstData() { + Builder b(getContext()); + return getMixedValues(getStaticTargetInstData(), getTargetInstData(), b); + } + }]; +} + #endif // XEGPU_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 12270af..0ca5842 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -37,6 +37,19 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> { propagate the layouts required for their operands to the producers. With this propagated layout information, pass will then update op result type with the layout information. + + `layout-kind` option values: + - `inst` + Propagate the `inst_data` field of the layout attribute. The default is chosen to + maximize instruction-level granularity so that the user shape can be processed + with the fewest instructions. For N-D operations, this granularity depends on + W (width) and H (height) of the instruction shape. + The B (block) dimension (or array length) is not included in the default + configuration and must be enabled via a separate optimization pass. + + - `lane` + Propagate the `lane_layout` and `lane_data` fields of the layout attribute. + Default values are selected to align with hardware. }]; let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"]; diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index b7e168a..8ac1a2e 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -188,7 +188,12 @@ class AnyAttrOf<list<Attr> allowedAttrs, string summary = "", } def LocationAttr : Attr<CPred<"::llvm::isa<::mlir::LocationAttr>($_self)">, - "location attribute">; + "location attribute"> { + let storageType = [{ ::mlir::LocationAttr }]; + let returnType = [{ ::mlir::Location }]; + let convertFromStorage = "::mlir::Location($_self)"; + let constBuilderCall = "(::mlir::LocationAttr)$0"; +} def BoolAttr : Attr<CPred<"::llvm::isa<::mlir::BoolAttr>($_self)">, "bool attribute"> { let storageType = [{ ::mlir::BoolAttr }]; diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td index 0cbe3fa..e51bbd5 100644 --- a/mlir/include/mlir/IR/Interfaces.td +++ b/mlir/include/mlir/IR/Interfaces.td @@ -147,6 +147,11 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []> !if(!empty(cppNamespace),"", cppNamespace # "::") # name >; +// DialectInterface represents a Dialect Interface. +class DialectInterface<string name, list<Interface> baseInterfaces = []> + : Interface<name, baseInterfaces>, OpInterfaceTrait<name>; + + // Whether to declare the interface methods in the user entity's header. This // class simply wraps an Interface but is used to indicate that the method // declarations should be generated. This class takes an optional set of methods diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc index d5fb57d..4afbcf2 100644 --- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc +++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc @@ -152,9 +152,7 @@ public: void push_back(TypeRange value) { // The lifetime of a TypeRange can't be guaranteed, so we'll need to // allocate a storage for it. - llvm::OwningArrayRef<Type> storage(value.size()); - llvm::copy(value, storage.begin()); - allocatedTypeRanges.emplace_back(std::move(storage)); + allocatedTypeRanges.emplace_back(value.begin(), value.end()); typeRanges.push_back(allocatedTypeRanges.back()); results.push_back(&typeRanges.back()); } @@ -174,9 +172,7 @@ public: void push_back(ValueRange value) { // The lifetime of a ValueRange can't be guaranteed, so we'll need to // allocate a storage for it. - llvm::OwningArrayRef<Value> storage(value.size()); - llvm::copy(value, storage.begin()); - allocatedValueRanges.emplace_back(std::move(storage)); + allocatedValueRanges.emplace_back(value.begin(), value.end()); valueRanges.push_back(allocatedValueRanges.back()); results.push_back(&valueRanges.back()); } @@ -206,8 +202,8 @@ protected: SmallVector<ValueRange> valueRanges; /// Memory allocated to store ranges in the result list whose lifetime was /// generated in the native function. - SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges; - SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges; + SmallVector<std::vector<Type>> allocatedTypeRanges; + SmallVector<std::vector<Value>> allocatedValueRanges; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 576481a..35f7290 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -675,9 +675,9 @@ public: /// true. Also notify the listener about every in-place op modification (for /// every use that was replaced). The optional `allUsesReplaced` flag is set /// to "true" if all uses were replaced. - void replaceUsesWithIf(Value from, Value to, - function_ref<bool(OpOperand &)> functor, - bool *allUsesReplaced = nullptr); + virtual void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr); void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref<bool(OpOperand &)> functor, bool *allUsesReplaced = nullptr); diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index a7ade06..2830ba9 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -468,7 +468,7 @@ class ArrayProp<Property elem = Property<>, string newSummary = ""> : return $_diag() << "expected array attribute"; for (::mlir::Attribute elemAttr : arrayAttr) { }] # _makePropStorage<elem, "elemVal">.ret # [{ - auto elemRes = [&](Attribute propAttr, }] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { + auto elemRes = [&](::mlir::Attribute propAttr, }] # elem.storageType # [{& propStorage) -> ::mlir::LogicalResult { }] # !subst("$_attr", "propAttr", !subst("$_storage", "propStorage", elem.convertFromAttribute)) # [{ }(elemAttr, elemVal); @@ -480,7 +480,7 @@ class ArrayProp<Property elem = Property<>, string newSummary = ""> : }]; let convertToAttribute = [{ - SmallVector<Attribute> elems; + SmallVector<::mlir::Attribute> elems; for (const auto& elemVal : $_storage) { auto elemAttr = [&](const }] # elem.storageType #[{& propStorage) -> ::mlir::Attribute { }] # !subst("$_storage", "propStorage", elem.convertToAttribute) # [{ @@ -647,7 +647,7 @@ class OptionalProp<Property p, bit canDelegateParsing = 1> } ::mlir::Attribute presentAttr = arrayAttr[0]; }] # _makePropStorage<p, "presentVal">.ret # [{ - auto presentRes = [&](Attribute propAttr, }] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { + auto presentRes = [&](::mlir::Attribute propAttr, }] # p.storageType # [{& propStorage) -> ::mlir::LogicalResult { }] # !subst("$_storage", "propStorage", !subst("$_attr", "propAttr", p.convertFromAttribute)) # [{ }(presentAttr, presentVal); diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h index 9877926..3102542 100644 --- a/mlir/include/mlir/IR/Remarks.h +++ b/mlir/include/mlir/IR/Remarks.h @@ -99,18 +99,30 @@ public: } // Remark argument that is a key-value pair that can be printed as machine - // parsable args. + // parsable args. For Attribute arguments, the original attribute is also + // stored to allow custom streamers to handle them specially. struct Arg { std::string key; std::string val; + /// Optional attribute storage for Attribute-based args. Allows streamers + /// to access the original attribute for custom handling. + std::optional<Attribute> attr; + Arg(llvm::StringRef m) : key("Remark"), val(m) {} Arg(llvm::StringRef k, llvm::StringRef v) : key(k), val(v) {} Arg(llvm::StringRef k, std::string v) : key(k), val(std::move(v)) {} Arg(llvm::StringRef k, const char *v) : Arg(k, llvm::StringRef(v)) {} Arg(llvm::StringRef k, Value v); Arg(llvm::StringRef k, Type t); + Arg(llvm::StringRef k, Attribute a); Arg(llvm::StringRef k, bool b) : key(k), val(b ? "true" : "false") {} + /// Check if this arg has an associated attribute. + bool hasAttribute() const { return attr.has_value(); } + + /// Get the attribute if present. + Attribute getAttribute() const { return attr.value_or(Attribute()); } + // One constructor for all arithmetic types except bool. template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T> && !std::is_same_v<T, bool>>> diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index bbfa308..b3aafe0 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -171,6 +171,8 @@ def Symbol : OpInterface<"SymbolOpInterface"> { if (concreteOp.isDeclaration() && concreteOp.isPublic()) return concreteOp.emitOpError("symbol declaration cannot have public " "visibility"); + if ($_op->getNumResults() != 0) + return concreteOp.emitOpError("symbols must not have results"); auto parent = $_op->getParentOp(); if (parent && !parent->hasTrait<OpTrait::SymbolTable>() && parent->isRegistered()) { return concreteOp.emitOpError("symbol's parent must have the SymbolTable " diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h index 4fcbeff..1bfb66e 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>; LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes); +FailureOr<SmallVector<OpFoldResult>> +reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex); +FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op, + int resultIndex, int dim); /// Adaptor class to abstract the differences between whether value is from /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td index 1a2c05f..67568f7 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface : let methods = [ InterfaceMethod< /*desc=*/[{ - Reify the shape of the result of an operation (typically in terms of the - shape of its operands). + Reify the shapes of all the result of an operation (typically in terms + of the shape of its operands). `reifiedReturnShapes` is populated with one vector per op result. Each of those vectors contains an OpFoldResult for each dimension of the shaped type. The given builder may be used to insert ops that compute result shapes. - If the shape of a particular result cannot be computed it must be empty. + If the shape of a particular result cannot be computed it in terms of + its operands it must be left empty. If any dimension of the result cannot + be computed it must be set to OpFoldResult(). }], /*retTy=*/"::llvm::LogicalResult", /*methodName=*/"reifyResultShapes", /*args=*/(ins "::mlir::OpBuilder &":$builder, - "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes) + "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return ::mlir::failure(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Reify the shape of a single result of an operation (typically in terms + of the shape of its operands). + + Returns the shape of a single result of the operation as a + `SmallVector<OpFoldResult>`, one per dimension of the shaped type. The + given builder may be used to insert ops that compute result shapes. + + If any dimension of the result cannot be computed it must be set to + OpFoldResult(). + }], + /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>", + /*methodName=*/"reifyShapeOfResult", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "int":$resultIndex), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + ReifiedRankedShapedTypeDims reifiedShapes; + if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes))) + return failure(); + if (resultIndex < 0 || resultIndex >= static_cast<int>(reifiedShapes.size())) + return $_op.emitOpError("invalid result index"); + return reifiedShapes[resultIndex]; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Reify the shape of a dimension of a given result of an operation + (typically in terms of the shape of its operands). + + Returns the shape of a specific dimension of a result of the operation as + an OpFoldResult. The given builder may be used to insert ops that compute + the shapes. + + If the dimension of the result cannot be computed the method must return + `failure()`. + }], + /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>", + /*methodName=*/"reifyDimOfResult", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "int":$resultIndex, "int":$dim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex); + if (failed(shapes)) + return failure(); + if (dim < 0 || dim >= static_cast<int>((*shapes).size())) + return $_op.emitOpError("invalid dimension"); + return (*shapes)[dim]; + }] > ]; } diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td index e0516ab..c30782a 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -360,6 +360,43 @@ def TilingInterface : OpInterface<"TilingInterface"> { /*defaultImplementation=*/[{ return failure(); }] + >, + //===------------------------------------------------------------------===// + // Interface methods for querying fusability. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Indicates whether it is possible to fuse this operation with the given + result slice. This method is not allowed to generate any IR. + }], + /*retTy=*/"bool", + /*methodName=*/"isOpFusableWithConsumerSlice", + /*args=*/(ins + "unsigned":$resultNumber, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Indicates whether it is possible to fuse this operation with the given + list of operand slices. This method is not allowed to generate any IR. + }], + /*retTy=*/"bool", + /*methodName=*/"isOpFusableWithProducerSlices", + /*args=*/(ins + "::mlir::ArrayRef<unsigned>":$operandNumbers, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets, + "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes + ), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] > ]; } diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 16893c6..448a688 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -193,6 +193,13 @@ protected: /// This is useful for generic operation passes to add restrictions on the /// operations they operate on. virtual bool canScheduleOn(RegisteredOperationName opName) const = 0; + virtual bool canScheduleOn(Operation *op) const { + std::optional<RegisteredOperationName> registeredInfo = + op->getName().getRegisteredInfo(); + if (!registeredInfo) + return false; + return canScheduleOn(*registeredInfo); + } /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. @@ -436,6 +443,7 @@ protected: /// Indicate if the current pass can be scheduled on the given operation type. /// For an InterfacePass, this checks if the operation implements the given /// interface. + bool canScheduleOn(Operation *op) const final { return isa<InterfaceT>(op); } bool canScheduleOn(RegisteredOperationName opName) const final { return opName.hasInterface<InterfaceT>(); } diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h index a85562f..a33877d 100644 --- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h +++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h @@ -10,6 +10,7 @@ #define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H #include "mlir/IR/DialectInterface.h" +#include "mlir/Reducer/Tester.h" namespace mlir { @@ -47,10 +48,17 @@ public: /// replacing an operation with a constant. virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0; + /// This method extends `populateReductionPatterns` by allowing reduction + /// patterns to use a `Tester` instance. Some reduction patterns may need to + /// run tester to determine whether certain transformations preserve the + /// "interesting" behavior of the program. This is mostly useful when pattern + /// should choose between multiple modifications. + virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns, + Tester &tester) const {} + protected: DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {} }; - } // namespace mlir #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H diff --git a/mlir/include/mlir/Reducer/Tester.h b/mlir/include/mlir/Reducer/Tester.h index eb44afc..bed4408 100644 --- a/mlir/include/mlir/Reducer/Tester.h +++ b/mlir/include/mlir/Reducer/Tester.h @@ -36,6 +36,9 @@ public: Untested, }; + Tester() = default; + Tester(const Tester &) = default; + Tester(StringRef testScript, ArrayRef<std::string> testScriptArgs); /// Runs the interestingness testing script on a MLIR test case file. Returns @@ -46,6 +49,9 @@ public: /// Return whether the file in the given path is interesting. Interestingness isInteresting(StringRef testCase) const; + void setTestScript(StringRef script) { testScript = script; } + void setTestScriptArgs(ArrayRef<std::string> args) { testScriptArgs = args; } + private: StringRef testScript; ArrayRef<std::string> testScriptArgs; diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h index 7c36cbc..f62d21d 100644 --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -157,6 +157,13 @@ struct TypeInterface : public Interface { static bool classof(const Interface *interface); }; +// An interface that is registered to a Dialect. +struct DialectInterface : public Interface { + using Interface::Interface; + + static bool classof(const Interface *interface); +}; + } // namespace tblgen } // namespace mlir diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 49b2dae..d2610f0 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -643,8 +643,10 @@ public: using IdentifierLine = std::pair<StringRef, unsigned>; // Returns the file location of the pattern (buffer identifier + line number - // pair). - std::vector<IdentifierLine> getLocation() const; + // pair). If `forSourceOutput` is true, replace absolute paths in the buffer + // identifier with just their filename so that we don't leak build paths into + // the generated code. + std::vector<IdentifierLine> getLocation(bool forSourceOutput = false) const; // Recursively collects all bound symbols inside the DAG tree rooted // at `tree` and updates the given `infoMap`. diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h index 6a42627..0e50fac7 100644 --- a/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h +++ b/mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h @@ -84,10 +84,14 @@ public: /// Hook for derived dialect interfaces to publish the supported metadata /// kinds. As every metadata kind has a unique integer identifier, the - /// function returns the list of supported metadata identifiers. `ctx` can be - /// used to obtain IDs of metadata kinds that do not have a fixed static one. - virtual ArrayRef<unsigned> - getSupportedMetadata(llvm::LLVMContext &ctx) const { + /// function returns the list of supported metadata identifiers. The + /// `llvmContext` parameter is used to obtain identifiers for metadata kinds + /// that do not have a fixed static identifier. Since different LLVM contexts + /// can assign different identifiers to these non-static metadata kinds, the + /// function must recompute the list of supported metadata identifiers on each + /// call. + virtual SmallVector<unsigned> + getSupportedMetadata(llvm::LLVMContext &llvmContext) const { return {}; } }; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 09d819a..dba950c0 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -163,9 +163,10 @@ public: /// Converts `value` to a float attribute. Asserts if the matching fails. FloatAttr matchFloatAttr(llvm::Value *value); - /// Converts `value` to a local variable attribute. Asserts if the matching - /// fails. - DILocalVariableAttr matchLocalVariableAttr(llvm::Value *value); + /// Converts `valOrVariable` to a local variable attribute. Asserts if the + /// matching fails. + DILocalVariableAttr matchLocalVariableAttr( + llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> valOrVariable); /// Converts `value` to a label attribute. Asserts if the matching fails. DILabelAttr matchLabelAttr(llvm::Value *value); @@ -281,6 +282,10 @@ public: /// after the function conversion has finished. void addDebugIntrinsic(llvm::CallInst *intrinsic); + /// Adds a debug record to the list of debug records that need to be imported + /// after the function conversion has finished. + void addDebugRecord(llvm::DbgVariableRecord *dbgRecord); + /// Converts the LLVM values for an intrinsic to mixed MLIR values and /// attributes for LLVM_IntrOpBase. Attributes correspond to LLVM immargs. The /// list `immArgPositions` contains the positions of immargs on the LLVM @@ -339,9 +344,26 @@ private: /// Converts all debug intrinsics in `debugIntrinsics`. Assumes that the /// function containing the intrinsics has been fully converted to MLIR. LogicalResult processDebugIntrinsics(); + /// Converts all debug records in `dbgRecords`. Assumes that the + /// function containing the record has been fully converted to MLIR. + LogicalResult processDebugRecords(); /// Converts a single debug intrinsic. LogicalResult processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, DominanceInfo &domInfo); + /// Converts a single debug record. + LogicalResult processDebugRecord(llvm::DbgVariableRecord &dbgRecord, + DominanceInfo &domInfo); + /// Process arguments for declare/value operation insertion. `localVarAttr` + /// and `localExprAttr` are the attained attributes after importing the debug + /// variable and expressions. This also sets the builder insertion point to be + /// used by these operations. + std::tuple<DILocalVariableAttr, DIExpressionAttr, Value> + processDebugOpArgumentsAndInsertionPt( + Location loc, + llvm::function_ref<FailureOr<Value>()> convertArgOperandToValue, + llvm::Value *address, + llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> variable, + llvm::DIExpression *expression, DominanceInfo &domInfo); /// Converts LLMV IR asm inline call operand's attributes into an array of /// MLIR attributes to be utilized in `llvm.inline_asm`. ArrayAttr convertAsmInlineOperandAttrs(const llvm::CallBase &llvmCall); @@ -485,6 +507,9 @@ private: /// Function-local list of debug intrinsics that need to be imported after the /// function conversion has finished. SetVector<llvm::Instruction *> debugIntrinsics; + /// Function-local list of debug records that need to be imported after the + /// function conversion has finished. + SetVector<llvm::DbgVariableRecord *> dbgRecords; /// Mapping between LLVM alias scope and domain metadata nodes and /// attributes in the LLVM dialect corresponding to these nodes. DenseMap<const llvm::MDNode *, Attribute> aliasScopeMapping; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index eb7dfa7..039ac8e 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -512,6 +512,15 @@ llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &builder, ArrayRef<llvm::Value *> args = {}, ArrayRef<llvm::Type *> tys = {}); +/// Creates a call to an LLVM IR intrinsic function with the given return type +/// and arguments. If the intrinsic is overloaded, the function signature will +/// be automatically resolved based on the provided return type and argument +/// types. +llvm::CallInst *createIntrinsicCall(llvm::IRBuilderBase &builder, + llvm::Intrinsic::ID intrinsic, + llvm::Type *retTy, + ArrayRef<llvm::Value *> args); + /// Creates a call to a LLVM IR intrinsic defined by LLVM_IntrOpBase. This /// resolves the overloads, and maps mixed MLIR value and attribute arguments to /// LLVM values. diff --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h index 538ea7c..da74c50 100644 --- a/mlir/include/mlir/Tools/PDLL/AST/Types.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h @@ -22,17 +22,6 @@ class Operation; namespace ast { class Context; -namespace detail { -struct AttributeTypeStorage; -struct ConstraintTypeStorage; -struct OperationTypeStorage; -struct RangeTypeStorage; -struct RewriteTypeStorage; -struct TupleTypeStorage; -struct TypeTypeStorage; -struct ValueTypeStorage; -} // namespace detail - //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// @@ -100,6 +89,127 @@ inline raw_ostream &operator<<(raw_ostream &os, Type type) { } //===----------------------------------------------------------------------===// +// Type::Storage +//===----------------------------------------------------------------------===// + +struct Type::Storage : public StorageUniquer::BaseStorage { + Storage(TypeID typeID) : typeID(typeID) {} + + /// The type identifier for the derived type class. + TypeID typeID; +}; + +namespace detail { + +/// A utility CRTP base class that defines many of the necessary utilities for +/// defining a PDLL AST Type. +template <typename ConcreteT, typename KeyT = void> +struct TypeStorageBase : public Type::Storage { + using KeyTy = KeyT; + using Base = TypeStorageBase<ConcreteT, KeyT>; + TypeStorageBase(KeyTy key) + : Type::Storage(TypeID::get<ConcreteT>()), key(key) {} + + /// Construct an instance with the given storage allocator. + static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate<ConcreteT>()) ConcreteT(key); + } + + /// Utility methods required by the storage allocator. + bool operator==(const KeyTy &key) const { return this->key == key; } + + /// Return the key value of this storage class. + const KeyTy &getValue() const { return key; } + +protected: + KeyTy key; +}; +/// A specialization of the storage base for singleton types. +template <typename ConcreteT> +struct TypeStorageBase<ConcreteT, void> : public Type::Storage { + using Base = TypeStorageBase<ConcreteT, void>; + TypeStorageBase() : Type::Storage(TypeID::get<ConcreteT>()) {} +}; + +//===----------------------------------------------------------------------===// +// AttributeTypeStorage +//===----------------------------------------------------------------------===// + +struct AttributeTypeStorage : public TypeStorageBase<AttributeTypeStorage> {}; + +//===----------------------------------------------------------------------===// +// ConstraintTypeStorage +//===----------------------------------------------------------------------===// + +struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {}; + +//===----------------------------------------------------------------------===// +// OperationTypeStorage +//===----------------------------------------------------------------------===// + +struct OperationTypeStorage + : public TypeStorageBase<OperationTypeStorage, + std::pair<StringRef, const ods::Operation *>> { + using Base::Base; + + static OperationTypeStorage * + construct(StorageUniquer::StorageAllocator &alloc, + const std::pair<StringRef, const ods::Operation *> &key) { + return new (alloc.allocate<OperationTypeStorage>()) OperationTypeStorage( + std::make_pair(alloc.copyInto(key.first), key.second)); + } +}; + +//===----------------------------------------------------------------------===// +// RangeTypeStorage +//===----------------------------------------------------------------------===// + +struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> { + using Base::Base; +}; + +//===----------------------------------------------------------------------===// +// RewriteTypeStorage +//===----------------------------------------------------------------------===// + +struct RewriteTypeStorage : public TypeStorageBase<RewriteTypeStorage> {}; + +//===----------------------------------------------------------------------===// +// TupleTypeStorage +//===----------------------------------------------------------------------===// + +struct TupleTypeStorage + : public TypeStorageBase<TupleTypeStorage, + std::pair<ArrayRef<Type>, ArrayRef<StringRef>>> { + using Base::Base; + + static TupleTypeStorage * + construct(StorageUniquer::StorageAllocator &alloc, + std::pair<ArrayRef<Type>, ArrayRef<StringRef>> key) { + SmallVector<StringRef> names = llvm::to_vector(llvm::map_range( + key.second, [&](StringRef name) { return alloc.copyInto(name); })); + return new (alloc.allocate<TupleTypeStorage>()) + TupleTypeStorage(std::make_pair(alloc.copyInto(key.first), + alloc.copyInto(llvm::ArrayRef(names)))); + } +}; + +//===----------------------------------------------------------------------===// +// TypeTypeStorage +//===----------------------------------------------------------------------===// + +struct TypeTypeStorage : public TypeStorageBase<TypeTypeStorage> {}; + +//===----------------------------------------------------------------------===// +// ValueTypeStorage +//===----------------------------------------------------------------------===// + +struct ValueTypeStorage : public TypeStorageBase<ValueTypeStorage> {}; + +} // namespace detail + +//===----------------------------------------------------------------------===// // AttributeType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h index b739438..79dfd7a 100644 --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -359,6 +359,20 @@ protected: /// the loaded IR. using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>; +/// Register basic command line options. +/// - toolName is used for the header displayed by `--help`. +/// - registry should contain all the dialects that can be parsed in the source. +/// - return std::string for help header. +std::string registerCLIOptions(llvm::StringRef toolName, + DialectRegistry ®istry); + +/// Parse command line options. +/// - helpHeader is used for the header displayed by `--help`. +/// - return std::pair<std::string, std::string> for +/// inputFilename and outputFilename command line option values. +std::pair<std::string, std::string> parseCLIOptions(int argc, char **argv, + llvm::StringRef helpHeader); + /// Register and parse command line options. /// - toolName is used for the header displayed by `--help`. /// - registry should contain all the dialects that can be parsed in the source. diff --git a/mlir/include/mlir/Transforms/CMakeLists.txt b/mlir/include/mlir/Transforms/CMakeLists.txt index 5fa52b2..1b57a34 100644 --- a/mlir/include/mlir/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Transforms/CMakeLists.txt @@ -5,4 +5,8 @@ mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header --prefix Transforms) mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl --prefix Transforms) add_mlir_dialect_tablegen_target(MLIRTransformsPassIncGen) +set(LLVM_TARGET_DEFINITIONS DialectInlinerInterface.td) +mlir_tablegen(DialectInlinerInterface.h.inc -gen-dialect-interface-decls) +add_mlir_dialect_tablegen_target(MLIRTransformsDialectInterfaceIncGen) + add_mlir_doc(Passes GeneralPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5ac9e26..9f44908 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -903,6 +903,27 @@ public: replaceAllUsesWith(from, ValueRange{to}); } + /// Replace the uses of `from` with `to` for which the `functor` returns + /// "true". The conversion driver will try to reconcile all type mismatches + /// that still exist at the end of the conversion with materializations. + /// This function supports both 1:1 and 1:N replacements. + /// + /// Note: The functor is also applied to builtin.unrealized_conversion_cast + /// ops that may have been inserted by the conversion driver. Some uses may + /// have been wrapped in unrealized_conversion_cast ops due to type changes. + /// + /// Note: This function is not supported in rollback mode. Calling it in + /// rollback mode will trigger an assertion. Furthermore, the + /// `allUsesReplaced` flag is not supported yet. + void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr) override { + replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced); + } + void replaceUsesWithIf(Value from, ValueRange to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr); + /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case /// of failure, the remapped value otherwise. diff --git a/mlir/include/mlir/Transforms/DialectInlinerInterface.td b/mlir/include/mlir/Transforms/DialectInlinerInterface.td new file mode 100644 index 0000000..0975b84 --- /dev/null +++ b/mlir/include/mlir/Transforms/DialectInlinerInterface.td @@ -0,0 +1,196 @@ +#ifndef MLIR_INTERFACES_DIALECTINLINERINTERFACE +#define MLIR_INTERFACES_DIALECTINLINERINTERFACE + +include "mlir/IR/Interfaces.td" + +def DialectInlinerInterface : DialectInterface<"DialectInlinerInterface"> { + let description = [{ + This is the interface that must be implemented by the dialects of operations + to be inlined. This interface should only handle the operations of the + given dialect. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns true if the given operation 'callable', that implements the + 'CallableOpInterface', can be inlined into the position given call + operation 'call', that is registered to the current dialect and implements + the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the + given 'callable' is set to be cloned during the inlining process, or false + if the region is set to be moved in-place(i.e. no duplicates would be + created). + }], + "bool", "isLegalToInline", + (ins "::mlir::Operation *":$call, "::mlir::Operation *":$callable, + "bool":$wouldBeCloned), + [{ + return false; + }] + >, + InterfaceMethod<[{ + Returns true if the given region 'src' can be inlined into the region + 'dest' that is attached to an operation registered to the current dialect. + 'wouldBeCloned' is set to true if the given 'src' region is set to be + cloned during the inlining process, or false if the region is set to be + moved in-place (i.e. no duplicates would be created). 'valueMapping' + contains any remapped values from within the 'src' region. This can be + used to examine what values will replace entry arguments into the 'src' + region for example. + }], + "bool", "isLegalToInline", + (ins "::mlir::Region *":$dest, "::mlir::Region *":$src, "bool":$wouldBeCloned, + "::mlir::IRMapping &":$valueMapping), + [{ + return false; + }] + >, + InterfaceMethod<[{ + Returns true if the given region 'src' can be inlined into the region + 'dest' that is attached to an operation registered to the current dialect. + 'wouldBeCloned' is set to true if the given 'src' region is set to be + cloned during the inlining process, or false if the region is set to be + moved in-place(i.e. no duplicates would be created). 'valueMapping' + contains any remapped values from within the 'src' region. This can be + used to examine what values will replace entry arguments into the 'src' + region for example. + }], + "bool", "isLegalToInline", + (ins "::mlir::Operation *":$op, "::mlir::Region *":$dest, + "bool":$wouldBeCloned, "::mlir::IRMapping &":$valueMapping), + [{ + return false; + }] + >, + InterfaceMethod<[{ + This hook is invoked on an operation that contains regions. It should + return true if the analyzer should recurse within the regions of this + operation when computing legality and cost, false otherwise. The default + implementation returns true. + }], + "bool", "shouldAnalyzeRecursively", + (ins "::mlir::Operation *":$op), + [{ + return true; + }] + >, + InterfaceMethod<[{ + Handle the given inlined terminator by replacing it with a new operation + as necessary. This overload is called when the inlined region has more + than one block. The 'newDest' block represents the new final branching + destination of blocks within this region, i.e. operations that release + control to the parent operation will likely now branch to this block. + Its block arguments correspond to any values that need to be replaced by + terminators within the inlined region. + }], + "void", "handleTerminator", + (ins "::mlir::Operation *":$op, "::mlir::Block *":$newDest), + [{ + llvm_unreachable("must implement handleTerminator in the case of multiple " + "inlined blocks"); + }] + >, + InterfaceMethod<[{ + Handle the given inlined terminator by replacing it with a new operation + as necessary. This overload is called when the inlined region only + contains one block. 'valuesToReplace' contains the previously returned + values of the call site before inlining. These values must be replaced by + this callback if they had any users (for example for traditional function + calls, these are directly replaced with the operands of the `return` + operation). The given 'op' will be removed by the caller, after this + function has been called. + }], + "void", "handleTerminator", + (ins "::mlir::Operation *":$op, "::mlir::ValueRange":$valuesToReplace), + [{ + llvm_unreachable( + "must implement handleTerminator in the case of one inlined block"); + }] + >, + InterfaceMethod<[{ + Attempt to materialize a conversion for a type mismatch between a call + from this dialect, and a callable region. This method should generate an + operation that takes 'input' as the only operand, and produces a single + result of 'resultType'. If a conversion can not be generated, nullptr + should be returned. For example, this hook may be invoked in the following + scenarios: + + ```mlir + func @foo(i32) -> i32 { ... } + + // Mismatched input operand ... = foo.call @foo(%input : i16) -> i32 + + // Mismatched result type. + ... = foo.call @foo(%input : i32) -> i16 + ``` + + NOTE: This hook may be invoked before the 'isLegal' checks above. + }], + "::mlir::Operation *", "materializeCallConversion", + (ins "::mlir::OpBuilder &":$builder, "::mlir::Value":$input, + "::mlir::Type":$resultType, "::mlir::Location":$conversionLoc), + [{ + return nullptr; + }] + >, + InterfaceMethod<[{ + Hook to transform the call arguments before using them to replace the + callee arguments. Returns a value of the same type or the `argument` + itself if nothing changed. The `argumentAttrs` dictionary is non-null even + if no attribute is present. The hook is called after converting the + callsite argument types using the materializeCallConversion callback, and + right before inlining the callee region. Any operations created using the + provided `builder` are inserted right before the inlined callee region. An + example use case is the insertion of copies for by value arguments. + }], + "::mlir::Value", "handleArgument", + (ins "::mlir::OpBuilder &":$builder, "::mlir::Operation *":$call, + "::mlir::Operation *":$callable, "::mlir::Value":$argument, + "::mlir::DictionaryAttr":$argumentAttrs), + [{ + return argument; + }] + >, + InterfaceMethod<[{ + Hook to transform the callee results before using them to replace the call + results. Returns a value of the same type or the `result` itself if + nothing changed. The `resultAttrs` dictionary is non-null even if no + attribute is present. The hook is called right before handling + terminators, and obtains the callee result before converting its type + using the `materializeCallConversion` callback. Any operations created + using the provided `builder` are inserted right after the inlined callee + region. An example use case is the insertion of copies for by value + results. NOTE: This hook is invoked after inlining the `callable` region. + }], + "::mlir::Value", "handleResult", + (ins "::mlir::OpBuilder &":$builder, "::mlir::Operation *":$call, + "::mlir::Operation *":$callable, "::mlir::Value":$result, + "::mlir::DictionaryAttr":$resultAttrs), + [{ + return result; + }] + >, + InterfaceMethod<[{ + Process a set of blocks that have been inlined for a call. This callback + is invoked before inlined terminator operations have been processed. + }], + "void", "processInlinedCallBlocks", + (ins "::mlir::Operation *":$call, + "::mlir::iterator_range<::mlir::Region::iterator>":$inlinedBlocks), + [{}] + >, + InterfaceMethod<[{ + Returns true if the inliner can assume a fast path of not creating a new + block, if there is only one block. + }], + "bool", "allowSingleBlockOptimization", + (ins "::mlir::iterator_range<::mlir::Region::iterator>":$inlinedBlocks), + [{ + return true; + }] + > + ]; +} + + +#endif diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h index ed6413d..b6c6da3 100644 --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -32,158 +32,7 @@ class Region; class TypeRange; class Value; class ValueRange; - -//===----------------------------------------------------------------------===// -// InlinerInterface -//===----------------------------------------------------------------------===// - -/// This is the interface that must be implemented by the dialects of operations -/// to be inlined. This interface should only handle the operations of the -/// given dialect. -class DialectInlinerInterface - : public DialectInterface::Base<DialectInlinerInterface> { -public: - DialectInlinerInterface(Dialect *dialect) : Base(dialect) {} - - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - /// Returns true if the given operation 'callable', that implements the - /// 'CallableOpInterface', can be inlined into the position given call - /// operation 'call', that is registered to the current dialect and implements - /// the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the - /// given 'callable' is set to be cloned during the inlining process, or false - /// if the region is set to be moved in-place(i.e. no duplicates would be - /// created). - virtual bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const { - return false; - } - - /// Returns true if the given region 'src' can be inlined into the region - /// 'dest' that is attached to an operation registered to the current dialect. - /// 'wouldBeCloned' is set to true if the given 'src' region is set to be - /// cloned during the inlining process, or false if the region is set to be - /// moved in-place(i.e. no duplicates would be created). 'valueMapping' - /// contains any remapped values from within the 'src' region. This can be - /// used to examine what values will replace entry arguments into the 'src' - /// region for example. - virtual bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - IRMapping &valueMapping) const { - return false; - } - - /// Returns true if the given operation 'op', that is registered to this - /// dialect, can be inlined into the given region, false otherwise. - /// 'wouldBeCloned' is set to true if the given 'op' is set to be cloned - /// during the inlining process, or false if the operation is set to be moved - /// in-place(i.e. no duplicates would be created). 'valueMapping' contains any - /// remapped values from within the 'src' region. This can be used to examine - /// what values may potentially replace the operands to 'op'. - virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, - IRMapping &valueMapping) const { - return false; - } - - /// This hook is invoked on an operation that contains regions. It should - /// return true if the analyzer should recurse within the regions of this - /// operation when computing legality and cost, false otherwise. The default - /// implementation returns true. - virtual bool shouldAnalyzeRecursively(Operation *op) const { return true; } - - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. This overload is called when the inlined region has more - /// than one block. The 'newDest' block represents the new final branching - /// destination of blocks within this region, i.e. operations that release - /// control to the parent operation will likely now branch to this block. - /// Its block arguments correspond to any values that need to be replaced by - /// terminators within the inlined region. - virtual void handleTerminator(Operation *op, Block *newDest) const { - llvm_unreachable("must implement handleTerminator in the case of multiple " - "inlined blocks"); - } - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. This overload is called when the inlined region only - /// contains one block. 'valuesToReplace' contains the previously returned - /// values of the call site before inlining. These values must be replaced by - /// this callback if they had any users (for example for traditional function - /// calls, these are directly replaced with the operands of the `return` - /// operation). The given 'op' will be removed by the caller, after this - /// function has been called. - virtual void handleTerminator(Operation *op, - ValueRange valuesToReplace) const { - llvm_unreachable( - "must implement handleTerminator in the case of one inlined block"); - } - - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. For example, this hook may be invoked in the following - /// scenarios: - /// func @foo(i32) -> i32 { ... } - /// - /// // Mismatched input operand - /// ... = foo.call @foo(%input : i16) -> i32 - /// - /// // Mismatched result type. - /// ... = foo.call @foo(%input : i32) -> i16 - /// - /// NOTE: This hook may be invoked before the 'isLegal' checks above. - virtual Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const { - return nullptr; - } - - /// Hook to transform the call arguments before using them to replace the - /// callee arguments. Returns a value of the same type or the `argument` - /// itself if nothing changed. The `argumentAttrs` dictionary is non-null even - /// if no attribute is present. The hook is called after converting the - /// callsite argument types using the materializeCallConversion callback, and - /// right before inlining the callee region. Any operations created using the - /// provided `builder` are inserted right before the inlined callee region. An - /// example use case is the insertion of copies for by value arguments. - virtual Value handleArgument(OpBuilder &builder, Operation *call, - Operation *callable, Value argument, - DictionaryAttr argumentAttrs) const { - return argument; - } - - /// Hook to transform the callee results before using them to replace the call - /// results. Returns a value of the same type or the `result` itself if - /// nothing changed. The `resultAttrs` dictionary is non-null even if no - /// attribute is present. The hook is called right before handling - /// terminators, and obtains the callee result before converting its type - /// using the `materializeCallConversion` callback. Any operations created - /// using the provided `builder` are inserted right after the inlined callee - /// region. An example use case is the insertion of copies for by value - /// results. NOTE: This hook is invoked after inlining the `callable` region. - virtual Value handleResult(OpBuilder &builder, Operation *call, - Operation *callable, Value result, - DictionaryAttr resultAttrs) const { - return result; - } - - /// Process a set of blocks that have been inlined for a call. This callback - /// is invoked before inlined terminator operations have been processed. - virtual void processInlinedCallBlocks( - Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {} - - /// Returns true if the inliner can assume a fast path of not creating a new - /// block, if there is only one block. - virtual bool allowSingleBlockOptimization( - iterator_range<Region::iterator> inlinedBlocks) const { - return true; - } -}; +class DialectInlinerInterface; /// This interface provides the hooks into the inlining interface. /// Note: this class automatically collects 'DialectInlinerInterface' objects @@ -307,4 +156,6 @@ inlineCall(InlinerInterface &interface, } // namespace mlir +#include "mlir/Transforms/DialectInlinerInterface.h.inc" + #endif // MLIR_TRANSFORMS_INLININGUTILS_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 17c323a..724da00 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -31,23 +31,23 @@ class GreedyRewriteConfig; // Passes //===----------------------------------------------------------------------===// +#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS +#define GEN_PASS_DECL_CSE #define GEN_PASS_DECL_CANONICALIZER +#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS #define GEN_PASS_DECL_CONTROLFLOWSINK -#define GEN_PASS_DECL_CSE -#define GEN_PASS_DECL_INLINER +#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION +#define GEN_PASS_DECL_INLINER #define GEN_PASS_DECL_MEM2REG #define GEN_PASS_DECL_PRINTIRPASS #define GEN_PASS_DECL_PRINTOPSTATS +#define GEN_PASS_DECL_SCCP #define GEN_PASS_DECL_SROA #define GEN_PASS_DECL_STRIPDEBUGINFO -#define GEN_PASS_DECL_SCCP #define GEN_PASS_DECL_SYMBOLDCE #define GEN_PASS_DECL_SYMBOLPRIVATIZE #define GEN_PASS_DECL_TOPOLOGICALSORT -#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS -#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS -#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION #include "mlir/Transforms/Passes.h.inc" /// Creates an instance of the Canonicalizer pass, configured with default diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 28b4a01..55addfd 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -248,6 +248,7 @@ def RemoveDeadValues : Pass<"remove-dead-values"> { ``` }]; let constructor = "mlir::createRemoveDeadValuesPass()"; + let dependentDialects = ["ub::UBDialect"]; } def PrintIRPass : Pass<"print-ir"> { diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 2ed96af..daf4373 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -84,7 +84,8 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, /// Move definitions of `values` before an insertion point. Current support is /// only for movement of definitions within the same basic block. Note that this /// is an all-or-nothing approach. Either definitions of all values are moved -/// before insertion point, or none of them are. +/// before insertion point, or none of them are. Any side-effecting operations +/// in the producer chain pessimistically blocks movement. LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, Operation *insertionPoint, DominanceInfo &dominance); diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 70b56ca..a93e605 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -180,23 +180,20 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( return; } - /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() - /// on a LoopLikeInterface return the lower/upper bound for that result if - /// possible. - auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, - Type boundType, Block *block, bool getUpper) { + /// Given a lower bound, upper bound, or step from a LoopLikeInterface return + /// the lower/upper bound for that result if possible. + auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType, + Block *block, bool getUpper) { unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); - if (loopBound.has_value()) { - if (auto attr = dyn_cast<Attribute>(*loopBound)) { - if (auto bound = dyn_cast_or_null<IntegerAttr>(attr)) - return bound.getValue(); - } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) { - const IntegerValueRangeLattice *lattice = - getLatticeElementFor(getProgramPointBefore(block), value); - if (lattice != nullptr && !lattice->getValue().isUninitialized()) - return getUpper ? lattice->getValue().getValue().smax() - : lattice->getValue().getValue().smin(); - } + if (auto attr = dyn_cast<Attribute>(loopBound)) { + if (auto bound = dyn_cast<IntegerAttr>(attr)) + return bound.getValue(); + } else if (auto value = llvm::dyn_cast<Value>(loopBound)) { + const IntegerValueRangeLattice *lattice = + getLatticeElementFor(getProgramPointBefore(block), value); + if (lattice != nullptr && !lattice->getValue().isUninitialized()) + return getUpper ? lattice->getValue().getValue().smax() + : lattice->getValue().getValue().smin(); } // Given the results of getConstant{Lower,Upper}Bound() // or getConstantStep() on a LoopLikeInterface return the lower/upper @@ -207,38 +204,43 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( // Infer bounds for loop arguments that have static bounds if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { - std::optional<Value> iv = loop.getSingleInductionVar(); - if (!iv) { + std::optional<llvm::SmallVector<Value>> maybeIvs = + loop.getLoopInductionVars(); + if (!maybeIvs) { return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( op, successor, argLattices, firstIndex); } - Block *block = iv->getParentBlock(); - std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); - std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); - std::optional<OpFoldResult> step = loop.getSingleStep(); - APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block, - /*getUpper=*/false); - APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block, - /*getUpper=*/true); - // Assume positivity for uniscoverable steps by way of getUpper = true. - APInt stepVal = - getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true); - - if (stepVal.isNegative()) { - std::swap(min, max); - } else { - // Correct the upper bound by subtracting 1 so that it becomes a <= - // bound, because loops do not generally include their upper bound. - max -= 1; - } + // This shouldn't be returning nullopt if there are indunction variables. + SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds(); + SmallVector<OpFoldResult> upperBounds = *loop.getLoopUpperBounds(); + SmallVector<OpFoldResult> steps = *loop.getLoopSteps(); + for (auto [iv, lowerBound, upperBound, step] : + llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) { + Block *block = iv.getParentBlock(); + APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block, + /*getUpper=*/false); + APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block, + /*getUpper=*/true); + // Assume positivity for uniscoverable steps by way of getUpper = true. + APInt stepVal = + getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true); + + if (stepVal.isNegative()) { + std::swap(min, max); + } else { + // Correct the upper bound by subtracting 1 so that it becomes a <= + // bound, because loops do not generally include their upper bound. + max -= 1; + } - // If we infer the lower bound to be larger than the upper bound, the - // resulting range is meaningless and should not be used in further - // inferences. - if (max.sge(min)) { - IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); - auto ivRange = ConstantIntRanges::fromSigned(min, max); - propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); + // If we infer the lower bound to be larger than the upper bound, the + // resulting range is meaningless and should not be used in further + // inferences. + if (max.sge(min)) { + IntegerValueRangeLattice *ivEntry = getLatticeElement(iv); + auto ivRange = ConstantIntRanges::fromSigned(min, max); + propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); + } } return; } diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp index 75d592e..c31b277 100644 --- a/mlir/lib/Analysis/Presburger/Barvinok.cpp +++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp @@ -178,13 +178,13 @@ mlir::presburger::detail::solveParametricEquations(FracMatrix equations) { for (unsigned i = 0; i < d; ++i) { // First ensure that the diagonal element is nonzero, by swapping // it with a row that is non-zero at column i. - if (equations(i, i) != 0) - continue; - for (unsigned j = i + 1; j < d; ++j) { - if (equations(j, i) == 0) - continue; - equations.swapRows(j, i); - break; + if (equations(i, i) == 0) { + for (unsigned j = i + 1; j < d; ++j) { + if (equations(j, i) == 0) + continue; + equations.swapRows(j, i); + break; + } } Fraction diagElement = equations(i, i); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 812043d..26197ce 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -21,6 +21,7 @@ #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallBitVector.h" @@ -442,6 +443,14 @@ void IntegerRelation::removeInequality(unsigned pos) { inequalities.removeRow(pos); } +void IntegerRelation::removeConstraint(unsigned pos) { + if (pos >= getNumInequalities()) { + removeEquality(pos - getNumInequalities()); + } else { + removeInequality(pos); + } +} + void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) { if (start >= end) return; @@ -1112,15 +1121,29 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart, return posLimit - posStart; } +static std::optional<unsigned> +findEqualityWithNonZeroAfterRow(IntegerRelation &rel, unsigned fromRow, + unsigned colIdx) { + assert(fromRow < rel.getNumEqualities() && colIdx < rel.getNumCols() && + "position out of bounds"); + for (unsigned rowIdx = fromRow, e = rel.getNumEqualities(); rowIdx < e; + ++rowIdx) { + if (rel.atEq(rowIdx, colIdx) != 0) + return rowIdx; + } + return std::nullopt; +} + bool IntegerRelation::gaussianEliminate() { gcdTightenInequalities(); unsigned firstVar = 0, vars = getNumVars(); unsigned nowDone, eqs; std::optional<unsigned> pivotRow; for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) { - // Finds the first non-empty column. + // Finds the first non-empty column that we haven't dealt with. for (; firstVar < vars; ++firstVar) { - if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true))) + if ((pivotRow = + findEqualityWithNonZeroAfterRow(*this, nowDone, firstVar))) break; } // The matrix has been normalized to row echelon form. @@ -1143,6 +1166,10 @@ bool IntegerRelation::gaussianEliminate() { inequalities.normalizeRow(i); } gcdTightenInequalities(); + + // The column is finished. Tell the next iteration to start at the next + // column. + firstVar++; } // No redundant rows. @@ -1724,12 +1751,64 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize( return minDiff; } +void IntegerRelation::pruneOrthogonalConstraints(unsigned pos) { + llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows; + + // Early exit if constraints is empty. + unsigned numConstraints = getNumConstraints(); + if (numConstraints == 0) + return; + + llvm::SmallVector<unsigned> rowStack, colStack({pos}); + // The following code performs a graph traversal, starting from the target + // variable, to identify all variables(recorded in relatedCols) and + // constraints (recorded in relatedRows) belonging to the same connected + // component. + while (!rowStack.empty() || !colStack.empty()) { + if (!rowStack.empty()) { + unsigned currentRow = rowStack.pop_back_val(); + // Push all variable that accociated to this constraints to relatedCols + // and colStack. + for (unsigned colIndex = 0; colIndex < getNumVars(); ++colIndex) { + if (atConstraint(currentRow, colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } + } else { + unsigned currentCol = colStack.pop_back_val(); + // Push all constraints that are associated with this variable to related + // rows and the row stack. + for (unsigned rowIndex = 0; rowIndex < numConstraints; ++rowIndex) { + if (atConstraint(rowIndex, currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } + } + } + + // Prune all constraints not related to target variable. + for (int constraintId = numConstraints - 1; constraintId >= 0; + --constraintId) { + if (!relatedRows.contains(constraintId)) + removeConstraint((unsigned)constraintId); + } +} + template <bool isLower> std::optional<DynamicAPInt> IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { assert(pos < getNumVars() && "invalid position"); // Project to 'pos'. + // Prune orthogonal constraints to reduce unnecessary computations and + // accelerate the bound computation. + pruneOrthogonalConstraints(pos); projectOut(0, pos); + + // After projecting out values, more orthogonal constraints may be exposed. + // Prune these orthogonal constraints again. + pruneOrthogonalConstraints(0); projectOut(1, getNumVars() - 1); // Check if there's an equality equating the '0'^th variable to a constant. int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false); @@ -2265,11 +2344,11 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { newLb[d] = lbFloorDivisor; newUb[d] = -lbFloorDivisor; // Copy over the symbolic part + constant term. - std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars()); + llvm::copy(minLb, newLb.begin() + getNumDimVars()); std::transform(newLb.begin() + getNumDimVars(), newLb.end(), newLb.begin() + getNumDimVars(), std::negate<DynamicAPInt>()); - std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars()); + llvm::copy(maxUb, newUb.begin() + getNumDimVars()); boundingLbs.emplace_back(newLb); boundingUbs.emplace_back(newUb); diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index bb60564..83a2c28 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -255,20 +255,13 @@ void Matrix<T>::fillRow(unsigned row, const T &value) { } // moveColumns is implemented by moving the columns adjacent to the source range -// to their final position. When moving right (i.e. dstPos > srcPos), the range -// of the adjacent columns is [srcPos + num, dstPos + num). When moving left -// (i.e. dstPos < srcPos) the range of the adjacent columns is [dstPos, srcPos). -// First, zeroed out columns are inserted in the final positions of the adjacent -// columns. Then, the adjacent columns are moved to their final positions by -// swapping them with the zeroed columns. Finally, the now zeroed adjacent -// columns are deleted. +// to their final position. template <typename T> void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) { if (num == 0) return; - int offset = dstPos - srcPos; - if (offset == 0) + if (dstPos == srcPos) return; assert(srcPos + num <= getNumColumns() && @@ -276,23 +269,19 @@ void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) { assert(dstPos + num <= getNumColumns() && "move destination range exceeds matrix columns"); - unsigned insertCount = offset > 0 ? offset : -offset; - unsigned finalAdjStart = offset > 0 ? srcPos : srcPos + num; - unsigned curAdjStart = offset > 0 ? srcPos + num : dstPos; - // TODO: This can be done using std::rotate. - // Insert new zero columns in the positions where the adjacent columns are to - // be moved. - insertColumns(finalAdjStart, insertCount); - // Update curAdjStart if insertion of new columns invalidates it. - if (finalAdjStart < curAdjStart) - curAdjStart += insertCount; - - // Swap the adjacent columns with inserted zero columns. - for (unsigned i = 0; i < insertCount; ++i) - swapColumns(finalAdjStart + i, curAdjStart + i); - - // Delete the now redundant zero columns. - removeColumns(curAdjStart, insertCount); + unsigned numRows = getNumRows(); + // std::rotate(start, middle, end) permutes the elements of [start, end] to + // [middle, end) + [start, middle). NOTE: &at(i, srcPos + num) will trigger an + // assert. + if (dstPos > srcPos) { + for (unsigned i = 0; i < numRows; ++i) { + std::rotate(&at(i, srcPos), &at(i, srcPos) + num, &at(i, dstPos) + num); + } + return; + } + for (unsigned i = 0; i < numRows; ++i) { + std::rotate(&at(i, dstPos), &at(i, srcPos), &at(i, srcPos) + num); + } } template <typename T> diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 870a713..05681ce 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -31,8 +31,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // StructType //===--------------------------------------------------------------------===// - auto llvmStructType = - mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); + auto llvmStructType = mlir_type_subclass( + m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID); llvmStructType .def_classmethod( @@ -137,7 +137,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // PointerType //===--------------------------------------------------------------------===// - mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) + mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType, + mlirLLVMPointerTypeGetTypeID) .def_classmethod( "get", [](const nb::object &cls, std::optional<unsigned> addressSpace, diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 0155023..0b079b4 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { "op.", nb::arg("op")); + m.def( + "infer_contraction_dimensions_from_maps", + [](std::vector<MlirAffineMap> indexingMaps) + -> std::optional<MlirLinalgContractionDimensions> { + if (indexingMaps.empty()) + return std::nullopt; + + MlirLinalgContractionDimensions dims = + mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(), + indexingMaps.size()); + + // Detect "empty" result from invalid input or failed inference. + if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && + mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { + return std::nullopt; + } + return dims; + }, + "Infers contraction dimensions (batch/m/n/k) from a list of affine " + "maps.", + nb::arg("indexing_maps")); + m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp, "Checks if the given operation is a Linalg convolution operation.", nb::arg("op")); diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 8bb493e..be0785b1 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -75,13 +75,13 @@ NB_MODULE(_mlirExecutionEngine, m) { "__init__", [](PyExecutionEngine &self, MlirModule module, int optLevel, const std::vector<std::string> &sharedLibPaths, - bool enableObjectDump) { + bool enableObjectDump, bool enablePIC) { llvm::SmallVector<MlirStringRef, 4> libPaths; for (const std::string &path : sharedLibPaths) libPaths.push_back({path.c_str(), path.length()}); - MlirExecutionEngine executionEngine = - mlirExecutionEngineCreate(module, optLevel, libPaths.size(), - libPaths.data(), enableObjectDump); + MlirExecutionEngine executionEngine = mlirExecutionEngineCreate( + module, optLevel, libPaths.size(), libPaths.data(), + enableObjectDump, enablePIC); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); @@ -89,7 +89,7 @@ NB_MODULE(_mlirExecutionEngine, m) { }, nb::arg("module"), nb::arg("opt_level") = 2, nb::arg("shared_libs") = nb::list(), - nb::arg("enable_object_dump") = true, + nb::arg("enable_object_dump") = true, nb::arg("enable_pic") = false, "Create a new ExecutionEngine instance for the given Module. The " "module must contain only dialects that can be translated to LLVM. " "Perform transformations and code generation at the optimization " diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cda4fe1..2e0c2b8 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -18,6 +18,7 @@ #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "nanobind/nanobind.h" +#include "nanobind/typing.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -32,33 +33,6 @@ using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; -//------------------------------------------------------------------------------ -// Docstrings (trivial, non-duplicated docstrings are included inline). -//------------------------------------------------------------------------------ - -static const char kContextParseTypeDocstring[] = - R"(Parses the assembly form of a type. - -Returns a Type object or raises an MLIRError if the type cannot be parsed. - -See also: https://mlir.llvm.org/docs/LangRef/#type-system -)"; - -static const char kContextGetCallSiteLocationDocstring[] = - R"(Gets a Location representing a caller and callsite)"; - -static const char kContextGetFileLocationDocstring[] = - R"(Gets a Location representing a file, line and column)"; - -static const char kContextGetFileRangeDocstring[] = - R"(Gets a Location representing a file, line and column range)"; - -static const char kContextGetFusedLocationDocstring[] = - R"(Gets a Location representing a fused location with optional metadata)"; - -static const char kContextGetNameLocationDocString[] = - R"(Gets a Location representing a named location with optional child location)"; - static const char kModuleParseDocstring[] = R"(Parses a module's assembly format from a string. @@ -67,132 +41,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; -static const char kModuleCAPICreate[] = - R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr). -Note this returns a new object BUT _clear_mlir_module(module) must be called to -prevent double-frees (of the underlying mlir::Module). -)"; - -static const char kOperationCreateDocstring[] = - R"(Creates a new operation. - -Args: - name: Operation name (e.g. "dialect.operation"). - results: Sequence of Type representing op result types. - attributes: Dict of str:Attribute. - successors: List of Block for the operation's successors. - regions: Number of regions to create. - location: A Location object (defaults to resolve from context manager). - ip: An InsertionPoint (defaults to resolve from context manager or set to - False to disable insertion, even with an insertion point set in the - context manager). - infer_type: Whether to infer result types. -Returns: - A new "detached" Operation object. Detached operations can be added - to blocks, which causes them to become "attached." -)"; - -static const char kOperationPrintDocstring[] = - R"(Prints the assembly form of the operation to a file like object. - -Args: - file: The file like object to write to. Defaults to sys.stdout. - binary: Whether to write bytes (True) or str (False). Defaults to False. - large_elements_limit: Whether to elide elements attributes above this - number of elements. Defaults to None (no limit). - large_resource_limit: Whether to elide resource attributes above this - number of characters. Defaults to None (no limit). If large_elements_limit - is set and this is None, the behavior will be to use large_elements_limit - as large_resource_limit. - enable_debug_info: Whether to print debug/location information. Defaults - to False. - pretty_debug_info: Whether to format debug information for easier reading - by a human (warning: the result is unparseable). - print_generic_op_form: Whether to print the generic assembly forms of all - ops. Defaults to False. - use_local_Scope: Whether to print in a way that is more optimized for - multi-threaded access but may not be consistent with how the overall - module prints. - assume_verified: By default, if not printing generic form, the verifier - will be run and if it fails, generic form will be printed with a comment - about failed verification. While a reasonable default for interactive use, - for systematic use, it is often better for the caller to verify explicitly - and report failures in a more robust fashion. Set this to True if doing this - in order to avoid running a redundant verification. If the IR is actually - invalid, behavior is undefined. - skip_regions: Whether to skip printing regions. Defaults to False. -)"; - -static const char kOperationPrintStateDocstring[] = - R"(Prints the assembly form of the operation to a file like object. - -Args: - file: The file like object to write to. Defaults to sys.stdout. - binary: Whether to write bytes (True) or str (False). Defaults to False. - state: AsmState capturing the operation numbering and flags. -)"; - -static const char kOperationGetAsmDocstring[] = - R"(Gets the assembly form of the operation with all options available. - -Args: - binary: Whether to return a bytes (True) or str (False) object. Defaults to - False. - ... others ...: See the print() method for common keyword arguments for - configuring the printout. -Returns: - Either a bytes or str object, depending on the setting of the 'binary' - argument. -)"; - -static const char kOperationPrintBytecodeDocstring[] = - R"(Write the bytecode form of the operation to a file like object. - -Args: - file: The file like object to write to. - desired_version: The version of bytecode to emit. -Returns: - The bytecode writer status. -)"; - -static const char kOperationStrDunderDocstring[] = - R"(Gets the assembly form of the operation with default options. - -If more advanced control over the assembly formatting or I/O options is needed, -use the dedicated print or get_asm method, which supports keyword arguments to -customize behavior. -)"; - static const char kDumpDocstring[] = - R"(Dumps a debug representation of the object to stderr.)"; - -static const char kAppendBlockDocstring[] = - R"(Appends a new block, with argument types as positional args. - -Returns: - The created block. -)"; - -static const char kValueDunderStrDocstring[] = - R"(Returns the string form of the value. - -If the value is a block argument, this is the assembly form of its type and the -position in the argument list. If the value is an operation result, this is -equivalent to printing the operation that produced it. -)"; - -static const char kGetNameAsOperand[] = - R"(Returns the string form of value as an operand (i.e., the ValueID). -)"; - -static const char kValueReplaceAllUsesWithDocstring[] = - R"(Replace all uses of value with the new value, updating anything in -the IR that uses 'self' to use the other value instead. -)"; + "Dumps a debug representation of the object to stderr."; static const char kValueReplaceAllUsesExceptDocstring[] = - R"("Replace all uses of this value with the 'with' value, except for those -in 'exceptions'. 'exceptions' can be either a single operation or a list of + R"(Replace all uses of this value with the `with` value, except for those +in `exceptions`. `exceptions` can be either a single operation or a list of operations. )"; @@ -274,22 +128,26 @@ struct PyGlobalDebugFlag { // Debug flags. nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + &PyGlobalDebugFlag::set, "LLVM-wide debug flag.") .def_static( "set_types", [](const std::string &type) { nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugType(type.c_str()); }, - "types"_a, "Sets specific debug types to be produced by LLVM") - .def_static("set_types", [](const std::vector<std::string> &types) { - std::vector<const char *> pointers; - pointers.reserve(types.size()); - for (const std::string &str : types) - pointers.push_back(str.c_str()); - nb::ft_lock_guard lock(mutex); - mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); - }); + "types"_a, "Sets specific debug types to be produced by LLVM.") + .def_static( + "set_types", + [](const std::vector<std::string> &types) { + std::vector<const char *> pointers; + pointers.reserve(types.size()); + for (const std::string &str : types) + pointers.push_back(str.c_str()); + nb::ft_lock_guard lock(mutex); + mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); + }, + "types"_a, + "Sets multiple specific debug types to be produced by LLVM."); } private: @@ -316,12 +174,18 @@ struct PyAttrBuilderMap { static void bind(nb::module_ &m) { nb::class_<PyAttrBuilderMap>(m, "AttrBuilder") - .def_static("contains", &PyAttrBuilderMap::dunderContains) - .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) + .def_static("contains", &PyAttrBuilderMap::dunderContains, + "attribute_kind"_a, + "Checks whether an attribute builder is registered for the " + "given attribute kind.") + .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed, + "attribute_kind"_a, + "Gets the registered attribute builder for the given " + "attribute kind.") .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, "Register an attribute builder for building MLIR " - "attributes from python values."); + "attributes from Python values."); } }; @@ -341,8 +205,8 @@ namespace { class PyRegionIterator { public: - PyRegionIterator(PyOperationRef operation) - : operation(std::move(operation)) {} + PyRegionIterator(PyOperationRef operation, int nextIndex) + : operation(std::move(operation)), nextIndex(nextIndex) {} PyRegionIterator &dunderIter() { return *this; } @@ -357,13 +221,15 @@ public: static void bind(nb::module_ &m) { nb::class_<PyRegionIterator>(m, "RegionIterator") - .def("__iter__", &PyRegionIterator::dunderIter) - .def("__next__", &PyRegionIterator::dunderNext); + .def("__iter__", &PyRegionIterator::dunderIter, + "Returns an iterator over the regions in the operation.") + .def("__next__", &PyRegionIterator::dunderNext, + "Returns the next region in the iteration."); } private: PyOperationRef operation; - int nextIndex = 0; + intptr_t nextIndex = 0; }; /// Regions of an op are fixed length and indexed numerically so are represented @@ -382,11 +248,12 @@ public: PyRegionIterator dunderIter() { operation->checkValid(); - return PyRegionIterator(operation); + return PyRegionIterator(operation, startIndex); } static void bindDerived(ClassTy &c) { - c.def("__iter__", &PyRegionList::dunderIter); + c.def("__iter__", &PyRegionList::dunderIter, + "Returns an iterator over the regions in the sequence."); } private: @@ -430,8 +297,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyBlockIterator>(m, "BlockIterator") - .def("__iter__", &PyBlockIterator::dunderIter) - .def("__next__", &PyBlockIterator::dunderNext); + .def("__iter__", &PyBlockIterator::dunderIter, + "Returns an iterator over the blocks in the operation's region.") + .def("__next__", &PyBlockIterator::dunderNext, + "Returns the next block in the iteration."); } private: @@ -493,10 +362,19 @@ public: static void bind(nb::module_ &m) { nb::class_<PyBlockList>(m, "BlockList") - .def("__getitem__", &PyBlockList::dunderGetItem) - .def("__iter__", &PyBlockList::dunderIter) - .def("__len__", &PyBlockList::dunderLen) - .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, + .def("__getitem__", &PyBlockList::dunderGetItem, + "Returns the block at the specified index.") + .def("__iter__", &PyBlockList::dunderIter, + "Returns an iterator over blocks in the operation's region.") + .def("__len__", &PyBlockList::dunderLen, + "Returns the number of blocks in the operation's region.") + .def("append", &PyBlockList::appendBlock, + R"( + Appends a new block, with argument types as positional args. + + Returns: + The created block. + )", nb::arg("args"), nb::kw_only(), nb::arg("arg_locs") = std::nullopt); } @@ -527,8 +405,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOperationIterator>(m, "OperationIterator") - .def("__iter__", &PyOperationIterator::dunderIter) - .def("__next__", &PyOperationIterator::dunderNext); + .def("__iter__", &PyOperationIterator::dunderIter, + "Returns an iterator over the operations in an operation's block.") + .def("__next__", &PyOperationIterator::dunderNext, + "Returns the next operation in the iteration."); } private: @@ -584,9 +464,12 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOperationList>(m, "OperationList") - .def("__getitem__", &PyOperationList::dunderGetItem) - .def("__iter__", &PyOperationList::dunderIter) - .def("__len__", &PyOperationList::dunderLen); + .def("__getitem__", &PyOperationList::dunderGetItem, + "Returns the operation at the specified index.") + .def("__iter__", &PyOperationList::dunderIter, + "Returns an iterator over operations in the list.") + .def("__len__", &PyOperationList::dunderLen, + "Returns the number of operations in the list."); } private: @@ -609,8 +492,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOpOperand>(m, "OpOperand") - .def_prop_ro("owner", &PyOpOperand::getOwner) - .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); + .def_prop_ro("owner", &PyOpOperand::getOwner, + "Returns the operation that owns this operand.") + .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber, + "Returns the operand number in the owning operation."); } private: @@ -634,8 +519,10 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOpOperandIterator>(m, "OpOperandIterator") - .def("__iter__", &PyOpOperandIterator::dunderIter) - .def("__next__", &PyOpOperandIterator::dunderNext); + .def("__iter__", &PyOpOperandIterator::dunderIter, + "Returns an iterator over operands.") + .def("__next__", &PyOpOperandIterator::dunderNext, + "Returns the next operand in the iteration."); } private: @@ -1524,9 +1411,10 @@ nb::object PyOperation::create(std::string_view name, } // Construct the operation. + PyMlirContext::ErrorCapture errors(location.getContext()); MlirOperation operation = mlirOperationCreate(&state); if (!operation.ptr) - throw nb::value_error("Operation creation failed"); + throw MLIRError("Operation creation failed", errors.take()); PyOperationRef created = PyOperation::createDetached(location.getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1596,7 +1484,11 @@ public: /// Binds the Python module objects to functions of this class. static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy( + m, DerivedTy::pyClassName, nb::is_generic(), + nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])") + .str() + .c_str())); cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")); cls.def_static( "isinstance", @@ -1626,16 +1518,21 @@ public: static void bindDerived(ClassTy &c) { c.def_prop_ro( - "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> { + "owner", + [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> { assert(mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in " "the IR"); return self.getParentOperation().getObject(); - }); - c.def_prop_ro("result_number", [](PyOpResult &self) { - return mlirOpResultGetResultNumber(self.get()); - }); + }, + "Returns the operation that produces this result."); + c.def_prop_ro( + "result_number", + [](PyOpResult &self) { + return mlirOpResultGetResultNumber(self.get()); + }, + "Returns the position of this result in the operation's result list."); } }; @@ -1671,13 +1568,18 @@ public: operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyOpResultList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - c.def_prop_ro("owner", - [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> { - return self.operation->createOpView(); - }); + c.def_prop_ro( + "types", + [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }, + "Returns a list of types for all results in this result list."); + c.def_prop_ro( + "owner", + [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> { + return self.operation->createOpView(); + }, + "Returns the operation that owns this result list."); } PyOperationRef &getOperation() { return operation; } @@ -2427,19 +2329,31 @@ public: using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyBlockArgument &self) { - return PyBlock(self.getParentOperation(), - mlirBlockArgumentGetOwner(self.get())); - }); - c.def_prop_ro("arg_number", [](PyBlockArgument &self) { - return mlirBlockArgumentGetArgNumber(self.get()); - }); + c.def_prop_ro( + "owner", + [](PyBlockArgument &self) { + return PyBlock(self.getParentOperation(), + mlirBlockArgumentGetOwner(self.get())); + }, + "Returns the block that owns this argument."); + c.def_prop_ro( + "arg_number", + [](PyBlockArgument &self) { + return mlirBlockArgumentGetArgNumber(self.get()); + }, + "Returns the position of this argument in the block's argument list."); c.def( "set_type", [](PyBlockArgument &self, PyType type) { return mlirBlockArgumentSetType(self.get(), type); }, - nb::arg("type")); + nb::arg("type"), "Sets the type of this block argument."); + c.def( + "set_location", + [](PyBlockArgument &self, PyLocation loc) { + return mlirBlockArgumentSetLocation(self.get(), loc); + }, + nb::arg("loc"), "Sets the location of this block argument."); } }; @@ -2462,9 +2376,12 @@ public: operation(std::move(operation)), block(block) {} static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyBlockArgumentList &self) { - return getValueTypes(self, self.operation->getContext()); - }); + c.def_prop_ro( + "types", + [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }, + "Returns a list of types for all arguments in this argument list."); } private: @@ -2516,7 +2433,9 @@ public: } static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpOperandList::dunderSetItem); + c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"), + nb::arg("value"), + "Sets the operand at the specified index to a new value."); } private: @@ -2571,7 +2490,8 @@ public: } static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpSuccessors::dunderSetItem); + c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"), + nb::arg("block"), "Sets the successor block at the specified index."); } private: @@ -2743,55 +2663,70 @@ public: static void bind(nb::module_ &m) { nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") - .def("__contains__", &PyOpAttributeMap::dunderContains) - .def("__len__", &PyOpAttributeMap::dunderLen) - .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) - .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) - .def("__setitem__", &PyOpAttributeMap::dunderSetItem) - .def("__delitem__", &PyOpAttributeMap::dunderDelItem) - .def("__iter__", - [](PyOpAttributeMap &self) { - nb::list keys; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute) { - keys.append(nb::str(name.data, name.length)); - }); - return nb::iter(keys); - }) - .def("keys", - [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute) { - out.append(nb::str(name.data, name.length)); - }); - return out; - }) - .def("values", - [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef, MlirAttribute attr) { - out.append(PyAttribute(self.operation->getContext(), attr) - .maybeDownCast()); - }); - return out; - }) - .def("items", [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute attr) { - out.append(nb::make_tuple( - nb::str(name.data, name.length), - PyAttribute(self.operation->getContext(), attr) - .maybeDownCast())); - }); - return out; - }); + .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"), + "Checks if an attribute with the given name exists in the map.") + .def("__len__", &PyOpAttributeMap::dunderLen, + "Returns the number of attributes in the map.") + .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, + nb::arg("name"), "Gets an attribute by name.") + .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, + nb::arg("index"), "Gets a named attribute by index.") + .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"), + nb::arg("attr"), "Sets an attribute with the given name.") + .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"), + "Deletes an attribute with the given name.") + .def( + "__iter__", + [](PyOpAttributeMap &self) { + nb::list keys; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + keys.append(nb::str(name.data, name.length)); + }); + return nb::iter(keys); + }, + "Iterates over attribute names.") + .def( + "keys", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + out.append(nb::str(name.data, name.length)); + }); + return out; + }, + "Returns a list of attribute names.") + .def( + "values", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef, MlirAttribute attr) { + out.append(PyAttribute(self.operation->getContext(), attr) + .maybeDownCast()); + }); + return out; + }, + "Returns a list of attribute values.") + .def( + "items", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute attr) { + out.append(nb::make_tuple( + nb::str(name.data, name.length), + PyAttribute(self.operation->getContext(), attr) + .maybeDownCast())); + }); + return out; + }, + "Returns a list of `(name, attribute)` tuples."); } private: @@ -2979,62 +2914,103 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Mapping of Diagnostics. //---------------------------------------------------------------------------- nb::class_<PyDiagnostic>(m, "Diagnostic") - .def_prop_ro("severity", &PyDiagnostic::getSeverity) - .def_prop_ro("location", &PyDiagnostic::getLocation) - .def_prop_ro("message", &PyDiagnostic::getMessage) - .def_prop_ro("notes", &PyDiagnostic::getNotes) - .def("__str__", [](PyDiagnostic &self) -> nb::str { - if (!self.isValid()) - return nb::str("<Invalid Diagnostic>"); - return self.getMessage(); - }); + .def_prop_ro("severity", &PyDiagnostic::getSeverity, + "Returns the severity of the diagnostic.") + .def_prop_ro("location", &PyDiagnostic::getLocation, + "Returns the location associated with the diagnostic.") + .def_prop_ro("message", &PyDiagnostic::getMessage, + "Returns the message text of the diagnostic.") + .def_prop_ro("notes", &PyDiagnostic::getNotes, + "Returns a tuple of attached note diagnostics.") + .def( + "__str__", + [](PyDiagnostic &self) -> nb::str { + if (!self.isValid()) + return nb::str("<Invalid Diagnostic>"); + return self.getMessage(); + }, + "Returns the diagnostic message as a string."); nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo") - .def("__init__", - [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { - new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); - }) - .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) - .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) - .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) - .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) - .def("__str__", - [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); + .def( + "__init__", + [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { + new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); + }, + "diag"_a, "Creates a DiagnosticInfo from a Diagnostic.") + .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity, + "The severity level of the diagnostic.") + .def_ro("location", &PyDiagnostic::DiagnosticInfo::location, + "The location associated with the diagnostic.") + .def_ro("message", &PyDiagnostic::DiagnosticInfo::message, + "The message text of the diagnostic.") + .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes, + "List of attached note diagnostics.") + .def( + "__str__", + [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }, + "Returns the diagnostic message as a string."); nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler") - .def("detach", &PyDiagnosticHandler::detach) - .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) - .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) - .def("__enter__", &PyDiagnosticHandler::contextEnter) + .def("detach", &PyDiagnosticHandler::detach, + "Detaches the diagnostic handler from the context.") + .def_prop_ro("attached", &PyDiagnosticHandler::isAttached, + "Returns True if the handler is attached to a context.") + .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError, + "Returns True if an error was encountered during diagnostic " + "handling.") + .def("__enter__", &PyDiagnosticHandler::contextEnter, + "Enters the diagnostic handler as a context manager.") .def("__exit__", &PyDiagnosticHandler::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none()); + nb::arg("traceback").none(), + "Exits the diagnostic handler context manager."); // Expose DefaultThreadPool to python nb::class_<PyThreadPool>(m, "ThreadPool") - .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) - .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) - .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); + .def( + "__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }, + "Creates a new thread pool with default concurrency.") + .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency, + "Returns the maximum number of threads in the pool.") + .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr, + "Returns the raw pointer to the LLVM thread pool as a string."); nb::class_<PyMlirContext>(m, "Context") - .def("__init__", - [](PyMlirContext &self) { - MlirContext context = mlirContextCreateWithThreading(false); - new (&self) PyMlirContext(context); - }) - .def_static("_get_live_count", &PyMlirContext::getLiveCount) - .def("_get_context_again", - [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> { - PyMlirContextRef ref = PyMlirContext::forContext(self.get()); - return ref.releaseObject(); - }) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) + .def( + "__init__", + [](PyMlirContext &self) { + MlirContext context = mlirContextCreateWithThreading(false); + new (&self) PyMlirContext(context); + }, + R"( + Creates a new MLIR context. + + The context is the top-level container for all MLIR objects. It owns the storage + for types, attributes, locations, and other core IR objects. A context can be + configured to allow or disallow unregistered dialects and can have dialects + loaded on-demand.)") + .def_static("_get_live_count", &PyMlirContext::getLiveCount, + "Gets the number of live Context objects.") + .def( + "_get_context_again", + [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> { + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); + }, + "Gets another reference to the same context.") + .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount, + "Gets the number of live modules owned by this context.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule, + "Gets a capsule wrapping the MlirContext.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyMlirContext::createFromCapsule) - .def("__enter__", &PyMlirContext::contextEnter) + &PyMlirContext::createFromCapsule, + "Creates a Context from a capsule wrapping MlirContext.") + .def("__enter__", &PyMlirContext::contextEnter, + "Enters the context as a context manager.") .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none()) + nb::arg("exc_value").none(), nb::arg("traceback").none(), + "Exits the context manager.") .def_prop_ro_static( "current", [](nb::object & /*class*/) @@ -3045,14 +3021,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { return nb::cast(context); }, nb::sig("def current(/) -> Context | None"), - "Gets the Context bound to the current thread or raises ValueError") + "Gets the Context bound to the current thread or returns None if no " + "context is set.") .def_prop_ro( "dialects", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Gets a container for accessing dialects by name") + "Gets a container for accessing dialects by name.") .def_prop_ro( "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Alias for 'dialect'") + "Alias for `dialects`.") .def( "get_dialect_descriptor", [=](PyMlirContext &self, std::string &name) { @@ -3065,7 +3042,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyDialectDescriptor(self.getRef(), dialect); }, nb::arg("dialect_name"), - "Gets or loads a dialect by name, returning its descriptor object") + "Gets or loads a dialect by name, returning its descriptor object.") .def_prop_rw( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { @@ -3073,67 +3050,110 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); - }) + }, + "Controls whether unregistered dialects are allowed in this context.") .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, nb::arg("callback"), - "Attaches a diagnostic handler that will receive callbacks") + "Attaches a diagnostic handler that will receive callbacks.") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { mlirContextEnableMultithreading(self.get(), enable); }, - nb::arg("enable")) - .def("set_thread_pool", - [](PyMlirContext &self, PyThreadPool &pool) { - // we should disable multi-threading first before setting - // new thread pool otherwise the assert in - // MLIRContext::setThreadPool will be raised. - mlirContextEnableMultithreading(self.get(), false); - mlirContextSetThreadPool(self.get(), pool.get()); - }) - .def("get_num_threads", - [](PyMlirContext &self) { - return mlirContextGetNumThreads(self.get()); - }) - .def("_mlir_thread_pool_ptr", - [](PyMlirContext &self) { - MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); - std::stringstream ss; - ss << pool.ptr; - return ss.str(); - }) + nb::arg("enable"), + R"( + Enables or disables multi-threading support in the context. + + Args: + enable: Whether to enable (True) or disable (False) multi-threading. + )") + .def( + "set_thread_pool", + [](PyMlirContext &self, PyThreadPool &pool) { + // we should disable multi-threading first before setting + // new thread pool otherwise the assert in + // MLIRContext::setThreadPool will be raised. + mlirContextEnableMultithreading(self.get(), false); + mlirContextSetThreadPool(self.get(), pool.get()); + }, + R"( + Sets a custom thread pool for the context to use. + + Args: + pool: A ThreadPool object to use for parallel operations. + + Note: + Multi-threading is automatically disabled before setting the thread pool.)") + .def( + "get_num_threads", + [](PyMlirContext &self) { + return mlirContextGetNumThreads(self.get()); + }, + "Gets the number of threads in the context's thread pool.") + .def( + "_mlir_thread_pool_ptr", + [](PyMlirContext &self) { + MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); + std::stringstream ss; + ss << pool.ptr; + return ss.str(); + }, + "Gets the raw pointer to the LLVM thread pool as a string.") .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - nb::arg("operation_name")) + nb::arg("operation_name"), + R"( + Checks whether an operation with the given name is registered. + + Args: + operation_name: The fully qualified name of the operation (e.g., `arith.addf`). + + Returns: + True if the operation is registered, False otherwise.)") .def( "append_dialect_registry", [](PyMlirContext &self, PyDialectRegistry ®istry) { mlirContextAppendDialectRegistry(self.get(), registry); }, - nb::arg("registry")) + nb::arg("registry"), + R"( + Appends the contents of a dialect registry to the context. + + Args: + registry: A DialectRegistry containing dialects to append.)") .def_prop_rw("emit_error_diagnostics", &PyMlirContext::getEmitErrorDiagnostics, &PyMlirContext::setEmitErrorDiagnostics, - "Emit error diagnostics to diagnostic handlers. By default " - "error diagnostics are captured and reported through " - "MLIRError exceptions.") - .def("load_all_available_dialects", [](PyMlirContext &self) { - mlirContextLoadAllAvailableDialects(self.get()); - }); + R"( + Controls whether error diagnostics are emitted to diagnostic handlers. + + By default, error diagnostics are captured and reported through MLIRError exceptions.)") + .def( + "load_all_available_dialects", + [](PyMlirContext &self) { + mlirContextLoadAllAvailableDialects(self.get()); + }, + R"( + Loads all dialects available in the registry into the context. + + This eagerly loads all dialects that have been registered, making them + immediately available for use.)"); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor //---------------------------------------------------------------------------- nb::class_<PyDialectDescriptor>(m, "DialectDescriptor") - .def_prop_ro("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = mlirDialectGetNamespace(self.get()); - return nb::str(ns.data, ns.length); - }) + .def_prop_ro( + "namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + return nb::str(ns.data, ns.length); + }, + "Returns the namespace of the dialect.") .def( "__repr__", [](PyDialectDescriptor &self) { @@ -3143,35 +3163,43 @@ void mlir::python::populateIRCore(nb::module_ &m) { repr.append(">"); return repr; }, - nb::sig("def __repr__(self) -> str")); + nb::sig("def __repr__(self) -> str"), + "Returns a string representation of the dialect descriptor."); //---------------------------------------------------------------------------- // Mapping of PyDialects //---------------------------------------------------------------------------- nb::class_<PyDialects>(m, "Dialects") - .def("__getitem__", - [=](PyDialects &self, std::string keyName) { - MlirDialect dialect = - self.getDialectForKey(keyName, /*attrError=*/false); - nb::object descriptor = - nb::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(keyName, std::move(descriptor)); - }) - .def("__getattr__", [=](PyDialects &self, std::string attrName) { - MlirDialect dialect = - self.getDialectForKey(attrName, /*attrError=*/true); - nb::object descriptor = - nb::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(attrName, std::move(descriptor)); - }); + .def( + "__getitem__", + [=](PyDialects &self, std::string keyName) { + MlirDialect dialect = + self.getDialectForKey(keyName, /*attrError=*/false); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(keyName, std::move(descriptor)); + }, + "Gets a dialect by name using subscript notation.") + .def( + "__getattr__", + [=](PyDialects &self, std::string attrName) { + MlirDialect dialect = + self.getDialectForKey(attrName, /*attrError=*/true); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(attrName, std::move(descriptor)); + }, + "Gets a dialect by name using attribute notation."); //---------------------------------------------------------------------------- // Mapping of PyDialect //---------------------------------------------------------------------------- nb::class_<PyDialect>(m, "Dialect") - .def(nb::init<nb::object>(), nb::arg("descriptor")) - .def_prop_ro("descriptor", - [](PyDialect &self) { return self.getDescriptor(); }) + .def(nb::init<nb::object>(), nb::arg("descriptor"), + "Creates a Dialect from a DialectDescriptor.") + .def_prop_ro( + "descriptor", [](PyDialect &self) { return self.getDescriptor(); }, + "Returns the DialectDescriptor for this dialect.") .def( "__repr__", [](const nb::object &self) { @@ -3181,31 +3209,43 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::str(" (class ") + clazz.attr("__module__") + nb::str(".") + clazz.attr("__name__") + nb::str(")>"); }, - nb::sig("def __repr__(self) -> str")); + nb::sig("def __repr__(self) -> str"), + "Returns a string representation of the dialect."); //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- nb::class_<PyDialectRegistry>(m, "DialectRegistry") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule, + "Gets a capsule wrapping the MlirDialectRegistry.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyDialectRegistry::createFromCapsule) - .def(nb::init<>()); + &PyDialectRegistry::createFromCapsule, + "Creates a DialectRegistry from a capsule wrapping " + "`MlirDialectRegistry`.") + .def(nb::init<>(), "Creates a new empty dialect registry."); //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- nb::class_<PyLocation>(m, "Location") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) - .def("__enter__", &PyLocation::contextEnter) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule, + "Gets a capsule wrapping the MlirLocation.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule, + "Creates a Location from a capsule wrapping MlirLocation.") + .def("__enter__", &PyLocation::contextEnter, + "Enters the location as a context manager.") .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none()) - .def("__eq__", - [](PyLocation &self, PyLocation &other) -> bool { - return mlirLocationEqual(self, other); - }) - .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) + nb::arg("exc_value").none(), nb::arg("traceback").none(), + "Exits the location context manager.") + .def( + "__eq__", + [](PyLocation &self, PyLocation &other) -> bool { + return mlirLocationEqual(self, other); + }, + "Compares two locations for equality.") + .def( + "__eq__", [](PyLocation &self, nb::object other) { return false; }, + "Compares location with non-location object (always returns False).") .def_prop_ro_static( "current", [](nb::object & /*class*/) -> std::optional<PyLocation *> { @@ -3217,7 +3257,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { // clang-format off nb::sig("def current(/) -> Location | None"), // clang-format on - "Gets the Location bound to the current thread or raises ValueError") + "Gets the Location bound to the current thread or raises ValueError.") .def_static( "unknown", [](DefaultingPyMlirContext context) { @@ -3225,13 +3265,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationUnknownGet(context->get())); }, nb::arg("context") = nb::none(), - "Gets a Location representing an unknown location") + "Gets a Location representing an unknown location.") .def_static( "callsite", [](PyLocation callee, const std::vector<PyLocation> &frames, DefaultingPyMlirContext context) { if (frames.empty()) - throw nb::value_error("No caller frames provided"); + throw nb::value_error("No caller frames provided."); MlirLocation caller = frames.back().get(); for (const PyLocation &frame : llvm::reverse(llvm::ArrayRef(frames).drop_back())) @@ -3240,18 +3280,23 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationCallSiteGet(callee.get(), caller)); }, nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(), - kContextGetCallSiteLocationDocstring) - .def("is_a_callsite", mlirLocationIsACallSite) - .def_prop_ro("callee", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationCallSiteGetCallee(self)); - }) - .def_prop_ro("caller", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationCallSiteGetCaller(self)); - }) + "Gets a Location representing a caller and callsite.") + .def("is_a_callsite", mlirLocationIsACallSite, + "Returns True if this location is a CallSiteLoc.") + .def_prop_ro( + "callee", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCallee(self)); + }, + "Gets the callee location from a CallSiteLoc.") + .def_prop_ro( + "caller", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCaller(self)); + }, + "Gets the caller location from a CallSiteLoc.") .def_static( "file", [](std::string filename, int line, int col, @@ -3262,7 +3307,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { context->get(), toMlirStringRef(filename), line, col)); }, nb::arg("filename"), nb::arg("line"), nb::arg("col"), - nb::arg("context") = nb::none(), kContextGetFileLocationDocstring) + nb::arg("context") = nb::none(), + "Gets a Location representing a file, line and column.") .def_static( "file", [](std::string filename, int startLine, int startCol, int endLine, @@ -3274,17 +3320,25 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), nb::arg("end_line"), nb::arg("end_col"), - nb::arg("context") = nb::none(), kContextGetFileRangeDocstring) - .def("is_a_file", mlirLocationIsAFileLineColRange) - .def_prop_ro("filename", - [](MlirLocation loc) { - return mlirIdentifierStr( - mlirLocationFileLineColRangeGetFilename(loc)); - }) - .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine) - .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn) - .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine) - .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn) + nb::arg("context") = nb::none(), + "Gets a Location representing a file, line and column range.") + .def("is_a_file", mlirLocationIsAFileLineColRange, + "Returns True if this location is a FileLineColLoc.") + .def_prop_ro( + "filename", + [](MlirLocation loc) { + return mlirIdentifierStr( + mlirLocationFileLineColRangeGetFilename(loc)); + }, + "Gets the filename from a FileLineColLoc.") + .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine, + "Gets the start line number from a `FileLineColLoc`.") + .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn, + "Gets the start column number from a `FileLineColLoc`.") + .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine, + "Gets the end line number from a `FileLineColLoc`.") + .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn, + "Gets the end column number from a `FileLineColLoc`.") .def_static( "fused", [](const std::vector<PyLocation> &pyLocations, @@ -3300,8 +3354,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), location); }, nb::arg("locations"), nb::arg("metadata") = nb::none(), - nb::arg("context") = nb::none(), kContextGetFusedLocationDocstring) - .def("is_a_fused", mlirLocationIsAFused) + nb::arg("context") = nb::none(), + "Gets a Location representing a fused location with optional " + "metadata.") + .def("is_a_fused", mlirLocationIsAFused, + "Returns True if this location is a `FusedLoc`.") .def_prop_ro( "locations", [](PyLocation &self) { @@ -3314,7 +3371,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { for (unsigned i = 0; i < numLocations; ++i) pyLocations.emplace_back(self.getContext(), locations[i]); return pyLocations; - }) + }, + "Gets the list of locations from a `FusedLoc`.") .def_static( "name", [](std::string name, std::optional<PyLocation> childLoc, @@ -3327,17 +3385,24 @@ void mlir::python::populateIRCore(nb::module_ &m) { : mlirLocationUnknownGet(context->get()))); }, nb::arg("name"), nb::arg("childLoc") = nb::none(), - nb::arg("context") = nb::none(), kContextGetNameLocationDocString) - .def("is_a_name", mlirLocationIsAName) - .def_prop_ro("name_str", - [](MlirLocation loc) { - return mlirIdentifierStr(mlirLocationNameGetName(loc)); - }) - .def_prop_ro("child_loc", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationNameGetChildLoc(self)); - }) + nb::arg("context") = nb::none(), + "Gets a Location representing a named location with optional child " + "location.") + .def("is_a_name", mlirLocationIsAName, + "Returns True if this location is a `NameLoc`.") + .def_prop_ro( + "name_str", + [](MlirLocation loc) { + return mlirIdentifierStr(mlirLocationNameGetName(loc)); + }, + "Gets the name string from a `NameLoc`.") + .def_prop_ro( + "child_loc", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationNameGetChildLoc(self)); + }, + "Gets the child location from a `NameLoc`.") .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { @@ -3345,41 +3410,59 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationFromAttribute(attribute)); }, nb::arg("attribute"), nb::arg("context") = nb::none(), - "Gets a Location from a LocationAttr") + "Gets a Location from a `LocationAttr`.") .def_prop_ro( "context", [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that owns the Location") + "Context that owns the `Location`.") .def_prop_ro( "attr", [](PyLocation &self) { return PyAttribute(self.getContext(), mlirLocationGetAttribute(self)); }, - "Get the underlying LocationAttr") + "Get the underlying `LocationAttr`.") .def( "emit_error", [](PyLocation &self, std::string message) { mlirEmitError(self, message.c_str()); }, - nb::arg("message"), "Emits an error at this location") - .def("__repr__", [](PyLocation &self) { - PyPrintAccumulator printAccum; - mlirLocationPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }); + nb::arg("message"), + R"( + Emits an error diagnostic at this location. + + Args: + message: The error message to emit.)") + .def( + "__repr__", + [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + "Returns the assembly representation of the location."); //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable()) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule, + "Gets a capsule wrapping the MlirModule.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, - kModuleCAPICreate) - .def("_clear_mlir_module", &PyModule::clearMlirModule) + R"( + Creates a Module from a `MlirModule` wrapped by a capsule (i.e. `module._CAPIPtr`). + + This returns a new object **BUT** `_clear_mlir_module(module)` must be called to + prevent double-frees (of the underlying `mlir::Module`).)") + .def("_clear_mlir_module", &PyModule::clearMlirModule, + R"( + Clears the internal MLIR module reference. + + This is used internally to prevent double-free when ownership is transferred + via the C API capsule mechanism. Not intended for normal use.)") .def_static( "parse", [](const std::string &moduleAsm, DefaultingPyMlirContext context) @@ -3427,13 +3510,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("loc") = nb::none(), "Creates an empty module") + nb::arg("loc") = nb::none(), "Creates an empty module.") .def_prop_ro( "context", [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that created the Module") + "Context that created the `Module`.") .def_prop_ro( "operation", [](PyModule &self) -> nb::typed<nb::object, PyOperation> { @@ -3442,7 +3525,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { self.getRef().releaseObject()) .releaseObject(); }, - "Accesses the module as an operation") + "Accesses the module as an operation.") .def_prop_ro( "body", [](PyModule &self) { @@ -3452,7 +3535,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); return returnBlock; }, - "Return the block for this module") + "Return the block for this module.") .def( "dump", [](PyModule &self) { @@ -3465,39 +3548,59 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, - nb::sig("def __str__(self) -> str"), kOperationStrDunderDocstring) + nb::sig("def __str__(self) -> str"), + R"( + Gets the assembly form of the operation with default options. + + If more advanced control over the assembly formatting or I/O options is needed, + use the dedicated print or get_asm method, which supports keyword arguments to + customize behavior. + )") .def( "__eq__", [](PyModule &self, PyModule &other) { return mlirModuleEqual(self.get(), other.get()); }, - "other"_a) - .def("__hash__", - [](PyModule &self) { return mlirModuleHashValue(self.get()); }); + "other"_a, "Compares two modules for equality.") + .def( + "__hash__", + [](PyModule &self) { return mlirModuleHashValue(self.get()); }, + "Returns the hash value of the module."); //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- nb::class_<PyOperationBase>(m, "_OperationBase") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, - [](PyOperationBase &self) { - return self.getOperation().getCapsule(); - }) - .def("__eq__", - [](PyOperationBase &self, PyOperationBase &other) { - return mlirOperationEqual(self.getOperation().get(), - other.getOperation().get()); - }) - .def("__eq__", - [](PyOperationBase &self, nb::object other) { return false; }) - .def("__hash__", - [](PyOperationBase &self) { - return mlirOperationHashValue(self.getOperation().get()); - }) - .def_prop_ro("attributes", - [](PyOperationBase &self) { - return PyOpAttributeMap(self.getOperation().getRef()); - }) + .def_prop_ro( + MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }, + "Gets a capsule wrapping the `MlirOperation`.") + .def( + "__eq__", + [](PyOperationBase &self, PyOperationBase &other) { + return mlirOperationEqual(self.getOperation().get(), + other.getOperation().get()); + }, + "Compares two operations for equality.") + .def( + "__eq__", + [](PyOperationBase &self, nb::object other) { return false; }, + "Compares operation with non-operation object (always returns " + "False).") + .def( + "__hash__", + [](PyOperationBase &self) { + return mlirOperationHashValue(self.getOperation().get()); + }, + "Returns the hash value of the operation.") + .def_prop_ro( + "attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap(self.getOperation().getRef()); + }, + "Returns a dictionary-like map of operation attributes.") .def_prop_ro( "context", [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> { @@ -3505,22 +3608,28 @@ void mlir::python::populateIRCore(nb::module_ &m) { concreteOperation.checkValid(); return concreteOperation.getContext().getObject(); }, - "Context that owns the Operation") - .def_prop_ro("name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = concreteOperation.get(); - return mlirIdentifierStr(mlirOperationGetName(operation)); - }) - .def_prop_ro("operands", - [](PyOperationBase &self) { - return PyOpOperandList(self.getOperation().getRef()); - }) - .def_prop_ro("regions", - [](PyOperationBase &self) { - return PyRegionList(self.getOperation().getRef()); - }) + "Context that owns the operation.") + .def_prop_ro( + "name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + return mlirIdentifierStr(mlirOperationGetName(operation)); + }, + "Returns the fully qualified name of the operation.") + .def_prop_ro( + "operands", + [](PyOperationBase &self) { + return PyOpOperandList(self.getOperation().getRef()); + }, + "Returns the list of operation operands.") + .def_prop_ro( + "regions", + [](PyOperationBase &self) { + return PyRegionList(self.getOperation().getRef()); + }, + "Returns the list of operation regions.") .def_prop_ro( "results", [](PyOperationBase &self) { @@ -3551,14 +3660,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { "defined or derived from."), nb::for_setter("Sets the source location the operation was defined " "or derived from.")) - .def_prop_ro("parent", - [](PyOperationBase &self) - -> std::optional<nb::typed<nb::object, PyOperation>> { - auto parent = self.getOperation().getParentOperation(); - if (parent) - return parent->getObject(); - return {}; - }) + .def_prop_ro( + "parent", + [](PyOperationBase &self) + -> std::optional<nb::typed<nb::object, PyOperation>> { + auto parent = self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return {}; + }, + "Returns the parent operation, or `None` if at top level.") .def( "__str__", [](PyOperationBase &self) { @@ -3579,7 +3690,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::overload_cast<PyAsmState &, nb::object, bool>( &PyOperationBase::print), nb::arg("state"), nb::arg("file") = nb::none(), - nb::arg("binary") = false, kOperationPrintStateDocstring) + nb::arg("binary") = false, + R"( + Prints the assembly form of the operation to a file like object. + + Args: + state: `AsmState` capturing the operation numbering and flags. + file: Optional file like object to write to. Defaults to sys.stdout. + binary: Whether to write `bytes` (True) or `str` (False). Defaults to False.)") .def("print", nb::overload_cast<std::optional<int64_t>, std::optional<int64_t>, bool, bool, bool, bool, bool, bool, nb::object, @@ -3594,10 +3712,47 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("file") = nb::none(), nb::arg("binary") = false, nb::arg("skip_regions") = false, - kOperationPrintDocstring) + R"( + Prints the assembly form of the operation to a file like object. + + Args: + large_elements_limit: Whether to elide elements attributes above this + number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource attributes above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. + enable_debug_info: Whether to print debug/location information. Defaults + to False. + pretty_debug_info: Whether to format debug information for easier reading + by a human (warning: the result is unparseable). Defaults to False. + print_generic_op_form: Whether to print the generic assembly forms of all + ops. Defaults to False. + use_local_scope: Whether to print in a way that is more optimized for + multi-threaded access but may not be consistent with how the overall + module prints. + use_name_loc_as_prefix: Whether to use location attributes (NameLoc) as + prefixes for the SSA identifiers. Defaults to False. + assume_verified: By default, if not printing generic form, the verifier + will be run and if it fails, generic form will be printed with a comment + about failed verification. While a reasonable default for interactive use, + for systematic use, it is often better for the caller to verify explicitly + and report failures in a more robust fashion. Set this to True if doing this + in order to avoid running a redundant verification. If the IR is actually + invalid, behavior is undefined. + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + skip_regions: Whether to skip printing regions. Defaults to False.)") .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), nb::arg("desired_version") = nb::none(), - kOperationPrintBytecodeDocstring) + R"( + Write the bytecode form of the operation to a file like object. + + Args: + file: The file like object to write to. + desired_version: Optional version of bytecode to emit. + Returns: + The bytecode writer status.)") .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. nb::arg("binary") = false, @@ -3609,7 +3764,17 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("use_local_scope") = false, nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, - kOperationGetAsmDocstring) + R"( + Gets the assembly form of the operation with all options available. + + Args: + binary: Whether to return a bytes (True) or str (False) object. Defaults to + False. + ... others ...: See the print() method for common keyword arguments for + configuring the printout. + Returns: + Either a bytes or str object, depending on the setting of the `binary` + argument.)") .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") @@ -3621,18 +3786,31 @@ void mlir::python::populateIRCore(nb::module_ &m) { "block.") .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, nb::arg("other"), - "Given an operation 'other' that is within the same parent block, " - "return" - "whether the current operation is before 'other' in the operation " - "list" - "of the parent block.") + R"( + Checks if this operation is before another in the same block. + + Args: + other: Another operation in the same parent block. + + Returns: + True if this operation is before `other` in the operation list of the parent block.)") .def( "clone", [](PyOperationBase &self, const nb::object &ip) -> nb::typed<nb::object, PyOperation> { return self.getOperation().clone(ip); }, - nb::arg("ip") = nb::none()) + nb::arg("ip") = nb::none(), + R"( + Creates a deep copy of the operation. + + Args: + ip: Optional insertion point where the cloned operation should be inserted. + If None, the current insertion point is used. If False, the operation + remains detached. + + Returns: + A new Operation that is a clone of this operation.)") .def( "detach_from_parent", [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> { @@ -3653,13 +3831,24 @@ void mlir::python::populateIRCore(nb::module_ &m) { return operation.isAttached(); }, "Reports if the operation is attached to its parent block.") - .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) + .def( + "erase", [](PyOperationBase &self) { self.getOperation().erase(); }, + R"( + Erases the operation and frees its memory. + + Note: + After erasing, any Python references to the operation become invalid.)") .def("walk", &PyOperationBase::walk, nb::arg("callback"), nb::arg("walk_order") = MlirWalkPostOrder, // clang-format off - nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None") + nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"), // clang-format on - ); + R"( + Walks the operation tree with a callback function. + + Args: + callback: A callable that takes an Operation and returns a WalkResult. + walk_order: The order of traversal (PRE_ORDER or POST_ORDER).)"); nb::class_<PyOperation, PyOperationBase>(m, "Operation") .def_static( @@ -3692,7 +3881,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), nb::arg("successors") = nb::none(), nb::arg("regions") = 0, nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(), - nb::arg("infer_type") = false, kOperationCreateDocstring) + nb::arg("infer_type") = false, + R"( + Creates a new operation. + + Args: + name: Operation name (e.g. `dialect.operation`). + results: Optional sequence of Type representing op result types. + operands: Optional operands of the operation. + attributes: Optional Dict of {str: Attribute}. + successors: Optional List of Block for the operation's successors. + regions: Number of regions to create (default = 0). + location: Optional Location object (defaults to resolve from context manager). + ip: Optional InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager). + infer_type: Whether to infer result types (default = False). + Returns: + A new detached Operation object. Detached operations can be added to blocks, which causes them to become attached.)") .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, @@ -3705,18 +3909,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("context") = nb::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule, + "Gets a capsule wrapping the MlirOperation.") .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyOperation::createFromCapsule) - .def_prop_ro("operation", - [](nb::object self) -> nb::typed<nb::object, PyOperation> { - return self; - }) - .def_prop_ro("opview", - [](PyOperation &self) -> nb::typed<nb::object, PyOpView> { - return self.createOpView(); - }) - .def_prop_ro("block", &PyOperation::getBlock) + &PyOperation::createFromCapsule, + "Creates an Operation from a capsule wrapping MlirOperation.") + .def_prop_ro( + "operation", + [](nb::object self) -> nb::typed<nb::object, PyOperation> { + return self; + }, + "Returns self (the operation).") + .def_prop_ro( + "opview", + [](PyOperation &self) -> nb::typed<nb::object, PyOpView> { + return self.createOpView(); + }, + R"( + Returns an OpView of this operation. + + Note: + If the operation has a registered and loaded dialect then this OpView will + be concrete wrapper class.)") + .def_prop_ro("block", &PyOperation::getBlock, + "Returns the block containing this operation.") .def_prop_ro( "successors", [](PyOperationBase &self) { @@ -3830,7 +4046,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("cls"), nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", nb::arg("context") = nb::none(), - "Parses a specific, generated OpView based on class level attributes"); + "Parses a specific, generated OpView based on class level attributes."); //---------------------------------------------------------------------------- // Mapping of PyRegion. @@ -3856,17 +4072,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyBlockIterator(self.getParentOperation(), firstBlock); }, "Iterates over blocks in the region.") - .def("__eq__", - [](PyRegion &self, PyRegion &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); + .def( + "__eq__", + [](PyRegion &self, PyRegion &other) { + return self.get().ptr == other.get().ptr; + }, + "Compares two regions for pointer equality.") + .def( + "__eq__", [](PyRegion &self, nb::object &other) { return false; }, + "Compares region with non-region object (always returns False)."); //---------------------------------------------------------------------------- // Mapping of PyBlock. //---------------------------------------------------------------------------- nb::class_<PyBlock>(m, "Block") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule, + "Gets a capsule wrapping the MlirBlock.") .def_prop_ro( "owner", [](PyBlock &self) -> nb::typed<nb::object, PyOpView> { @@ -3893,14 +4114,26 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirBlockAddArgument(self.get(), type, loc)); }, "type"_a, "loc"_a, - "Append an argument of the specified type to the block and returns " - "the newly added argument.") + R"( + Appends an argument of the specified type to the block. + + Args: + type: The type of the argument to add. + loc: The source location for the argument. + + Returns: + The newly added block argument.)") .def( "erase_argument", [](PyBlock &self, unsigned index) { return mlirBlockEraseArgument(self.get(), index); }, - "Erase the argument at 'index' and remove it from the argument list.") + nb::arg("index"), + R"( + Erases the argument at the specified index. + + Args: + index: The index of the argument to erase.)") .def_prop_ro( "operations", [](PyBlock &self) { @@ -3928,7 +4161,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirBlockDetach(b); mlirRegionAppendOwnedBlock(region.get(), b); }, - "Append this block to a region, transferring ownership if necessary") + nb::arg("region"), + R"( + Appends this block to a region. + + Transfers ownership if the block is currently owned by another region. + + Args: + region: The region to append the block to.)") .def( "create_before", [](PyBlock &self, const nb::args &pyArgTypes, @@ -3969,15 +4209,21 @@ void mlir::python::populateIRCore(nb::module_ &m) { firstOperation); }, "Iterates over operations in the block.") - .def("__eq__", - [](PyBlock &self, PyBlock &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) - .def("__hash__", - [](PyBlock &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + .def( + "__eq__", + [](PyBlock &self, PyBlock &other) { + return self.get().ptr == other.get().ptr; + }, + "Compares two blocks for pointer equality.") + .def( + "__eq__", [](PyBlock &self, nb::object &other) { return false; }, + "Compares block with non-block object (always returns False).") + .def( + "__hash__", + [](PyBlock &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the block.") .def( "__str__", [](PyBlock &self) { @@ -4000,8 +4246,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { self.getParentOperation().getObject()); }, nb::arg("operation"), - "Appends an operation to this block. If the operation is currently " - "in another block, it will be moved.") + R"( + Appends an operation to this block. + + If the operation is currently in another block, it will be moved. + + Args: + operation: The operation to append to the block.)") .def_prop_ro( "successors", [](PyBlock &self) { @@ -4022,10 +4273,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_<PyInsertionPoint>(m, "InsertionPoint") .def(nb::init<PyBlock &>(), nb::arg("block"), "Inserts after the last operation but still inside the block.") - .def("__enter__", &PyInsertionPoint::contextEnter) + .def("__enter__", &PyInsertionPoint::contextEnter, + "Enters the insertion point as a context manager.") .def("__exit__", &PyInsertionPoint::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none()) + nb::arg("traceback").none(), + "Exits the insertion point context manager.") .def_prop_ro_static( "current", [](nb::object & /*class*/) { @@ -4036,20 +4289,50 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::sig("def current(/) -> InsertionPoint"), "Gets the InsertionPoint bound to the current thread or raises " - "ValueError if none has been set") + "ValueError if none has been set.") .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - nb::arg("block"), "Inserts at the beginning of the block.") + nb::arg("block"), + R"( + Creates an insertion point at the beginning of a block. + + Args: + block: The block at whose beginning operations should be inserted. + + Returns: + An InsertionPoint at the block's beginning.)") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - nb::arg("block"), "Inserts before the block terminator.") + nb::arg("block"), + R"( + Creates an insertion point before a block's terminator. + + Args: + block: The block whose terminator to insert before. + + Returns: + An InsertionPoint before the terminator. + + Raises: + ValueError: If the block has no terminator.)") .def_static("after", &PyInsertionPoint::after, nb::arg("operation"), - "Inserts after the operation.") + R"( + Creates an insertion point immediately after an operation. + + Args: + operation: The operation after which to insert. + + Returns: + An InsertionPoint after the operation.)") .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), - "Inserts an operation.") + R"( + Inserts an operation at this insertion point. + + Args: + operation: The operation to insert.)") .def_prop_ro( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, - "Returns the block that this InsertionPoint points to.") + "Returns the block that this `InsertionPoint` points to.") .def_prop_ro( "ref_operation", [](PyInsertionPoint &self) @@ -4061,7 +4344,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " - "the block"); + "the block."); //---------------------------------------------------------------------------- // Mapping of PyAttribute. @@ -4070,10 +4353,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"), - "Casts the passed attribute to the generic Attribute") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, - &PyAttribute::createFromCapsule) + "Casts the passed attribute to the generic `Attribute`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule, + "Gets a capsule wrapping the MlirAttribute.") + .def_static( + MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule, + "Creates an Attribute from a capsule wrapping `MlirAttribute`.") .def_static( "parse", [](const std::string &attrSpec, DefaultingPyMlirContext context) @@ -4086,33 +4371,49 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyAttribute(context.get()->getRef(), attr).maybeDownCast(); }, nb::arg("asm"), nb::arg("context") = nb::none(), - "Parses an attribute from an assembly form. Raises an MLIRError on " + "Parses an attribute from an assembly form. Raises an `MLIRError` on " "failure.") .def_prop_ro( "context", [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that owns the Attribute") - .def_prop_ro("type", - [](PyAttribute &self) -> nb::typed<nb::object, PyType> { - return PyType(self.getContext(), - mlirAttributeGetType(self)) - .maybeDownCast(); - }) + "Context that owns the `Attribute`.") + .def_prop_ro( + "type", + [](PyAttribute &self) -> nb::typed<nb::object, PyType> { + return PyType(self.getContext(), mlirAttributeGetType(self)) + .maybeDownCast(); + }, + "Returns the type of the `Attribute`.") .def( "get_named", [](PyAttribute &self, std::string name) { return PyNamedAttribute(self, std::move(name)); }, - nb::keep_alive<0, 1>(), "Binds a name to the attribute") - .def("__eq__", - [](PyAttribute &self, PyAttribute &other) { return self == other; }) - .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) - .def("__hash__", - [](PyAttribute &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + nb::keep_alive<0, 1>(), + R"( + Binds a name to the attribute, creating a `NamedAttribute`. + + Args: + name: The name to bind to the `Attribute`. + + Returns: + A `NamedAttribute` with the given name and this attribute.)") + .def( + "__eq__", + [](PyAttribute &self, PyAttribute &other) { return self == other; }, + "Compares two attributes for equality.") + .def( + "__eq__", [](PyAttribute &self, nb::object &other) { return false; }, + "Compares attribute with non-attribute object (always returns " + "False).") + .def( + "__hash__", + [](PyAttribute &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the attribute.") .def( "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, kDumpDocstring) @@ -4125,61 +4426,69 @@ void mlir::python::populateIRCore(nb::module_ &m) { return printAccum.join(); }, "Returns the assembly form of the Attribute.") - .def("__repr__", - [](PyAttribute &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, attribute values are generally considered useful and - // are printed. This may need to be re-evaluated if debug dumps end - // up being excessive. - PyPrintAccumulator printAccum; - printAccum.parts.append("Attribute("); - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_prop_ro("typeid", - [](PyAttribute &self) { - MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - return PyTypeID(mlirTypeID); - }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> { - return self.maybeDownCast(); - }); + .def( + "__repr__", + [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and + // are printed. This may need to be re-evaluated if debug dumps end + // up being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + "Returns a string representation of the attribute.") + .def_prop_ro( + "typeid", + [](PyAttribute &self) { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return PyTypeID(mlirTypeID); + }, + "Returns the `TypeID` of the attribute.") + .def( + MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> { + return self.maybeDownCast(); + }, + "Downcasts the attribute to a more specific attribute if possible."); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute //---------------------------------------------------------------------------- nb::class_<PyNamedAttribute>(m, "NamedAttribute") - .def("__repr__", - [](PyNamedAttribute &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("NamedAttribute("); - printAccum.parts.append( - nb::str(mlirIdentifierStr(self.namedAttr.name).data, - mlirIdentifierStr(self.namedAttr.name).length)); - printAccum.parts.append("="); - mlirAttributePrint(self.namedAttr.attribute, - printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) + .def( + "__repr__", + [](PyNamedAttribute &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("NamedAttribute("); + printAccum.parts.append( + nb::str(mlirIdentifierStr(self.namedAttr.name).data, + mlirIdentifierStr(self.namedAttr.name).length)); + printAccum.parts.append("="); + mlirAttributePrint(self.namedAttr.attribute, + printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + "Returns a string representation of the named attribute.") .def_prop_ro( "name", [](PyNamedAttribute &self) { return mlirIdentifierStr(self.namedAttr.name); }, - "The name of the NamedAttribute binding") + "The name of the `NamedAttribute` binding.") .def_prop_ro( "attr", [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"), - "The underlying generic attribute of the NamedAttribute binding"); + "The underlying generic attribute of the `NamedAttribute` binding."); //---------------------------------------------------------------------------- // Mapping of PyType. @@ -4188,9 +4497,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. .def(nb::init<PyType &>(), nb::arg("cast_from_type"), - "Casts the passed type to the generic Type") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) + "Casts the passed type to the generic `Type`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule, + "Gets a capsule wrapping the `MlirType`.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule, + "Creates a Type from a capsule wrapping `MlirType`.") .def_static( "parse", [](std::string typeSpec, @@ -4203,21 +4514,31 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyType(context.get()->getRef(), type).maybeDownCast(); }, nb::arg("asm"), nb::arg("context") = nb::none(), - kContextParseTypeDocstring) + R"( + Parses the assembly form of a type. + + Returns a Type object or raises an `MLIRError` if the type cannot be parsed. + + See also: https://mlir.llvm.org/docs/LangRef/#type-system)") .def_prop_ro( "context", [](PyType &self) -> nb::typed<nb::object, PyMlirContext> { return self.getContext().getObject(); }, - "Context that owns the Type") - .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) + "Context that owns the `Type`.") + .def( + "__eq__", [](PyType &self, PyType &other) { return self == other; }, + "Compares two types for equality.") .def( "__eq__", [](PyType &self, nb::object &other) { return false; }, - nb::arg("other").none()) - .def("__hash__", - [](PyType &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + nb::arg("other").none(), + "Compares type with non-type object (always returns False).") + .def( + "__hash__", + [](PyType &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the `Type`.") .def( "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) .def( @@ -4228,60 +4549,84 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.getUserData()); return printAccum.join(); }, - "Returns the assembly form of the type.") - .def("__repr__", - [](PyType &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, types are an exception as they typically have compact - // assembly forms and printing them is useful. - PyPrintAccumulator printAccum; - printAccum.parts.append("Type("); - mlirTypePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyType &self) -> nb::typed<nb::object, PyType> { - return self.maybeDownCast(); - }) - .def_prop_ro("typeid", [](PyType &self) { - MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); - if (!mlirTypeIDIsNull(mlirTypeID)) - return PyTypeID(mlirTypeID); - auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self))); - throw nb::value_error( - (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); - }); + "Returns the assembly form of the `Type`.") + .def( + "__repr__", + [](PyType &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, types are an exception as they typically have compact + // assembly forms and printing them is useful. + PyPrintAccumulator printAccum; + printAccum.parts.append("Type("); + mlirTypePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + "Returns a string representation of the `Type`.") + .def( + MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyType &self) -> nb::typed<nb::object, PyType> { + return self.maybeDownCast(); + }, + "Downcasts the Type to a more specific `Type` if possible.") + .def_prop_ro( + "typeid", + [](PyType &self) { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return PyTypeID(mlirTypeID); + auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self))); + throw nb::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); + }, + "Returns the `TypeID` of the `Type`, or raises `ValueError` if " + "`Type` has no " + "`TypeID`."); //---------------------------------------------------------------------------- // Mapping of PyTypeID. //---------------------------------------------------------------------------- nb::class_<PyTypeID>(m, "TypeID") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule, + "Gets a capsule wrapping the `MlirTypeID`.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule, + "Creates a `TypeID` from a capsule wrapping `MlirTypeID`.") // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether // the Python objects are the same (i.e., PyTypeID is a value type). - .def("__eq__", - [](PyTypeID &self, PyTypeID &other) { return self == other; }) - .def("__eq__", - [](PyTypeID &self, const nb::object &other) { return false; }) + .def( + "__eq__", + [](PyTypeID &self, PyTypeID &other) { return self == other; }, + "Compares two `TypeID`s for equality.") + .def( + "__eq__", + [](PyTypeID &self, const nb::object &other) { return false; }, + "Compares TypeID with non-TypeID object (always returns False).") // Note, this gives the hash value of the underlying TypeID, not the // hash value of the Python object, nor the hash value of the // MlirTypeID wrapper. - .def("__hash__", [](PyTypeID &self) { - return static_cast<size_t>(mlirTypeIDHashValue(self)); - }); + .def( + "__hash__", + [](PyTypeID &self) { + return static_cast<size_t>(mlirTypeIDHashValue(self)); + }, + "Returns the hash value of the `TypeID`."); //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - nb::class_<PyValue>(m, "Value") - .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) - .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) + m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type")); + + nb::class_<PyValue>(m, "Value", nb::is_generic(), + nb::sig("class Value(Generic[_T])")) + .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"), + "Creates a Value reference from another `Value`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule, + "Gets a capsule wrapping the `MlirValue`.") + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule, + "Creates a `Value` from a capsule wrapping `MlirValue`.") .def_prop_ro( "context", [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> { @@ -4312,23 +4657,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { assert(false && "Value must be a block argument or an op result"); return nb::none(); }, - // clang-format off - nb::sig("def owner(self) -> Operation | Block | None")) - // clang-format on - .def_prop_ro("uses", - [](PyValue &self) { - return PyOpOperandIterator( - mlirValueGetFirstUse(self.get())); - }) - .def("__eq__", - [](PyValue &self, PyValue &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyValue &self, nb::object other) { return false; }) - .def("__hash__", - [](PyValue &self) { - return static_cast<size_t>(llvm::hash_value(self.get().ptr)); - }) + "Returns the owner of the value (`Operation` for results, `Block` " + "for " + "arguments).") + .def_prop_ro( + "uses", + [](PyValue &self) { + return PyOpOperandIterator(mlirValueGetFirstUse(self.get())); + }, + "Returns an iterator over uses of this value.") + .def( + "__eq__", + [](PyValue &self, PyValue &other) { + return self.get().ptr == other.get().ptr; + }, + "Compares two values for pointer equality.") + .def( + "__eq__", [](PyValue &self, nb::object other) { return false; }, + "Compares value with non-value object (always returns False).") + .def( + "__hash__", + [](PyValue &self) { + return static_cast<size_t>(llvm::hash_value(self.get().ptr)); + }, + "Returns the hash value of the value.") .def( "__str__", [](PyValue &self) { @@ -4339,7 +4691,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.parts.append(")"); return printAccum.join(); }, - kValueDunderStrDocstring) + R"( + Returns the string form of the value. + + If the value is a block argument, this is the assembly form of its type and the + position in the argument list. If the value is an operation result, this is + equivalent to printing the operation that produced it. + )") .def( "get_name", [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) { @@ -4359,7 +4717,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { return printAccum.join(); }, nb::arg("use_local_scope") = false, - nb::arg("use_name_loc_as_prefix") = false) + nb::arg("use_name_loc_as_prefix") = false, + R"( + Returns the string form of value as an operand. + + Args: + use_local_scope: Whether to use local scope for naming. + use_name_loc_as_prefix: Whether to use the location attribute (NameLoc) as prefix. + + Returns: + The value's name as it appears in IR (e.g., `%0`, `%arg0`).)") .def( "get_name", [](PyValue &self, PyAsmState &state) { @@ -4370,25 +4737,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.getUserData()); return printAccum.join(); }, - nb::arg("state"), kGetNameAsOperand) - .def_prop_ro("type", - [](PyValue &self) -> nb::typed<nb::object, PyType> { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())) - .maybeDownCast(); - }) + nb::arg("state"), + "Returns the string form of value as an operand (i.e., the ValueID).") + .def_prop_ro( + "type", + [](PyValue &self) -> nb::typed<nb::object, PyType> { + return PyType(self.getParentOperation()->getContext(), + mlirValueGetType(self.get())) + .maybeDownCast(); + }, + "Returns the type of the value.") .def( "set_type", [](PyValue &self, const PyType &type) { - return mlirValueSetType(self.get(), type); + mlirValueSetType(self.get(), type); }, - nb::arg("type")) + nb::arg("type"), "Sets the type of the value.", + nb::sig("def set_type(self, type: _T)")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { mlirValueReplaceAllUsesOfWith(self.get(), with.get()); }, - kValueReplaceAllUsesWithDocstring) + "Replace all uses of value with the new value, updating anything in " + "the IR that uses `self` to use the other value instead.") .def( "replace_all_uses_except", [](PyValue &self, PyValue &with, PyOperation &exception) { @@ -4434,10 +4806,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("with_"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyValue &self) -> nb::typed<nb::object, PyValue> { - return self.maybeDownCast(); - }) + .def( + MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) -> nb::typed<nb::object, PyValue> { + return self.maybeDownCast(); + }, + "Downcasts the `Value` to a more specific kind if possible.") .def_prop_ro( "location", [](MlirValue self) { @@ -4445,7 +4819,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyMlirContext::forContext(mlirValueGetContext(self)), mlirValueGetLocation(self)); }, - "Returns the source location the value"); + "Returns the source location of the value."); PyBlockArgument::bind(m); PyOpResult::bind(m); @@ -4453,43 +4827,105 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_<PyAsmState>(m, "AsmState") .def(nb::init<PyValue &, bool>(), nb::arg("value"), - nb::arg("use_local_scope") = false) + nb::arg("use_local_scope") = false, + R"( + Creates an `AsmState` for consistent SSA value naming. + + Args: + value: The value to create state for. + use_local_scope: Whether to use local scope for naming.)") .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"), - nb::arg("use_local_scope") = false); + nb::arg("use_local_scope") = false, + R"( + Creates an AsmState for consistent SSA value naming. + + Args: + op: The operation to create state for. + use_local_scope: Whether to use local scope for naming.)"); //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- nb::class_<PySymbolTable>(m, "SymbolTable") - .def(nb::init<PyOperationBase &>()) - .def("__getitem__", - [](PySymbolTable &self, - const std::string &name) -> nb::typed<nb::object, PyOpView> { - return self.dunderGetItem(name); - }) - .def("insert", &PySymbolTable::insert, nb::arg("operation")) - .def("erase", &PySymbolTable::erase, nb::arg("operation")) - .def("__delitem__", &PySymbolTable::dunderDel) - .def("__contains__", - [](PySymbolTable &table, const std::string &name) { - return !mlirOperationIsNull(mlirSymbolTableLookup( - table, mlirStringRefCreate(name.data(), name.length()))); - }) + .def(nb::init<PyOperationBase &>(), + R"( + Creates a symbol table for an operation. + + Args: + operation: The `Operation` that defines a symbol table (e.g., a `ModuleOp`). + + Raises: + TypeError: If the operation is not a symbol table.)") + .def( + "__getitem__", + [](PySymbolTable &self, + const std::string &name) -> nb::typed<nb::object, PyOpView> { + return self.dunderGetItem(name); + }, + R"( + Looks up a symbol by name in the symbol table. + + Args: + name: The name of the symbol to look up. + + Returns: + The operation defining the symbol. + + Raises: + KeyError: If the symbol is not found.)") + .def("insert", &PySymbolTable::insert, nb::arg("operation"), + R"( + Inserts a symbol operation into the symbol table. + + Args: + operation: An operation with a symbol name to insert. + + Returns: + The symbol name attribute of the inserted operation. + + Raises: + ValueError: If the operation does not have a symbol name.)") + .def("erase", &PySymbolTable::erase, nb::arg("operation"), + R"( + Erases a symbol operation from the symbol table. + + Args: + operation: The symbol operation to erase. + + Note: + The operation is also erased from the IR and invalidated.)") + .def("__delitem__", &PySymbolTable::dunderDel, + "Deletes a symbol by name from the symbol table.") + .def( + "__contains__", + [](PySymbolTable &table, const std::string &name) { + return !mlirOperationIsNull(mlirSymbolTableLookup( + table, mlirStringRefCreate(name.data(), name.length()))); + }, + "Checks if a symbol with the given name exists in the table.") // Static helpers. .def_static("set_symbol_name", &PySymbolTable::setSymbolName, - nb::arg("symbol"), nb::arg("name")) + nb::arg("symbol"), nb::arg("name"), + "Sets the symbol name for a symbol operation.") .def_static("get_symbol_name", &PySymbolTable::getSymbolName, - nb::arg("symbol")) + nb::arg("symbol"), + "Gets the symbol name from a symbol operation.") .def_static("get_visibility", &PySymbolTable::getVisibility, - nb::arg("symbol")) + nb::arg("symbol"), + "Gets the visibility attribute of a symbol operation.") .def_static("set_visibility", &PySymbolTable::setVisibility, - nb::arg("symbol"), nb::arg("visibility")) + nb::arg("symbol"), nb::arg("visibility"), + "Sets the visibility attribute of a symbol operation.") .def_static("replace_all_symbol_uses", &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), - nb::arg("new_symbol"), nb::arg("from_op")) + nb::arg("new_symbol"), nb::arg("from_op"), + "Replaces all uses of a symbol with a new symbol name within " + "the given operation.") .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, nb::arg("from_op"), nb::arg("all_sym_uses_visible"), - nb::arg("callback")); + nb::arg("callback"), + "Walks symbol tables starting from an operation with a " + "callback function."); // Container bindings. PyBlockArgumentList::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index a14f09f..ba767ad 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -24,6 +24,8 @@ using namespace mlir::python; NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; + m.attr("T") = nb::type_var("T"); + m.attr("U") = nb::type_var("U"); nb::class_<PyGlobals>(m, "_Globals") .def_prop_rw("dialect_search_modules", @@ -102,6 +104,10 @@ NB_MODULE(_mlir, m) { return opClass; }); }, + // clang-format off + nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) " + "-> typing.Callable[[type[T]], type[T]]"), + // clang-format on "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); @@ -114,6 +120,10 @@ NB_MODULE(_mlir, m) { return typeCaster; }); }, + // clang-format off + nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( @@ -126,6 +136,10 @@ NB_MODULE(_mlir, m) { return valueCaster; }); }, + // clang-format off + nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index 64ea4329..aea195f 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -19,6 +19,7 @@ #include "llvm/Support/raw_ostream.h" #include <string> +#include <typeinfo> #include <variant> template <> @@ -344,7 +345,16 @@ public: /// Binds the indexing and length methods in the Python class. static void bind(nanobind::module_ &m) { - auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName) + const std::type_info &elemTy = typeid(ElementTy); + PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy); + assert(elemTyInfo && + "expected nb_type_lookup to succeed for Sliceable elemTy"); + nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo); + std::string sig = std::string("class ") + Derived::pyClassName + + "(collections.abc.Sequence[" + + nanobind::cast<std::string>(elemTyName) + "])"; + auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName, + nanobind::sig(sig.c_str())) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); @@ -395,7 +405,6 @@ public: /// Hook for derived classes willing to bind more methods. static void bindDerived(ClassTy &) {} -private: intptr_t startIndex; intptr_t length; intptr_t step; diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 1659437..0ac5fc5 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -27,6 +27,7 @@ #include <cstddef> #include <cstdint> +#include <deque> #include <list> #include <memory> #include <numeric> @@ -830,6 +831,23 @@ namespace { /// This class provides support for reading attribute and type entries from the /// bytecode. Attribute and Type entries are read lazily on demand, so we use /// this reader to manage when to actually parse them from the bytecode. +/// +/// The parsing of attributes & types are generally recursive, this can lead to +/// stack overflows for deeply nested structures, so we track a few extra pieces +/// of information to avoid this: +/// +/// - `depth`: The current depth while parsing nested attributes. We defer on +/// parsing deeply nested attributes to avoid potential stack overflows. The +/// deferred parsing is achieved by reporting a failure when parsing a nested +/// attribute/type and registering the index of the encountered attribute/type +/// in the deferred parsing worklist. Hence, a failure with deffered entry +/// does not constitute a failure, it also requires that folks return on +/// first failure rather than attempting additional parses. +/// - `deferredWorklist`: A list of attribute/type indices that we could not +/// parse due to hitting the depth limit. The worklist is used to capture the +/// indices of attributes/types that need to be parsed/reparsed when we hit +/// the depth limit. This enables moving the tracking of what needs to be +/// parsed to the heap. class AttrTypeReader { /// This class represents a single attribute or type entry. template <typename T> @@ -863,12 +881,34 @@ public: ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData); + LogicalResult readAttribute(uint64_t index, Attribute &result, + uint64_t depth = 0) { + return readEntry(attributes, index, result, "attribute", depth); + } + + LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) { + return readEntry(types, index, result, "type", depth); + } + /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. - Attribute resolveAttribute(size_t index) { - return resolveEntry(attributes, index, "Attribute"); + Attribute resolveAttribute(size_t index, uint64_t depth = 0) { + return resolveEntry(attributes, index, "Attribute", depth); + } + Type resolveType(size_t index, uint64_t depth = 0) { + return resolveEntry(types, index, "Type", depth); + } + + Attribute getAttributeOrSentinel(size_t index) { + if (index >= attributes.size()) + return nullptr; + return attributes[index].entry; + } + Type getTypeOrSentinel(size_t index) { + if (index >= types.size()) + return nullptr; + return types[index].entry; } - Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } /// Parse a reference to an attribute or type using the given reader. LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { @@ -909,23 +949,33 @@ public: llvm::getTypeName<T>(), ", but got: ", baseResult); } + /// Add an index to the deferred worklist for re-parsing. + void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); } + private: /// Resolve the given entry at `index`. template <typename T> - T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, - StringRef entryType); + T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index, + StringRef entryType, uint64_t depth = 0); - /// Parse an entry using the given reader that was encoded using the textual - /// assembly format. + /// Read the entry at the given index, returning failure if the entry is not + /// yet resolved. template <typename T> - LogicalResult parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType); + LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index, + T &result, StringRef entryType, uint64_t depth); /// Parse an entry using the given reader that was encoded using a custom /// bytecode format. template <typename T> LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, - StringRef entryType); + StringRef entryType, uint64_t index, + uint64_t depth); + + /// Parse an entry using the given reader that was encoded using the textual + /// assembly format. + template <typename T> + LogicalResult parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType); /// The string section reader used to resolve string references when parsing /// custom encoded attribute/type entries. @@ -951,6 +1001,10 @@ private: /// Reference to the parser configuration. const ParserConfig &parserConfig; + + /// Worklist for deferred attribute/type parsing. This is used to handle + /// deeply nested structures like CallSiteLoc iteratively. + std::vector<uint64_t> deferredWorklist; }; class DialectReader : public DialectBytecodeReader { @@ -959,10 +1013,11 @@ public: const StringSectionReader &stringReader, const ResourceSectionReader &resourceReader, const llvm::StringMap<BytecodeDialect *> &dialectsMap, - EncodingReader &reader, uint64_t &bytecodeVersion) + EncodingReader &reader, uint64_t &bytecodeVersion, + uint64_t depth = 0) : attrTypeReader(attrTypeReader), stringReader(stringReader), resourceReader(resourceReader), dialectsMap(dialectsMap), - reader(reader), bytecodeVersion(bytecodeVersion) {} + reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {} InFlightDiagnostic emitError(const Twine &msg) const override { return reader.emitError(msg); @@ -998,14 +1053,40 @@ public: // IR //===--------------------------------------------------------------------===// + /// The maximum depth to eagerly parse nested attributes/types before + /// deferring. + static constexpr uint64_t maxAttrTypeDepth = 5; + LogicalResult readAttribute(Attribute &result) override { - return attrTypeReader.parseAttribute(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + if (depth > maxAttrTypeDepth) { + if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) { + result = attr; + return success(); + } + attrTypeReader.addDeferredParsing(index); + return failure(); + } + return attrTypeReader.readAttribute(index, result, depth + 1); } LogicalResult readOptionalAttribute(Attribute &result) override { return attrTypeReader.parseOptionalAttribute(reader, result); } LogicalResult readType(Type &result) override { - return attrTypeReader.parseType(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + if (depth > maxAttrTypeDepth) { + if (Type type = attrTypeReader.getTypeOrSentinel(index)) { + result = type; + return success(); + } + attrTypeReader.addDeferredParsing(index); + return failure(); + } + return attrTypeReader.readType(index, result, depth + 1); } FailureOr<AsmDialectResourceHandle> readResourceHandle() override { @@ -1095,6 +1176,7 @@ private: const llvm::StringMap<BytecodeDialect *> &dialectsMap; EncodingReader &reader; uint64_t &bytecodeVersion; + uint64_t depth; }; /// Wraps the properties section and handles reading properties out of it. @@ -1238,69 +1320,112 @@ LogicalResult AttrTypeReader::initialize( } template <typename T> -T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, - StringRef entryType) { +T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, + uint64_t index, StringRef entryType, + uint64_t depth) { if (index >= entries.size()) { emitError(fileLoc) << "invalid " << entryType << " index: " << index; return {}; } - // If the entry has already been resolved, there is nothing left to do. - Entry<T> &entry = entries[index]; - if (entry.entry) - return entry.entry; + // Fast path: Try direct parsing without worklist overhead. This handles the + // common case where there are no deferred dependencies. + assert(deferredWorklist.empty()); + T result; + if (succeeded(readEntry(entries, index, result, entryType, depth))) { + assert(deferredWorklist.empty()); + return result; + } + if (deferredWorklist.empty()) { + // Failed with no deferred entries is error. + return T(); + } - // Parse the entry. - EncodingReader reader(entry.data, fileLoc); + // Slow path: Use worklist to handle deferred dependencies. Use a deque to + // iteratively resolve entries with dependencies. + // - Pop from front to process + // - Push new dependencies to front (depth-first) + // - Move failed entries to back (retry after dependencies) + std::deque<size_t> worklist; + llvm::DenseSet<size_t> inWorklist; - // Parse based on how the entry was encoded. - if (entry.hasCustomEncoding) { - if (failed(parseCustomEntry(entry, reader, entryType))) - return T(); - } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { - return T(); + // Add the original index and any dependencies from the fast path attempt. + worklist.push_back(index); + inWorklist.insert(index); + for (uint64_t idx : llvm::reverse(deferredWorklist)) { + if (inWorklist.insert(idx).second) + worklist.push_front(idx); } - if (!reader.empty()) { - reader.emitError("unexpected trailing bytes after " + entryType + " entry"); - return T(); + while (!worklist.empty()) { + size_t currentIndex = worklist.front(); + worklist.pop_front(); + + // Clear the deferred worklist before parsing to capture any new entries. + deferredWorklist.clear(); + + T result; + if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) { + inWorklist.erase(currentIndex); + continue; + } + + if (deferredWorklist.empty()) { + // Parsing failed with no deferred entries which implies an error. + return T(); + } + + // Move this entry to the back to retry after dependencies. + worklist.push_back(currentIndex); + + // Add dependencies to the front (in reverse so they maintain order). + for (uint64_t idx : llvm::reverse(deferredWorklist)) { + if (inWorklist.insert(idx).second) + worklist.push_front(idx); + } + deferredWorklist.clear(); } - return entry.entry; + return entries[index].entry; } template <typename T> -LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType) { - StringRef asmStr; - if (failed(reader.parseNullTerminatedString(asmStr))) - return failure(); +LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries, + uint64_t index, T &result, + StringRef entryType, uint64_t depth) { + if (index >= entries.size()) + return emitError(fileLoc) << "invalid " << entryType << " index: " << index; - // Invoke the MLIR assembly parser to parse the entry text. - size_t numRead = 0; - MLIRContext *context = fileLoc->getContext(); - if constexpr (std::is_same_v<T, Type>) - result = - ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); - else - result = ::parseAttribute(asmStr, context, Type(), &numRead, - /*isKnownNullTerminated=*/true); - if (!result) + // If the entry has already been resolved, return it. + Entry<T> &entry = entries[index]; + if (entry.entry) { + result = entry.entry; + return success(); + } + + // If the entry hasn't been resolved, try to parse it. + EncodingReader reader(entry.data, fileLoc); + LogicalResult parseResult = + entry.hasCustomEncoding + ? parseCustomEntry(entry, reader, entryType, index, depth) + : parseAsmEntry(entry.entry, reader, entryType); + if (failed(parseResult)) return failure(); - // Ensure there weren't dangling characters after the entry. - if (numRead != asmStr.size()) { - return reader.emitError("trailing characters found after ", entryType, - " assembly format: ", asmStr.drop_front(numRead)); - } + if (!reader.empty()) + return reader.emitError("unexpected trailing bytes after " + entryType + + " entry"); + + result = entry.entry; return success(); } template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, - StringRef entryType) { + StringRef entryType, + uint64_t index, uint64_t depth) { DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, - reader, bytecodeVersion); + reader, bytecodeVersion, depth); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); @@ -1350,6 +1475,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, return success(!!entry.entry); } +template <typename T> +LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType) { + StringRef asmStr; + if (failed(reader.parseNullTerminatedString(asmStr))) + return failure(); + + // Invoke the MLIR assembly parser to parse the entry text. + size_t numRead = 0; + MLIRContext *context = fileLoc->getContext(); + if constexpr (std::is_same_v<T, Type>) + result = + ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); + else + result = ::parseAttribute(asmStr, context, Type(), &numRead, + /*isKnownNullTerminated=*/true); + if (!result) + return failure(); + + // Ensure there weren't dangling characters after the entry. + if (numRead != asmStr.size()) { + return reader.emitError("trailing characters found after ", entryType, + " assembly format: ", asmStr.drop_front(numRead)); + } + return success(); +} + //===----------------------------------------------------------------------===// // Bytecode Reader //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index eaad8a8..bf23176 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) { return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace)); } +MlirTypeID mlirLLVMPointerTypeGetTypeID() { + return wrap(LLVM::LLVMPointerType::getTypeID()); +} + bool mlirTypeIsALLVMPointerType(MlirType type) { return isa<LLVM::LLVMPointerType>(unwrap(type)); } @@ -73,6 +77,10 @@ bool mlirTypeIsALLVMStructType(MlirType type) { return isa<LLVM::LLVMStructType>(unwrap(type)); } +MlirTypeID mlirLLVMStructTypeGetTypeID() { + return wrap(LLVM::LLVMStructType::getTypeID()); +} + bool mlirLLVMStructTypeIsLiteral(MlirType type) { return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified(); } @@ -159,9 +167,8 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations, return wrap(DIExpressionAttr::get( unwrap(ctx), - llvm::map_to_vector( - unwrapList(nOperations, operations, attrStorage), - [](Attribute a) { return cast<DIExpressionElemAttr>(a); }))); + llvm::map_to_vector(unwrapList(nOperations, operations, attrStorage), + llvm::CastTo<DIExpressionElemAttr>))); } MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) { @@ -202,7 +209,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( cast<DIExpressionAttr>(unwrap(allocated)), cast<DIExpressionAttr>(unwrap(associated)), llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }))); + llvm::CastTo<DINodeAttr>))); } MlirAttribute mlirLLVMDIDerivedTypeAttrGet( @@ -308,7 +315,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, return wrap(DISubroutineTypeAttr::get( unwrap(ctx), callingConvention, llvm::map_to_vector(unwrapList(nTypes, types, attrStorage), - [](Attribute a) { return cast<DITypeAttr>(a); }))); + llvm::CastTo<DITypeAttr>))); } MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) { @@ -338,10 +345,10 @@ MlirAttribute mlirLLVMDISubprogramAttrGet( cast<DISubroutineTypeAttr>(unwrap(type)), llvm::map_to_vector( unwrapList(nRetainedNodes, retainedNodes, nodesStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }), + llvm::CastTo<DINodeAttr>), llvm::map_to_vector( unwrapList(nAnnotations, annotations, annotationsStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }))); + llvm::CastTo<DINodeAttr>))); } MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) { @@ -398,7 +405,7 @@ MlirAttribute mlirLLVMDIImportedEntityAttrGet( cast<DINodeAttr>(unwrap(entity)), cast<DIFileAttr>(unwrap(file)), line, cast<StringAttr>(unwrap(name)), llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return cast<DINodeAttr>(a); }))); + llvm::CastTo<DINodeAttr>))); } MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name, diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 5c2a65d..75c811a 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Linalg.h" +#include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { const linalg::ContractionDimensions &contractionDims = *maybeDims; MLIRContext *ctx = linalgOp.getContext(); - auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute { - return wrap( - DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals))); + auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals))); }; result.batch = toAttr(contractionDims.batch); @@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, + size_t numMaps) { + MlirLinalgContractionDimensions result{}; + if (!indexingMaps || numMaps == 0) + return result; + + SmallVector<AffineMap, 3> maps; + maps.reserve(numMaps); + for (size_t i = 0; i < numMaps; ++i) { + maps.push_back(unwrap(indexingMaps[i])); + } + + FailureOr<linalg::ContractionDimensions> maybeDims = + linalg::inferContractionDims(maps); + if (failed(maybeDims)) + return result; + + MLIRContext *ctx = maps[0].getContext(); + + auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals))); + }; + + result.batch = toAttr(maybeDims->batch); + result.m = toAttr(maybeDims->m); + result.n = toAttr(maybeDims->n); + result.k = toAttr(maybeDims->k); + + return result; +} + MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) { auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op)); if (!linalgOp) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 2dbb993..81d86ad 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -22,7 +22,7 @@ using namespace mlir; extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, - bool enableObjectDump) { + bool enableObjectDump, bool enablePIC) { static bool initOnce = [] { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm @@ -38,12 +38,17 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); if (!tmBuilderOrError) { - llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; + llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host " + "because: \n"; + consumeError(tmBuilderOrError.takeError()); return MlirExecutionEngine{nullptr}; } + if (enablePIC) + tmBuilderOrError->setRelocationModel(llvm::Reloc::PIC_); auto tmOrError = tmBuilderOrError->createTargetMachine(); if (!tmOrError) { - llvm::errs() << "Failed to create a TargetMachine for the host\n"; + llvm::errs() << "Failed to create a TargetMachine for the host because: \n"; + consumeError(tmOrError.takeError()); return MlirExecutionEngine{nullptr}; } @@ -60,8 +65,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, jitOptions.jitCodeGenOptLevel = static_cast<llvm::CodeGenOptLevel>(optLevel); jitOptions.sharedLibPaths = libPaths; jitOptions.enableObjectDump = enableObjectDump; - auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions); + auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions, + std::move(tmOrError.get())); if (!jitOrError) { + llvm::errs() << "Failed to create an ExecutionEngine because: \n"; consumeError(jitOrError.takeError()); return MlirExecutionEngine{nullptr}; } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index f5f4ed3..e2e236a 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -536,7 +536,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, if (failed(memrefType.getStridesAndOffset(strides_, *offset))) return mlirLogicalResultFailure(); - (void)std::copy(strides_.begin(), strides_.end(), strides); + (void)llvm::copy(strides_, strides); return mlirLogicalResultSuccess(); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 1881865..ffcbed8 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -1129,6 +1129,11 @@ void mlirBlockArgumentSetType(MlirValue value, MlirType type) { blockArg.setType(unwrap(type)); } +void mlirBlockArgumentSetLocation(MlirValue value, MlirLocation loc) { + if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value))) + blockArg.setLoc(unwrap(loc)); +} + MlirOperation mlirOpResultGetOwner(MlirValue value) { return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner()); } diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 3a307a0..7584b17 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -16,8 +16,10 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -42,6 +44,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8); constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); constexpr Chipset kGfx942 = Chipset(9, 4, 2); constexpr Chipset kGfx950 = Chipset(9, 5, 0); +constexpr Chipset kGfx1250 = Chipset(12, 5, 0); /// Convert an unsigned number `val` to i32. static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, @@ -79,12 +82,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter, return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value); } -static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, - bool value) { - Type llvmI1 = rewriter.getI1Type(); - return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); -} - /// Returns the linear index used to access an element in the memref. static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, @@ -509,10 +506,16 @@ struct MemoryCounterWaitOpLowering if (std::optional<int> exp = adaptor.getExp()) ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); + if (std::optional<int> tensor = adaptor.getTensor()) + ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor); + rewriter.eraseOp(op); return success(); } + if (adaptor.getTensor()) + return op.emitOpError("unsupported chipset"); + auto getVal = [](Attribute attr) -> unsigned { if (attr) return cast<IntegerAttr>(attr).getInt(); @@ -684,12 +687,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, /// intrinsics having been defined before the AMD backend supported bfloat. We /// similarly need to pack 8-bit float types into integers as if they were i8 /// (which they are for the backend's purposes). -static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, - Location loc, - const TypeConverter *typeConverter, - bool isUnsigned, Value llvmInput, - Value mlirInput, - SmallVector<Value, 4> &operands) { +static void wmmaPushInputOperand( + ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, + Value mlirInput, SmallVectorImpl<Value> &operands, + SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast<VectorType>(inputType); if (!vectorType) { @@ -697,10 +699,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, return; } Type elemType = vectorType.getElementType(); - - if (elemType.isBF16()) - llvmInput = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -719,8 +717,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, } else if (elemType.isSignedInteger()) { localIsUnsigned = false; } - Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); - operands.push_back(sign); + attrs.push_back( + NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned))); } int64_t numBits = @@ -751,18 +749,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, - bool clamp, SmallVector<Value, 4> &operands) { + bool clamp, SmallVectorImpl<Value> &operands, + SmallVectorImpl<NamedAttribute> &attrs) { Type inputType = output.getType(); auto vectorType = dyn_cast<VectorType>(inputType); Type elemType = vectorType.getElementType(); - if (elemType.isBF16()) - output = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { - operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); + attrs.push_back( + NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset))); } else if (elemType.isInteger(32)) { - operands.push_back(createI1Constant(rewriter, loc, clamp)); + attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp))); } } @@ -1160,7 +1157,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, k, isRDNA3); // Handle gfx1250. - if (chipset == Chipset{12, 5, 0}) + if (chipset == kGfx1250) return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, elemDestType, k); @@ -1311,11 +1308,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); - // The WMMA operations represent vectors of bf16s as vectors of i16s, so we - // need to bitcast bfloats to i16 and then bitcast them back. + bool isGFX1250 = chipset >= kGfx1250; + + // The WMMA operations represent vectors of bf16s as vectors of i16s + // (except on gfx1250), so we need to bitcast bfloats to i16 and then + // bitcast them back. + auto aType = cast<VectorType>(adaptor.getSourceA().getType()); + auto bType = cast<VectorType>(adaptor.getSourceB().getType()); + auto destCType = cast<VectorType>(adaptor.getDestC().getType()); + bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250; + bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250; + bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250; + bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250; VectorType rawOutType = outType; - if (outType.getElementType().isBF16()) + if (castOutToI16) rawOutType = outType.clone(rewriter.getI16Type()); + Value a = adaptor.getSourceA(); + if (castAToI16) + a = LLVM::BitcastOp::create(rewriter, loc, + aType.clone(rewriter.getI16Type()), a); + Value b = adaptor.getSourceB(); + if (castBToI16) + b = LLVM::BitcastOp::create(rewriter, loc, + bType.clone(rewriter.getI16Type()), b); + Value destC = adaptor.getDestC(); + if (castDestCToI16) + destC = LLVM::BitcastOp::create( + rewriter, loc, destCType.clone(rewriter.getI16Type()), destC); std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); @@ -1325,18 +1344,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) return op.emitOpError("subwordOffset not supported on gfx12+"); - OperationState loweredOp(loc, *maybeIntrinsic); - loweredOp.addTypes(rawOutType); - SmallVector<Value, 4> operands; - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), - adaptor.getSourceA(), op.getSourceA(), operands); - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), - adaptor.getSourceB(), op.getSourceB(), operands); - wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), - op.getSubwordOffset(), op.getClamp(), operands); + SmallVector<NamedAttribute, 4> attrs; + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a, + op.getSourceA(), operands, attrs, "signA"); + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b, + op.getSourceB(), operands, attrs, "signB"); + wmmaPushOutputOperand(rewriter, loc, typeConverter, destC, + op.getSubwordOffset(), op.getClamp(), operands, + attrs); + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes(rawOutType); loweredOp.addOperands(operands); + loweredOp.addAttributes(attrs); Operation *lowered = rewriter.create(loweredOp); Operation *maybeCastBack = lowered; @@ -1492,6 +1513,20 @@ struct ExtPackedFp8OpLowering final ConversionPatternRewriter &rewriter) const override; }; +struct ScaledExtPackedMatrixOpLowering final + : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> { + ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(ScaledExtPackedMatrixOp op, + ScaledExtPackedMatrixOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + struct PackedTrunc2xFp8OpLowering final : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, @@ -1600,6 +1635,173 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( return success(); } +int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf, + int32_t firstScaleByte) { + // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f* + // operations, the attributes blockSize, sourceType, scaleWaveHalf, and + // firstScaleByte are merged into a single attribute scaleSel. This is how + // those values are merged together. (Note: scaleWaveHalf isn't a high-level + // attribute but is derifed from firstScaleLane). + assert(llvm::is_contained({16, 32}, blockSize)); + assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth)); + + const bool isFp8 = bitWidth == 8; + const bool isBlock16 = blockSize == 16; + + if (!isFp8) { + int32_t bit0 = isBlock16; + assert(llvm::is_contained({0, 1, 2}, firstScaleByte)); + int32_t bit1 = (firstScaleByte == 2) << 1; + assert(llvm::is_contained({0, 1}, scaleWaveHalf)); + int32_t bit2 = scaleWaveHalf << 2; + return bit2 | bit1 | bit0; + } + + int32_t bit0 = isBlock16; + // firstScaleByte is guaranteed to be defined by two bits. + assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); + int32_t bits2and1 = firstScaleByte << 1; + assert(llvm::is_contained({0, 1}, scaleWaveHalf)); + int32_t bit3 = scaleWaveHalf << 3; + int32_t bits = bit3 | bits2and1 | bit0; + // These are invalid cases. + assert(!llvm::is_contained( + {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); + return bits; +} + +static std::optional<StringRef> +scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; + if (isa<fp4>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); + return std::nullopt; + } + if (isa<fp8>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); + return std::nullopt; + } + if (isa<bf8>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); + return std::nullopt; + } + if (isa<fp6>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); + return std::nullopt; + } + if (isa<bf6>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); + return std::nullopt; + } + llvm_unreachable("invalid combination of element types for packed conversion " + "instructions"); +} + +LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite( + ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; + Location loc = op.getLoc(); + if (chipset != kGfx1250) { + return rewriter.notifyMatchFailure( + loc, + "Scaled fp packed conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + } + // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that + // is being selected. + int32_t scaleWaveHalf = op.getFirstScaleLane() / 16; + int32_t firstScaleByte = op.getFirstScaleByte(); + int32_t blockSize = op.getBlockSize(); + auto sourceType = cast<VectorType>(op.getSource().getType()); + auto srcElemType = cast<FloatType>(sourceType.getElementType()); + unsigned bitWidth = srcElemType.getWidth(); + + auto targetType = cast<VectorType>(op.getResult().getType()); + auto destElemType = cast<FloatType>(targetType.getElementType()); + + IntegerType i32 = rewriter.getI32Type(); + Value source = adaptor.getSource(); + Type llvmResultType = typeConverter->convertType(op.getResult().getType()); + Type packedType = nullptr; + if (isa<fp4>(srcElemType)) { + packedType = i32; + packedType = getTypeConverter()->convertType(packedType); + } else if (isa<fp8, bf8>(srcElemType)) { + packedType = VectorType::get(2, i32); + packedType = getTypeConverter()->convertType(packedType); + } else if (isa<fp6, bf6>(srcElemType)) { + packedType = VectorType::get(3, i32); + packedType = getTypeConverter()->convertType(packedType); + } else { + llvm_unreachable("invalid element type for packed scaled ext"); + } + + if (!packedType || !llvmResultType) { + return rewriter.notifyMatchFailure(op, "type conversion failed"); + } + + std::optional<StringRef> maybeIntrinsic = + scaledExtPacked816ToIntrinsic(srcElemType, destElemType); + if (!maybeIntrinsic.has_value()) + return op.emitOpError( + "no intrinsic matching packed scaled conversion on the given chipset"); + + int32_t scaleSel = + getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte); + Value castedScale = + LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); + + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes({llvmResultType}); + loweredOp.addOperands({castedSource, castedScale}); + + SmallVector<NamedAttribute, 1> attrs; + attrs.push_back( + NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel))); + + loweredOp.addAttributes(attrs); + Operation *lowered = rewriter.create(loweredOp); + rewriter.replaceOp(op, lowered); + + return success(); +} + LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -2073,6 +2275,441 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> { } }; +struct AMDGPUMakeDmaBaseLowering + : public ConvertOpToLLVMPattern<MakeDmaBaseOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx1250) + return op->emitOpError("make_dma_base is only supported on gfx1250"); + + Location loc = op.getLoc(); + + ValueRange ldsIndices = adaptor.getLdsIndices(); + Value lds = adaptor.getLds(); + auto ldsMemRefType = cast<MemRefType>(op.getLds().getType()); + + Value ldsPtr = + getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices); + + ValueRange globalIndices = adaptor.getGlobalIndices(); + Value global = adaptor.getGlobal(); + auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType()); + + Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType, + global, globalIndices); + + Type i32 = rewriter.getI32Type(); + Type i64 = rewriter.getI64Type(); + + Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr); + Value castForGlobalAddr = + LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr); + + Value lowHalf = + LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr); + + Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr, + createI64Constant(rewriter, loc, 32)); + + Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift); + + Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1); + Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask); + + Value typeField = createI32Constant(rewriter, loc, 2 << 30); + Value highHalfPlusType = + LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField); + + Value c0 = createI32Constant(rewriter, loc, 0); + Value c1 = createI32Constant(rewriter, loc, 1); + Value c2 = createI32Constant(rewriter, loc, 2); + Value c3 = createI32Constant(rewriter, loc, 3); + + Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32)); + assert(v4i32 && "expected type conversion to succeed"); + Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32); + result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + castForLdsAddr, c1); + result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + highHalfPlusType, c3); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct AMDGPUMakeDmaDescriptorLowering + : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter), + chipset(chipset) {} + Chipset chipset; + + Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); } + + Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc, + Value accumulator, Value value, int64_t shift) const { + shift = shift % 32; + Value shiftAmount; + if (shift != 0) { + shiftAmount = createI32Constant(rewriter, loc, shift % 32); + value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount); + } + + if (matchPattern(accumulator, mlir::m_Zero())) + return value; + + return LLVM::OrOp::create(rewriter, loc, accumulator, value); + } + + Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0) const { + Value mask = op.getWorkgroupMask(); + if (!mask) + return sgpr0; + + Type i32 = rewriter.getI32Type(); + Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask); + return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0); + } + + Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + // Compute data_size. + unsigned elementTypeWidthInBits = op.getElementTypeWidth(); + assert( + llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) && + "expected type width to be 8, 16, 32, or 64."); + int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8); + Value size = createI32Constant(rewriter, loc, dataSize); + return setValueAtOffset(rewriter, loc, sgpr0, size, 16); + } + + Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr; + if (!atomic_barrier_enable) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18); + } + + Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool iterate_enable = adaptor.getGlobalIncrement() != nullptr; + if (!iterate_enable) + return sgpr0; + + // TODO: In future PR, add other required fields for iteration. + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19); + } + + Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20); + } + + Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + if (!op.getWorkgroupMask()) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21); + } + + Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + IntegerType i32 = rewriter.getI32Type(); + Value padInterval = adaptor.getPadInterval(); + // pre-condition: padInterval can be a power of two between 2 and 256. + padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32, + padInterval, false); + padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]); + // post-condition: padInterval can be a value between 0 and 7. + return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22); + } + + Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + Value padAmount = adaptor.getPadAmount(); + // pre-condition: padAmount is a value between 1-128. + padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]); + // post-condition: padAmount is a value between 0-127. + return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25); + } + + Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr1, + ArrayRef<Value> consts) const { + bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr; + if (!atomic_barrier_enable) + return sgpr1; + + Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress(); + auto barrierAddressTy = + cast<MemRefType>(op.getAtomicBarrierAddress().getType()); + ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices(); + atomicBarrierAddress = + getStridedElementPtr(rewriter, loc, barrierAddressTy, + atomicBarrierAddress, atomicBarrierIndices); + IntegerType i32 = rewriter.getI32Type(); + // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies + // that the 3 LSBs are zero. + atomicBarrierAddress = + LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress); + atomicBarrierAddress = + LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]); + Value mask = createI32Constant(rewriter, loc, 0xFFFF); + atomicBarrierAddress = + LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask); + return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32); + } + + std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr1, Value sgpr2, + ArrayRef<Value> consts) const { + SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes(); + OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back(); + Value tensorDim0; + if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult)) + tensorDim0 = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDim0 = cast<Value>(tensorDim0OpFoldResult); + + Value c16 = createI32Constant(rewriter, loc, 16); + Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16); + sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48); + sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16); + return {sgpr1, sgpr2}; + } + + std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr2, Value sgpr3, + ArrayRef<Value> consts) const { + // TODO: Generalize to setTensorDimX. + SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes(); + OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1); + Value tensorDim1; + if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult)) + tensorDim1 = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDim1 = cast<Value>(tensorDim1OpFoldResult); + + Value c16 = createI32Constant(rewriter, loc, 16); + Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16); + sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80); + sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16); + return {sgpr2, sgpr3}; + } + + Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr, ArrayRef<Value> consts, size_t dimX, + int64_t offset) const { + SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes(); + + if (mixedSharedSizes.size() <= dimX) + return sgpr; + + OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX); + Value tileDimX; + if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) + tileDimX = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tileDimX = cast<Value>(tileDimXOpFoldResult); + + return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset); + } + + Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr3, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112); + } + + Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr4, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128); + } + + Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr4, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144); + } + + std::pair<Value, Value> + setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgprY, Value sgprZ, ArrayRef<Value> consts, + size_t dimX, int64_t offset) const { + SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides(); + + if (mixedGlobalStrides.size() <= dimX) + return {sgprY, sgprZ}; + + OpFoldResult tensorDimXStrideOpFoldResult = + *(mixedGlobalStrides.rbegin() + dimX); + Value tensorDimXStride; + if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult)) + tensorDimXStride = + createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult); + + constexpr int64_t first48bits = (1ll << 48) - 1; + Value mask = createI64Constant(rewriter, loc, first48bits); + tensorDimXStride = + LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride); + IntegerType i32 = rewriter.getI32Type(); + Value tensorDimXStrideLow = + LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride); + + int64_t shift = (offset % 32) == 0 ? 32 : offset % 32; + Value shiftVal = createI64Constant(rewriter, loc, shift); + Value tensorDimXStrideHigh = + LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal); + tensorDimXStrideHigh = + LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh); + + sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset); + sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh, + offset + shift); + return {sgprY, sgprZ}; + } + + std::pair<Value, Value> + setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const { + return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, + 0, 160); + } + + std::pair<Value, Value> + setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const { + return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, + 1, 208); + } + + Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<Value> consts) const { + Value sgprs[8]; + for (int64_t i = 0; i < 8; i++) { + sgprs[i] = consts[0]; + } + + sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]); + sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts); + + sgprs[1] = + setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts); + std::tie(sgprs[1], sgprs[2]) = + setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts); + std::tie(sgprs[2], sgprs[3]) = + setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts); + + sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts); + sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts); + sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts); + std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride( + op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts); + std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride( + op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts); + + IntegerType i32 = rewriter.getI32Type(); + Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32)); + assert(v8i32 && "expected type conversion to succeed"); + Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32); + + for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) { + dgroup1 = + LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant); + } + + return dgroup1; + } + + LogicalResult + matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx1250) + return op->emitOpError( + "make_dma_descriptor is only supported on gfx1250"); + + if (op.getRank() > 2) + return op->emitOpError("unimplemented"); + + Location loc = op.getLoc(); + + IntegerType i32 = rewriter.getI32Type(); + [[maybe_unused]] Type v4i32 = + this->typeConverter->convertType(VectorType::get(4, i32)); + assert(v4i32 && "expected type conversion to succeed"); + + SmallVector<Value> consts; + for (int64_t i = 0; i < 8; i++) + consts.push_back(createI32Constant(rewriter, loc, i)); + + Value dgroup0 = this->getDGroup0(adaptor); + Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts); + + SmallVector<Value> results = {dgroup0, dgroup1}; + rewriter.replaceOpWithMultiple(op, {results}); + return success(); + } +}; + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> { using Base::Base; @@ -2087,6 +2724,11 @@ struct ConvertAMDGPUToROCDLPass RewritePatternSet patterns(ctx); LLVMTypeConverter converter(ctx); + converter.addConversion([&](TDMBaseType type) -> Type { + Type i32 = IntegerType::get(type.getContext(), 32); + return converter.convertType(VectorType::get(4, i32)); + }); + populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); LLVMConversionTarget target(getContext()); target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); @@ -2122,25 +2764,27 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { populateAMDGPUMemorySpaceAttributeConversions(converter); - patterns - .add<FatRawBufferCastLowering, - RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, - RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, - RawBufferOpLowering<RawBufferAtomicFaddOp, - ROCDL::RawPtrBufferAtomicFaddOp>, - RawBufferOpLowering<RawBufferAtomicFmaxOp, - ROCDL::RawPtrBufferAtomicFmaxOp>, - RawBufferOpLowering<RawBufferAtomicSmaxOp, - ROCDL::RawPtrBufferAtomicSmaxOp>, - RawBufferOpLowering<RawBufferAtomicUminOp, - ROCDL::RawPtrBufferAtomicUminOp>, - RawBufferOpLowering<RawBufferAtomicCmpswapOp, - ROCDL::RawPtrBufferAtomicCmpSwap>, - AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, - SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, - WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, - PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, - PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, - TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); + patterns.add< + FatRawBufferCastLowering, + RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, + RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, + RawBufferOpLowering<RawBufferAtomicFaddOp, + ROCDL::RawPtrBufferAtomicFaddOp>, + RawBufferOpLowering<RawBufferAtomicFmaxOp, + ROCDL::RawPtrBufferAtomicFmaxOp>, + RawBufferOpLowering<RawBufferAtomicSmaxOp, + ROCDL::RawPtrBufferAtomicSmaxOp>, + RawBufferOpLowering<RawBufferAtomicUminOp, + ROCDL::RawPtrBufferAtomicUminOp>, + RawBufferOpLowering<RawBufferAtomicCmpswapOp, + ROCDL::RawPtrBufferAtomicCmpSwap>, + AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering, + ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, + GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering, + AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter, + chipset); patterns.add<AMDGPUSwizzleBitModeLowering>(converter); } diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp new file mode 100644 index 0000000..79816fc --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -0,0 +1,665 @@ +//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::func; + +static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, FunctionType funcT, bool setPrivate, + SymbolTableCollection *symbolTables = nullptr) { + OpBuilder::InsertionGuard g(b); + assert(!symTable->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&symTable->getRegion(0).front()); + FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT); + if (setPrivate) + funcOp.setPrivate(); + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); + symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); + } + return funcOp; +} + +/// Helper function to look up or create the symbol for a runtime library +/// function with the given parameter types. Returns an int64_t, unless a +/// different result type is specified. +static FailureOr<FuncOp> +lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, TypeRange paramTypes, + SymbolTableCollection *symbolTables = nullptr, + Type resultType = {}) { + if (!resultType) + resultType = IntegerType::get(symTable->getContext(), 64); + std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); + FailureOr<FuncOp> func = + lookupFnDecl(symTable, funcName, funcT, symbolTables); + // Failed due to type mismatch. + if (failed(func)) + return func; + // Successfully matched existing decl. + if (*func) + return *func; + + return createFnDecl(b, symTable, funcName, funcT, + /*setPrivate=*/true, symbolTables); +} + +/// Helper function to look up or create the symbol for a runtime library +/// function for a binary arithmetic operation. +/// +/// Parameter 1: APFloat semantics +/// Parameter 2: Left-hand side operand +/// Parameter 3: Right-hand side operand +/// +/// This function will return a failure if the function is found but has an +/// unexpected signature. +/// +static FailureOr<FuncOp> +lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, + SymbolTableCollection *symbolTables = nullptr) { + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type}, + symbolTables); +} + +static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) { + int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + return arith::ConstantOp::create(b, loc, b.getI32Type(), + b.getIntegerAttr(b.getI32Type(), sem)); +} + +/// Given two operands of vector type and vector result type (with the same +/// shape), call the given function for each pair of scalar operands and +/// package the result into a vector. If the given operands and result type are +/// not vectors, call the function directly. The second operand is optional. +template <typename Fn, typename... Values> +static Value forEachScalarValue(RewriterBase &rewriter, Location loc, + Value operand1, Value operand2, Type resultType, + Fn fn) { + auto vecTy1 = dyn_cast<VectorType>(operand1.getType()); + if (operand2) { + // Sanity check: Operand types must match. + assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) && + "expected same vector types"); + } + if (!vecTy1) { + // Not a vector. Call the function directly. + return fn(operand1, operand2, resultType); + } + + // Prepare scalar operands. + ResultRange sclars1 = + vector::ToElementsOp::create(rewriter, loc, operand1)->getResults(); + SmallVector<Value> scalars2; + if (!operand2) { + // No second operand. Create a vector of empty values. + scalars2.assign(vecTy1.getNumElements(), Value()); + } else { + llvm::append_range( + scalars2, + vector::ToElementsOp::create(rewriter, loc, operand2)->getResults()); + } + + // Call the function for each pair of scalar operands. + auto resultVecType = cast<VectorType>(resultType); + SmallVector<Value> results; + for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) { + Value result = fn(scalar1, scalar2, resultVecType.getElementType()); + results.push_back(result); + } + + // Package the results into a vector. + return vector::FromElementsOp::create( + rewriter, loc, + vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()), + results); +} + +/// Check preconditions for the conversion: +/// 1. All operands / results must be integers or floats (or vectors thereof). +/// 2. The bitwidth of the operands / results must be <= 64. +static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) { + for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) { + Type type = value.getType(); + if (auto vecTy = dyn_cast<VectorType>(type)) { + type = vecTy.getElementType(); + } + if (!type.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + op, "only integers and floats (or vectors thereof) are supported"); + } + if (type.getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + } + return success(); +} + +/// Rewrite a binary arithmetic operation to an APFloat function call. +template <typename OpTy> +struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> { + BinaryArithOpToAPFloatConversion(MLIRContext *context, + const char *APFloatName, + SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + APFloatName(APFloatName) {}; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + FailureOr<FuncOp> fn = + lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), + [&](Value lhs, Value rhs, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(resultType); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, lhs)); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, rhs)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, floatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + const char *APFloatName; +}; + +template <typename OpTy> +struct FpToFpConversion final : OpRewritePattern<OpTy> { + FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = lookupOrCreateApFloatFn( + rewriter, symTable, "convert", {i32Type, i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inFloatTy = cast<FloatType>(operand1.getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outFloatTy = cast<FloatType>(resultType); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + std::array<Value, 3> params = {inSemValue, outSemValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create( + rewriter, loc, outIntWType, resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, outFloatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +template <typename OpTy> +struct FpToIntConversion final : OpRewritePattern<OpTy> { + FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + isUnsigned(isUnsigned) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inFloatTy = cast<FloatType>(operand1.getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outIntTy = cast<IntegerType>(resultType); + Value outWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, outIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector<Value> params = {inSemValue, outWidthValue, + isUnsignedValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + return arith::TruncIOp::create(rewriter, loc, outIntTy, + resultOp->getResult(0)); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + +template <typename OpTy> +struct IntToFpConversion final : OpRewritePattern<OpTy> { + IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + isUnsigned(isUnsigned) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inIntTy = cast<IntegerType>(operand1.getType()); + Value operandBits = operand1; + if (operandBits.getType().getIntOrFloatBitWidth() < 64) { + if (isUnsigned) { + operandBits = + arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits); + } else { + operandBits = + arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits); + } + } + + // Call APFloat function. + auto outFloatTy = cast<FloatType>(resultType); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + Value inWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, inIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector<Value> params = {outSemValue, inWidthValue, + isUnsignedValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create( + rewriter, loc, outIntWType, resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, outFloatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + +struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> { + CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i8Type = IntegerType::get(symTable->getContext(), 8); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "compare", + {i32Type, i64Type, i64Type}, nullptr, i8Type); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), + [&](Value lhs, Value rhs, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(lhs.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, lhs)); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, rhs)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + Value comparisonResult = + func::CallOp::create(rewriter, loc, TypeRange(i8Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Generate an i1 SSA value that is "true" if the comparison result + // matches the given `val`. + auto checkResult = [&](llvm::APFloat::cmpResult val) { + return arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, comparisonResult, + arith::ConstantOp::create( + rewriter, loc, i8Type, + rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val))) + .getResult()); + }; + // Generate an i1 SSA value that is "true" if the comparison result + // matches any of the given `vals`. + std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> + checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) { + Value first = checkResult(vals.front()); + if (vals.size() == 1) + return first; + Value rest = checkResults(vals.drop_front()); + return arith::OrIOp::create(rewriter, loc, first, rest) + .getResult(); + }; + + // This switch-case statement was taken from arith::applyCmpPredicate. + Value result; + switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: + result = + arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 0)) + .getResult(); + break; + case arith::CmpFPredicate::OEQ: + result = checkResult(llvm::APFloat::cmpEqual); + break; + case arith::CmpFPredicate::OGT: + result = checkResult(llvm::APFloat::cmpGreaterThan); + break; + case arith::CmpFPredicate::OGE: + result = checkResults( + {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::OLT: + result = checkResult(llvm::APFloat::cmpLessThan); + break; + case arith::CmpFPredicate::OLE: + result = checkResults( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ONE: + // Not cmpUnordered and not cmpUnordered. + result = checkResults( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::ORD: + // Not cmpUnordered. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UEQ: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UGT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::UGE: + result = checkResults({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ULT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan}); + break; + case arith::CmpFPredicate::ULE: + result = checkResults({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UNE: + // Not cmpEqual. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpUnordered}); + break; + case arith::CmpFPredicate::UNO: + result = checkResult(llvm::APFloat::cmpUnordered); + break; + case arith::CmpFPredicate::AlwaysTrue: + result = + arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 1)) + .getResult(); + break; + } + return result; + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> { + NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::NegFOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(operand1.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, operand1)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, operandBits}; + Value negatedBits = + func::CallOp::create(rewriter, loc, TypeRange(i64Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Truncate result to the original width. + Value truncatedBits = + arith::TruncIOp::create(rewriter, loc, intWType, negatedBits); + return arith::BitcastOp::create(rewriter, loc, floatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +namespace { +struct ArithToAPFloatConversionPass final + : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { + using Base::Base; + + void runOnOperation() override; +}; + +void ArithToAPFloatConversionPass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add", + getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>( + context, "subtract", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>( + context, "multiply", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>( + context, "divide", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>( + context, "remainder", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>( + context, "minnum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>( + context, "maxnum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>( + context, "minimum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>( + context, "maximum", getOperation()); + patterns + .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>, + CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>( + context, getOperation()); + patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(), + /*isUnsigned=*/true); + patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(), + /*isUnsigned=*/true); + LogicalResult result = success(); + ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) { + result = failure(); + } + // NB: if you don't return failure, no other diag handlers will fire (see + // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). + return failure(); + }); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + if (failed(result)) + return signalPassFailure(); +} +} // namespace diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt new file mode 100644 index 0000000..31fce7a --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArithToAPFloat + ArithToAPFloat.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRArithTransforms + MLIRFuncDialect + MLIRFuncUtils + MLIRVectorDialect + ) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index b609990..220826d 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" @@ -280,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), + /*propAttr=*/Attribute{}, *getTypeConverter(), rewriter); } @@ -481,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(), + op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "unsupported floating point type"); + Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); LLVM::FastmathFlags fmf = diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index bebf1b8..613dc6d 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) +add_subdirectory(ArithToAPFloat) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index 86d02e6..6a0c211 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - op->getAttrs(), *getTypeConverter(), rewriter); + op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(), + rewriter); } }; diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 93fe2ed..2220f61 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -374,9 +374,12 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp( // Create a memory effect attribute corresponding to readnone. if (funcOp->hasAttr(readnoneAttrName)) { auto memoryAttr = LLVM::MemoryEffectsAttr::get( - rewriter.getContext(), - {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef, - LLVM::ModRefInfo::NoModRef}); + rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::NoModRef, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef}); newFuncOp.setMemoryEffectsAttr(memoryAttr); } diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 425594b..f143a9e 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -66,7 +66,10 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef; auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); func.setMemoryEffectsAttr(memAttr); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index d64c4d6..5848489 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -419,7 +419,10 @@ struct LowerGpuOpsToNVVMOpsPass final if (this->hasRedux) populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); configureGpuToNVVMConversionLegality(target); - if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed( + applyPartialConversion(m, target, std::move(llvmPatterns), config))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 99c059c..6254de8 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" using namespace mlir; @@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { if (type.getElementType().isF32()) return type.getOperand() == "COp" ? NVVM::MMATypes::f32 : NVVM::MMATypes::tf32; - + if (type.getElementType().isF64()) + return NVVM::MMATypes::f64; if (type.getElementType().isSignedInteger(8)) return NVVM::MMATypes::s8; if (type.getElementType().isUnsignedInteger(8)) @@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering // then passed on to the intrinsic call. Emit llvm ops to extract individual // values form lowered memrefs. SmallVector<Value> unpackedOps; - auto unpackOp = [&](Value operand) { + // f64 a and b fragments are not structs but scalars. + if (!isa<LLVM::LLVMStructType>(operand.getType())) { + unpackedOps.push_back(operand); + return; + } + // every other type is lowered to an LLVM struct, extract the values. auto structType = cast<LLVM::LLVMStructType>(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i); @@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering return failure(); Location loc = subgroupMmaConstantOp.getLoc(); Value cst = adaptor.getOperands()[0]; - LLVM::LLVMStructType type = convertMMAToLLVMType( + Type type = convertMMAToLLVMType( cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType())); + // If the element is not a struct, it means it's a scalar f64. + auto structType = dyn_cast<LLVM::LLVMStructType>(type); + if (!structType) { + rewriter.replaceOp(subgroupMmaConstantOp, cst); + return success(); + } // If the element type is a vector create a vector from the operand. - if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) { + if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) { Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = LLVM::ConstantOp::create(rewriter, loc, @@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering } cst = vecCst; } - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type); - for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType); + for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) { matrixStruct = LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i); } @@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering return failure(); Location loc = subgroupMmaElementwiseOp.getLoc(); size_t numOperands = adaptor.getOperands().size(); - LLVM::LLVMStructType destType = convertMMAToLLVMType( + Type destType = convertMMAToLLVMType( cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType())); - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType); - for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { + + // If the element is not a struct, it means it's a scalar f64. + LLVM::LLVMStructType structDestTy = + dyn_cast<LLVM::LLVMStructType>(destType); + if (!structDestTy) { + SmallVector<Value> operands; + for (auto operand : adaptor.getOperands()) { + operands.push_back(operand); + } + Value element = createScalarOp( + rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands); + rewriter.replaceOp(subgroupMmaElementwiseOp, element); + return success(); + } + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy); + for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) { SmallVector<Value> extractedOperands; for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { extractedOperands.push_back(LLVM::ExtractValueOp::create( @@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering } // namespace /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. -LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { +Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { NVVM::MMAFrag frag = convertOperand(type.getOperand()); NVVM::MMATypes eltType = getElementType(type); auto nRow = type.getShape()[0]; auto nCol = type.getShape()[1]; std::pair<Type, unsigned> typeInfo = NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); + // Special handling for f64 a and b fragments + Type f64Ty = Float64Type::get(type.getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + return f64Ty; + } return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); } diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp index bc2f2f2..d4b4c46 100644 --- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp +++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp @@ -107,16 +107,16 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); - Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, -1)); // Compute `x`. Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero); @@ -157,14 +157,14 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value one = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value one = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); // Compute the non-zero result. Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one); @@ -193,16 +193,16 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); - Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, -1)); // Compute `x`. Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 48a0319..f28a6cc 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Detail methods //===----------------------------------------------------------------------===// -void LLVM::detail::setNativeProperties(Operation *op, - IntegerOverflowFlags overflowFlags) { - if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) - iface.setOverflowFlags(overflowFlags); -} - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector<Type> resultTypes; @@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite( } // Create the operation through state since we don't know its C++ type. - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, targetAttrs); - - setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands, + resultTypes, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e7dd0b5..e5969c2 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -116,18 +116,38 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( auto llvmNDVectorTy = operands[0].getType(); if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) - return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, - rewriter, overflowFlags); - - auto callback = [op, targetOp, targetAttrs, overflowFlags, + return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr, + typeConverter, rewriter); + auto callback = [op, targetOp, targetAttrs, propertiesAttr, &rewriter](Type llvm1DVectorTy, ValueRange operands) { - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), - operands, llvm1DVectorTy, targetAttrs); - LLVM::detail::setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), + operands, llvm1DVectorTy, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); return newOp->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } + +/// Return the given type if it's a floating point type. If the given type is +/// a vector type, return its element type if it's a floating point type. +static FloatType getFloatingPointType(Type type) { + if (auto floatType = dyn_cast<FloatType>(type)) + return floatType; + if (auto vecType = dyn_cast<VectorType>(type)) + return dyn_cast<FloatType>(vecType.getElementType()); + return nullptr; +} + +bool LLVM::detail::isUnsupportedFloatingPointType( + const TypeConverter &typeConverter, Type type) { + FloatType floatType = getFloatingPointType(type); + if (!floatType) + return false; + Type convertedType = typeConverter.convertType(floatType); + if (!convertedType) + return true; + return !isa<FloatType>(convertedType); +} diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 16ef11a..59a16df 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -93,13 +93,13 @@ public: /// Different MPI implementations have different communicator types. /// Using i64 as a portable, intermediate type. /// Appropriate cast needs to take place before calling MPI functions. - virtual Value getCommWorld(const Location loc, + virtual Value getCommWorld(Location loc, ConversionPatternRewriter &rewriter) = 0; /// Type converter provides i64 type for communicator type. /// Converts to native type, which might be ptr or int or whatever. - virtual Value castComm(const Location loc, - ConversionPatternRewriter &rewriter, Value comm) = 0; + virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter, + Value comm) = 0; /// Get the MPI_STATUS_IGNORE value (typically a pointer type). virtual intptr_t getStatusIgnore() = 0; @@ -109,13 +109,12 @@ public: /// Gets or creates an MPI datatype as a value which corresponds to the given /// type. - virtual Value getDataType(const Location loc, - ConversionPatternRewriter &rewriter, Type type) = 0; + virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter, + Type type) = 0; /// Gets or creates an MPI_Op value which corresponds to the given /// enum value. - virtual Value getMPIOp(const Location loc, - ConversionPatternRewriter &rewriter, + virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter, mpi::MPI_ReductionOpEnum opAttr) = 0; }; diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 11f866c..0a382d8 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -122,7 +122,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, return totalSizeBytes.getResult(); } -static emitc::ApplyOp +static emitc::AddressOfOp createPointerFromEmitcArray(Location loc, OpBuilder &builder, TypedValue<emitc::ArrayType> arrayValue) { @@ -133,9 +133,9 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder, llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex); emitc::SubscriptOp subPtr = emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); - emitc::ApplyOp ptr = emitc::ApplyOp::create( + emitc::AddressOfOp ptr = emitc::AddressOfOp::create( builder, loc, emitc::PointerType::get(arrayType.getElementType()), - builder.getStringAttr("&"), subPtr); + subPtr); return ptr; } @@ -225,12 +225,12 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> { auto srcArrayValue = cast<TypedValue<emitc::ArrayType>>(operands.getSource()); - emitc::ApplyOp srcPtr = + emitc::AddressOfOp srcPtr = createPointerFromEmitcArray(loc, rewriter, srcArrayValue); auto targetArrayValue = cast<TypedValue<emitc::ArrayType>>(operands.getTarget()); - emitc::ApplyOp targetPtr = + emitc::AddressOfOp targetPtr = createPointerFromEmitcArray(loc, rewriter, targetArrayValue); emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( @@ -319,8 +319,8 @@ struct ConvertGetGlobal final emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create( rewriter, op.getLoc(), lvalueType, operands.getNameAttr()); emitc::PointerType pointerType = emitc::PointerType::get(resultTy); - rewriter.replaceOpWithNewOp<emitc::ApplyOp>( - op, pointerType, rewriter.getStringAttr("&"), globalLValue); + rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType, + globalLValue); return success(); } rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 9348d3c1..64a7f56 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -922,15 +922,12 @@ struct NVGPUMBarrierArriveExpectTxLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value txcount = truncToI32(b, adaptor.getTxcount()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( - op, barrier, txcount, adaptor.getPredicate()); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( - op, barrier, txcount, adaptor.getPredicate()); + op, Type{}, // return-value is optional and is void by default + barrier, txcount, // barrier and txcount + NVVM::MemScopeKind::CTA, // default scope is CTA + false, // relaxed-semantics is false + adaptor.getPredicate()); return success(); } }; @@ -949,13 +946,6 @@ struct NVGPUMBarrierTryWaitParityLowering Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( - op, barrier, phase, ticks); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, phase, ticks); return success(); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 021e31a..7fdc23a 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -66,6 +66,9 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> { for (NamedAttribute attr : op->getAttrs()) { if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) { Type convertedType = converter->convertType(typeAttr.getValue()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert type in attribute"); convertedAttrs.emplace_back(attr.getName(), TypeAttr::get(convertedType)); } else { diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 37cfc9f..03842cc 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -36,6 +36,7 @@ namespace { struct SCFToControlFlowPass : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> { + using Base::Base; void runOnOperation() override; }; @@ -736,7 +737,9 @@ void SCFToControlFlowPass::runOnOperation() { target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 76a822b..309121f 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -453,10 +453,24 @@ static LogicalResult processParallelLoop( 1, 2, rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1)); + // Map through cloningMap first so we use values valid at the launch + // scope, then ensure they are launch-independent (or cloned constants). + Value mappedStep = cloningMap.lookupOrDefault(step); + Value mappedLowerBound = cloningMap.lookupOrDefault(lowerBound); + + mappedStep = ensureLaunchIndependent(mappedStep); + mappedLowerBound = ensureLaunchIndependent(mappedLowerBound); + + // If either cannot be made available above the launch, fail gracefully. + if (!mappedStep || !mappedLowerBound) { + return rewriter.notifyMatchFailure( + parallelOp, "lower bound / step must be constant or defined above " + "the gpu.launch"); + } + newIndex = AffineApplyOp::create( rewriter, loc, annotation.getMap().compose(lowerAndStep), - ValueRange{operand, ensureLaunchIndependent(step), - ensureLaunchIndependent(lowerBound)}); + ValueRange{operand, mappedStep, mappedLowerBound}); // If there was also a bound, insert that, too. // TODO: Check that we do not assign bounds twice. if (annotation.getBound()) { diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 460595b..6423d49 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -188,7 +188,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, OpBuilder::InsertionGuard guard(builder); Type type = reduce.getOperands()[reductionIndex].getType(); auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(), - "__scf_reduction", type); + "__scf_reduction", type, + /*byref_element_type=*/{}); symbolTable.insert(decl); builder.createBlock(&decl.getInitializerRegion(), diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 50fca56..02b61bd 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1520,20 +1520,12 @@ public: if (!dstType) return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); - Location loc = tanOp.getLoc(); - Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); - Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); - rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); + rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType, + adaptor.getOperands()); return success(); } }; -/// Convert `spirv.Tanh` to -/// -/// exp(2x) - 1 -/// ----------- -/// exp(2x) + 1 -/// class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { public: using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; @@ -1546,18 +1538,8 @@ public: if (!dstType) return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); - Location loc = tanhOp.getLoc(); - Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); - Value multiplied = - LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); - Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); - Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value numerator = - LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); - Value denominator = - LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); - rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, - denominator); + rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType, + adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp index 9921a06..feb0489 100644 --- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp +++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp @@ -23,8 +23,11 @@ namespace mlir { using namespace mlir; -namespace { +//===----------------------------------------------------------------------===// +// PoisonOpLowering +//===----------------------------------------------------------------------===// +namespace { struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> { matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; - } // namespace -//===----------------------------------------------------------------------===// -// PoisonOpLowering -//===----------------------------------------------------------------------===// - LogicalResult PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -61,6 +59,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, } //===----------------------------------------------------------------------===// +// UnreachableOpLowering +//===----------------------------------------------------------------------===// + +namespace { +struct UnreachableOpLowering + : public ConvertOpToLLVMPattern<ub::UnreachableOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace +LogicalResult + +UnreachableOpLowering::matchAndRewrite( + ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp<LLVM::UnreachableOp>(op); + return success(); +} + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -93,7 +114,7 @@ struct UBToLLVMConversionPass void mlir::ub::populateUBToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<PoisonOpLowering>(converter); + patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp index 244d214..3831387 100644 --- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp +++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp @@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> { } }; +struct UnreachableOpLowering final : OpConversionPattern<ub::UnreachableOp> { + using Base::Base; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<spirv::UnreachableOp>(op); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final void mlir::ub::populateUBToSPIRVConversionPatterns( const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<PoisonOpLowering>(converter, patterns.getContext()); + patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter, + patterns.getContext()); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 69a317ec..05d541f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -345,7 +345,8 @@ public: matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); - MemRefType memRefType = scatter.getMemRefType(); + auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType()); + assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return rewriter.notifyMatchFailure(scatter, "memref type not supported"); @@ -1654,6 +1655,20 @@ private: return failure(); } } + } else if (auto floatTy = dyn_cast<FloatType>(printType)) { + // Print other floating-point types using the APFloat runtime library. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + Value floatBits = + LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value); + printer = + LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables); + emitCall(rewriter, loc, printer.value(), + ValueRange({semValue, floatBits})); + return success(); } else { return failure(); } diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 1b4d1a4..079e1e2 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -519,8 +519,13 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { return lowerToScatteredLoadOp(readOp, rewriter); } - // Perform common data transfer checks. VectorType vecTy = readOp.getVectorType(); + + // Lower using load.gather in 1D case + if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim()) + return lowerToScatteredLoadOp(readOp, rewriter); + + // Perform common data transfer checks. if (failed(storeLoadPreconditions(rewriter, readOp, vecTy))) return failure(); @@ -562,7 +567,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, /*packed=*/nullptr, transposeAttr, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(readOp, loadOp); return success(); @@ -616,7 +622,8 @@ struct TransferWriteLowering auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, indices, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(writeOp, storeOp); return success(); @@ -720,7 +727,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> { xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(loadOp, loadNdOp); return success(); @@ -758,7 +766,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> { auto storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(storeOp, storeNdOp); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index de552ce..0ecb50e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16}; // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. enum class NdTdescOffset : uint32_t { - BasePtr = 0, // Base pointer (i64) - BaseShapeW = 2, // Base shape width (i32) - BaseShapeH = 3, // Base shape height (i32) - TensorOffsetW = 4, // Tensor offset W (i32) - TensorOffsetH = 5 // Tensor offset H (i32) + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + BasePitch = 4, // Base pitch (i32) }; static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { @@ -151,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, } } +// +// Note: +// Block operations for tile of sub byte element types are handled by +// emulating with larger element types. +// Tensor descriptor are keep intact and only ops consuming them are +// emulated +// + class CreateNdDescToXeVMPattern : public OpConversionPattern<xegpu::CreateNdDescOp> { using OpConversionPattern::OpConversionPattern; @@ -179,16 +186,12 @@ class CreateNdDescToXeVMPattern Value baseAddr; Value baseShapeW; Value baseShapeH; - Value offsetW; - Value offsetH; // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); + SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -197,10 +200,20 @@ class CreateNdDescToXeVMPattern if (!sourceMemrefTy.hasRank()) { return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } - baseAddr = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + // Access adaptor after failure check to avoid rolling back generated code + // for materialization cast. + baseAddr = adaptor.getSource(); } else { baseAddr = adaptor.getSource(); + if (baseAddr.getType() != i64Ty) { + // Pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); + } + } + // 1D tensor descriptor is just the base address. + if (rank == 1) { + rewriter.replaceOp(op, baseAddr); + return success(); } // Utility for creating offset values from op fold result. auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec, @@ -209,19 +222,11 @@ class CreateNdDescToXeVMPattern val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - // Offsets are not supported (0 is used). - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); - if (sourceMemrefTy) { - // Cast index to i64. - baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else if (baseAddr.getType() != i64Ty) { - // Pointer type may be i32. Cast to i64 if needed. - baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); - } + // Get pitch value from op fold results. + Value basePitch = createOffset(mixedStrides, 0); // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -235,12 +240,9 @@ class CreateNdDescToXeVMPattern payload = vector::InsertOp::create(rewriter, loc, baseShapeH, payload, static_cast<int>(NdTdescOffset::BaseShapeH)); - payload = vector::InsertOp::create( - rewriter, loc, offsetW, payload, - static_cast<int>(NdTdescOffset::TensorOffsetW)); - payload = vector::InsertOp::create( - rewriter, loc, offsetH, payload, - static_cast<int>(NdTdescOffset::TensorOffsetH)); + payload = + vector::InsertOp::create(rewriter, loc, basePitch, payload, + static_cast<int>(NdTdescOffset::BasePitch)); rewriter.replaceOp(op, payload); return success(); } @@ -257,108 +259,240 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { ConversionPatternRewriter &rewriter) const override { auto mixedOffsets = op.getMixedOffsets(); int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto tileRank = tdescTy.getRank(); + if (opOffsetsSize != tileRank) + return rewriter.notifyMatchFailure( + op, "Expected offset rank to match descriptor rank."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); - if (elemBitSize % 8 != 0) + bool isSubByte = elemBitSize < 8; + uint64_t wScaleFactor = 1; + + if (!isSubByte && (elemBitSize % 8 != 0)) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); + auto tileW = tdescTy.getDimSize(tileRank - 1); + // For sub byte types, only 4bits are currently supported. + if (isSubByte) { + if (elemBitSize != 4) + return rewriter.notifyMatchFailure( + op, "Only sub byte types of 4bits are supported."); + if (tileRank != 2) + return rewriter.notifyMatchFailure( + op, "Sub byte types are only supported for 2D tensor descriptors."); + auto subByteFactor = 8 / elemBitSize; + auto tileH = tdescTy.getDimSize(0); + // Handle special case for packed load. + if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + if (op.getPacked().value_or(false)) { + // packed load is implemented as packed loads of 8bit elements. + if (tileH == systolicDepth * 4 && + tileW == executionSize * subByteFactor) { + // Usage case for loading as Matrix B with pack request. + // source is assumed to pre-packed into 8bit elements + // Emulate with 8bit loads with pack request. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(8); + tileW = executionSize; + wScaleFactor = subByteFactor; + } + } + } + // If not handled by packed load case above, handle other cases. + if (wScaleFactor == 1) { + auto sub16BitFactor = subByteFactor * 2; + if (tileW == executionSize * sub16BitFactor) { + // Usage case for loading as Matrix A operand + // Emulate with 16bit loads/stores. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(16); + tileW = executionSize; + wScaleFactor = sub16BitFactor; + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported tile shape for sub byte types."); + } + } + // recompute element bit size for emulation. + elemBitSize = elemType.getIntOrFloatBitWidth(); + } - VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); - Value payLoadAsI64 = - vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); - Value basePtr = vector::ExtractOp::create( - rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr)); - Value baseShapeW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); - Value baseShapeH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); - // Offsets are provided by the op. - // convert them to i32. - Value offsetW = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); - offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetW); - Value offsetH = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); - offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - // Convert base pointer (i64) to LLVM pointer type. - Value basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); - // Compute element byte size and surface width in bytes. - Value elemByteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); - Value surfaceW = - arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); - - // Get tile sizes and vblocks from the tensor descriptor type. - auto tileW = tdescTy.getDimSize(1); - auto tileH = tdescTy.getDimSize(0); - int32_t vblocks = tdescTy.getArrayLength(); - if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { - Value src = adaptor.getValue(); - // If store value is a scalar, get value from op instead of adaptor. - // Adaptor might have optimized away single element vector - if (src.getType().isIntOrFloat()) { - src = op.getValue(); + if (tileRank == 2) { + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = + vector::ExtractOp::create(rewriter, loc, payLoadAsI64, + static_cast<int>(NdTdescOffset::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); + Value basePitch = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch)); + // Offsets are provided by the op. + // convert them to i32. + Value offsetW = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + Value offsetH = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // FIXME: width or pitch is not the same as baseShapeW it should be the + // stride of the second to last dimension in row major layout. + // Compute width in bytes. + Value baseShapeWInBytes = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + // Compute pitch in bytes. + Value basePitchBytes = + arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize); + + if (wScaleFactor > 1) { + // Scale offsetW, baseShapeWInBytes for sub byte emulation. + // Note: tileW is already scaled above. + Value wScaleFactorValLog2 = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor)); + baseShapeWInBytes = arith::ShRSIOp::create( + rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2); + basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes, + wScaleFactorValLog2); + offsetW = + arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2); } - VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); - if (!srcVecTy) - return rewriter.notifyMatchFailure( - op, "Expected store value to be a vector type."); - // Get flat vector type of integer type with matching element bit size. - VectorType newSrcVecTy = - encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); - if (srcVecTy != newSrcVecTy) - src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); - auto storeCacheControl = - translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, src, - xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); - rewriter.eraseOp(op); - } else { - auto loadCacheControl = - translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { - xevm::BlockPrefetch2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, vblocks, - xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + // Get tile height from the tensor descriptor type. + auto tileH = tdescTy.getDimSize(0); + // Get vblocks from the tensor descriptor type. + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + xevm::BlockStore2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { - VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); - const bool vnni = op.getPacked().value_or(false); - auto transposeValue = op.getTranspose(); - bool transpose = - transposeValue.has_value() && transposeValue.value()[0] == 1; - VectorType loadedTy = encodeVectorTypeTo( - dstVecTy, vnni ? rewriter.getI32Type() - : rewriter.getIntegerType(elemBitSize)); - - Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, - surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - transpose, vnni, + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, + vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes, + baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW, + tileH, vblocks, transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + } else { + // 1D tensor descriptor. + // `tdesc` represents base address as i64 + // Offset in number of elements, need to multiply by element byte size. + // Compute byte offset. + // byteOffset = offset * elementByteSize + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI64Type(), offset); + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemBitSize / 8); + Value byteOffset = + rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize); + // Final address = basePtr + byteOffset + Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>( + loc, tdesc, + getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(), + byteOffset)); + // Convert base pointer (i64) to LLVM pointer type. + Value finalPtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>( + op, finalPtrLLVM, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + VectorType resTy = cast<VectorType>(op.getValue().getType()); + VectorType loadedTy = + encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize)); + Value load = xevm::BlockLoadOp::create( + rewriter, loc, loadedTy, finalPtrLLVM, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); - resultFlatVec = vector::BitCastOp::create( - rewriter, loc, - encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), - resultFlatVec); - rewriter.replaceOp(op, resultFlatVec); + if (loadedTy != resTy) + load = vector::BitCastOp::create(rewriter, loc, resTy, load); + rewriter.replaceOp(op, load); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported operation: xegpu.prefetch_nd with tensor " + "descriptor rank == 1"); } } return success(); @@ -511,9 +645,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { } }; -// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions -// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than -// 32 bits will be converted to 32 bits. class CreateMemDescOpPattern final : public OpConversionPattern<xegpu::CreateMemDescOp> { public: @@ -522,16 +653,7 @@ public: matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resTy = op.getMemDesc(); - - // Create the result MemRefType with the same shape, element type, and - // memory space - auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); - - Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, - op.getSource(), zero, ValueRange()); - rewriter.replaceOp(op, viewOp); + rewriter.replaceOp(op, adaptor.getSource()); return success(); } }; @@ -551,19 +673,27 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); - Value basePtrStruct = adaptor.getMemDesc(); + Value baseAddr32 = adaptor.getMemDesc(); Value mdescVal = op.getMemDesc(); // Load result or Store value Type can be vector or scalar. - Value data; - if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) - data = op.getResult(); - else - data = adaptor.getData(); - VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); + Type dataTy; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + Type resType = op.getResult().getType(); + // Some transforms may leave unit dimension in the 2D vector, adaptors do + // not catch it for results. + if (auto vecType = dyn_cast<VectorType>(resType)) { + assert(llvm::count_if(vecType.getShape(), + [](int64_t d) { return d != 1; }) <= 1 && + "Expected either 1D vector or nD with unit dimensions"); + resType = VectorType::get({vecType.getNumElements()}, + vecType.getElementType()); + } + dataTy = resType; + } else + dataTy = adaptor.getData().getType(); + VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy); if (!valOrResVecTy) - valOrResVecTy = VectorType::get(1, data.getType()); - if (valOrResVecTy.getShape().size() != 1) - return rewriter.notifyMatchFailure(op, "Expected 1D data vector."); + valOrResVecTy = VectorType::get(1, dataTy); int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); @@ -579,21 +709,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); - Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, basePtrStruct); - - // Convert base pointer (ptr) to i32 - Value basePtrI32 = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), basePtrLLVM); - Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); linearOffset = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getI32Type(), linearOffset); - basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, - elemByteSize); + Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, + linearOffset, elemByteSize); // convert base pointer (i32) to LLVM pointer type - basePtrLLVM = + Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); if (op.getSubgroupBlockIoAttr()) { @@ -929,20 +1052,22 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + // Scattered descriptors are not supported in XeVM lowering. if (type.isScattered()) + return {}; + if (type.getRank() == 1) return IntegerType::get(&getContext(), 64); auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); - // Convert MemDescType into flattened MemRefType for SLM + // Convert MemDescType into i32 for SLM typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - Type elemTy = type.getElementType(); - int numElems = type.getNumElements(); - return MemRefType::get(numElems, elemTy, AffineMap(), 3); + return IntegerType::get(&getContext(), 32); }); typeConverter.addConversion([&](MemRefType type) -> Type { - // Convert MemRefType to i64 type. + if (type.getMemorySpaceAsInt() == 3) + return IntegerType::get(&getContext(), 32); return IntegerType::get(&getContext(), 64); }); @@ -1059,6 +1184,7 @@ struct ConvertXeGPUToXeVMPass }; typeConverter.addSourceMaterialization( singleElementVectorMaterializationCast); + typeConverter.addSourceMaterialization(vectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index f276984..20a420d 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -290,7 +290,7 @@ static LLVM::CallOp createDeviceFunctionCall( ArrayRef<Type> argTypes, ArrayRef<Value> args, mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs, LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { - auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); + auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); assert(moduleOp && "Expecting module"); Location loc = op->getLoc(); @@ -401,7 +401,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::NoModRef, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); auto funcAttrs = convergentNoUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; Value result = @@ -450,7 +453,10 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); funcAttr.memEffectsAttr = memAttr; LLVM::CallOp call = createDeviceFunctionCall( @@ -556,7 +562,10 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); funcAttr = noUnwindAttrs; funcAttr.memEffectsAttr = memAttr; } else { @@ -798,7 +807,10 @@ class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> { constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); call.setMemoryEffectsAttr(memAttr); rewriter.replaceOp(op, call); return success(); @@ -836,7 +848,10 @@ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> { constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); call.setMemoryEffectsAttr(memAttr); rewriter.replaceOp(op, call); return success(); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index df955fc..b7a665b 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -55,6 +55,10 @@ void AMDGPUDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" @@ -339,19 +343,45 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// -// ScaledExtPacked816Op +// ScaledExtPackedMatrixOp //===----------------------------------------------------------------------===// -LogicalResult ScaledExtPacked816Op::verify() { +LogicalResult ScaledExtPackedMatrixOp::verify() { int blockSize = getBlockSize(); - assert((blockSize == 16 || blockSize == 32) && "invalid block size"); + assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size"); + int firstScaleByte = getFirstScaleByte(); - if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) { - return emitOpError( - "blockSize of 16 can only have firstScaleByte be 0 or 1."); - } - if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) { - return emitOpError( - "blockSize of 32 can only have firstScaleByte be 0 or 2."); + int firstScaleLane = getFirstScaleLane(); + auto sourceType = cast<VectorType>(getSource().getType()); + Type elementType = sourceType.getElementType(); + auto floatType = cast<FloatType>(elementType); + unsigned bitWidth = floatType.getWidth(); + + assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth)); + + const bool is_fp8 = bitWidth == 8; + const bool is_block_16 = blockSize == 16; + + if (!is_fp8) { + if (is_block_16) { + if (!llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError("blockSize of 16 can only have firstScaleByte be 0 " + "or 1 for f4 and f6."); + } + } else { + if (!llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError("blockSize of 32 can only have firstScaleByte be 0 " + "or 2 for f4 and f6."); + } + } + } else { + if (is_block_16) { + bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) || + ((firstScaleLane == 16) && (firstScaleByte == 2)); + if (!is_valid) { + return emitOpError("blockSize of 16 can only have (firstScaleLane, " + "firstScaleByte) be (0, 0) or (16, 2) for f8."); + } + } } return success(); @@ -567,6 +597,53 @@ LogicalResult PermlaneSwapOp::verify() { } //===----------------------------------------------------------------------===// +// MemoryCounterWaitOp +//===----------------------------------------------------------------------===// + +namespace { +/// Fuse adjacent memory counter wait ops, taking the minimum value of the +/// counters. +struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> { + using Base::Base; + + LogicalResult matchAndRewrite(MemoryCounterWaitOp op, + PatternRewriter &rewriter) const override { + auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode()); + if (!next) + return failure(); + + auto setters = {&MemoryCounterWaitOp::setLoad, + &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs, + &MemoryCounterWaitOp::setExp, + &MemoryCounterWaitOp::setTensor}; + auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(), + op.getTensor()}; + auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(), + next.getExp(), next.getTensor()}; + rewriter.modifyOpInPlace(op, [&] { + for (auto [setter, lhs, rhs] : + llvm::zip_equal(setters, lhsVals, rhsVals)) { + if (lhs && rhs) { + (op.*setter)(std::min(*lhs, *rhs)); + } else if (lhs) { + (op.*setter)(*lhs); + } else if (rhs) { + (op.*setter)(*rhs); + } + } + }); + rewriter.eraseOp(next); + return success(); + } +}; +} // namespace + +void MemoryCounterWaitOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add<FuseMemoryCounterWaitOp>(context); +} + +//===----------------------------------------------------------------------===// // GatherToLDSOp //===----------------------------------------------------------------------===// @@ -662,19 +739,123 @@ LogicalResult TransposeLoadOp::verify() { }; auto validNumElems = kValidLoadSizeMap.find(elementTypeSize); - if (validNumElems == kValidLoadSizeMap.end()) { + if (validNumElems == kValidLoadSizeMap.end()) return emitOpError("Unsupported element type size for transpose load: ") << elementTypeSize << " bits"; - } - if (numElements != validNumElems->second) { + + if (numElements != validNumElems->second) return emitOpError( "Transferring type size mismatch: expected num of elements: ") << validNumElems->second; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MakeDmaBaseOp +//===----------------------------------------------------------------------===// + +LogicalResult MakeDmaBaseOp::verify() { + + auto ldsType = cast<MemRefType>(getLds().getType()); + auto globalType = cast<MemRefType>(getGlobal().getType()); + if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace())) + return emitOpError( + "lds memref must have workgroup address space attribute."); + if (!hasGlobalMemorySpace(globalType.getMemorySpace())) + return emitOpError( + "global memref must have global address space attribute."); + + Type elementType = ldsType.getElementType(); + unsigned width = elementType.getIntOrFloatBitWidth(); + + if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width)) + return emitOpError( + "element type must be 1, 2, 4, or 8 bytes long but type was ") + << width << " bits long."; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MakeDmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult MakeDmaDescriptorOp::verify() { + ArrayRef<int64_t> globalStaticStrides = getGlobalStaticStrides(); + + if (globalStaticStrides.empty()) + return emitOpError("strides must not be empty."); + if (globalStaticStrides.back() != 1) + return emitOpError("strides for the innermost dimension must be 1."); + + ArrayRef<int64_t> globalStaticSizes = getGlobalStaticSizes(); + size_t rank = globalStaticSizes.size(); + if (rank > 5) + return emitOpError("tensor and tile must be at most of rank 5."); + if (rank != globalStaticStrides.size()) + return emitOpError("strides and sizes must have same rank."); + + ArrayRef<int64_t> sharedStaticSizes = getSharedStaticSizes(); + if (rank != sharedStaticSizes.size()) + return emitOpError("tensor must have same rank as tile."); + + unsigned elementTypeWidth = getElementTypeWidth(); + if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth)) + return emitOpError( + "element type width must be 1, 2, 4 or 8 bytes, but was ") + << elementTypeWidth << " bits long"; + + if (Value atomicBarrierAddress = getAtomicBarrierAddress()) { + auto atomicBarrierAddressType = + cast<MemRefType>(atomicBarrierAddress.getType()); + bool barrierInLDS = + hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace()); + if (!barrierInLDS) + return emitOpError("atomic barrier address must be in LDS."); } + if (getEarlyTimeout() && !getWorkgroupMask()) + return emitOpError( + "early timeout does not apply when workgroup_mask is not set."); return success(); } +OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) { + SmallVector<OpFoldResult> mixedGlobalSizes(getMixedGlobalSizes()); + SmallVector<OpFoldResult> mixedGlobalStrides(getMixedGlobalStrides()); + SmallVector<OpFoldResult> mixedSharedSizes(getMixedSharedSizes()); + + if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true))) + return nullptr; + + SmallVector<Value> dynamicGlobalSizes, dynamicGlobalStrides, + dynamicSharedSizes; + SmallVector<int64_t> staticGlobalSizes, staticGlobalStrides, + staticSharedSizes; + + dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes, + staticGlobalSizes); + setGlobalStaticSizes(staticGlobalSizes); + getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes); + + dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides, + staticGlobalStrides); + setGlobalStaticStrides(staticGlobalStrides); + getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides); + + dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes, + staticSharedSizes); + setSharedStaticSizes(staticSharedSizes); + getSharedDynamicSizesMutable().assign(dynamicSharedSizes); + return getResult(); +} + //===----------------------------------------------------------------------===// // ScaledMFMAOp //===----------------------------------------------------------------------===// @@ -813,5 +994,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index f15c63c..89ef51f 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -33,19 +33,18 @@ using namespace mlir::amdgpu; /// This pattern supports lowering of: `vector.maskedload` to `vector.load` /// and `arith.select` if the memref is in buffer address space. -static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, - vector::MaskedLoadOp maskedOp) { - auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType()); +static LogicalResult hasBufferAddressSpace(Type type) { + auto memRefType = dyn_cast<MemRefType>(type); if (!memRefType) - return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); + return failure(); Attribute addrSpace = memRefType.getMemorySpace(); if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace)) - return rewriter.notifyMatchFailure(maskedOp, "no address space"); + return failure(); if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() != amdgpu::AddressSpace::FatRawBuffer) - return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); + return failure(); return success(); } @@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> { LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp, PatternRewriter &rewriter) const override { if (maskedOp->hasAttr(kMaskedloadNeedsMask)) - return failure(); + return rewriter.notifyMatchFailure(maskedOp, "already rewritten"); - if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { - return failure(); + if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) { + return rewriter.notifyMatchFailure( + maskedOp, "isn't a load from a fat buffer resource"); } // Check if this is either a full inbounds load or an empty, oob load. If @@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, PatternRewriter &rewriter) const override { + if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType()))) + return rewriter.notifyMatchFailure( + loadOp, "buffer loads are handled by a more specialized pattern"); + FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask()); if (failed(maybeCond)) { - return failure(); + return rewriter.notifyMatchFailure(loadOp, + "isn't loading a broadcasted scalar"); } Value cond = maybeCond.value(); @@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, PatternRewriter &rewriter) const override { + // A condition-free implementation of fully masked stores requires + // 1) an accessor for the num_records field on buffer resources/fat pointers + // 2) knowledge that said field will always be set accurately - that is, + // that writes to x < num_records of offset wouldn't trap, which is + // something a pattern user would need to assert or we'd need to prove. + // + // Therefore, conditional stores to buffers still go down this path at + // present. + FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask()); if (failed(maybeCond)) { return failure(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 0c35921..c6addfb 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -5421,7 +5421,7 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final return rewriter.notifyMatchFailure(op, "no unit basis entries to replace"); - if (newIndices.size() == 0) { + if (newIndices.empty()) { rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); return success(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index c942c02..b04e2d6 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -82,7 +82,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); newShape[0] = 2; - std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); + llvm::copy(oldShape, newShape.begin() + 1); return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); }; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 4743941..8f1249e 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1711,6 +1711,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) { outermost.getBody()->getOperations().splice( Block::iterator(secondOutermostLoop.getOperation()), innermost.getBody()->getOperations()); + for (auto [iter, init] : + llvm::zip_equal(secondOutermostLoop.getRegionIterArgs(), + secondOutermostLoop.getInits())) { + iter.replaceAllUsesWith(init); + iter.dropAllUses(); + } secondOutermostLoop.erase(); return success(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index de3efc9f..e256915 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -389,8 +389,8 @@ def TruncIExtUIToExtUI : // trunci(shrsi(x, c)) -> trunci(shrui(x, c)) def TruncIShrSIToTrunciShrUI : Pat<(Arith_TruncIOp:$tr - (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow), - (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow), + (Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0), $exact), $overflow), + (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)), $exact), $overflow), [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index adeb50b..c4e81e5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -35,7 +35,7 @@ static Value createConst(Location loc, Type type, int value, } /// Create a float constant. -static Value createFloatConst(Location loc, Type type, APFloat value, +static Value createFloatConst(Location loc, Type type, const APFloat &value, PatternRewriter &rewriter) { auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast<ShapedType>(type)) { diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index 39e398b..cb7c3d7 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -150,7 +150,7 @@ public: rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask()); } - auto extOp = op.getLhs().getDefiningOp(); + auto *extOp = op.getLhs().getDefiningOp(); arm_sme::CombiningKind kind = op.getKind(); if (kind == arm_sme::CombiningKind::Add) { @@ -311,8 +311,8 @@ public: rhsMask = packInputs(rhs0Mask, rhs1Mask); } - auto lhsExtOp = op.getLhs().getDefiningOp(); - auto rhsExtOp = op.getRhs().getDefiningOp(); + auto *lhsExtOp = op.getLhs().getDefiningOp(); + auto *rhsExtOp = op.getRhs().getDefiningOp(); arm_sme::CombiningKind kind = op.getKind(); if (kind == arm_sme::CombiningKind::Add) { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index e0cf353..9b11270 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const { return false; } -// bufferization.to_buffer is not allowed to change the rank. -static void ensureToBufferOpIsValid(Value tensor, Type memrefType) { -#ifndef NDEBUG - auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType()); - assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() == - rankedTensorType.getRank()) && - "to_buffer would be invalid: mismatching ranks"); -#endif -} - FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state) { @@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state); if (failed(bufferType)) return failure(); - ensureToBufferOpIsValid(value, *bufferType); + return bufferization::ToBufferOp::create(rewriter, value.getLoc(), *bufferType, value) .getResult(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index d6c3cd6..bd177ba 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel mlir::LogicalResult verifyCompatibleBufferType( mlir::Type tensor, BufferLikeType bufferType, llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const { - assert(isa<TensorType>(tensor) && "expected tensor type"); - assert(isa<BaseMemRefType>(bufferType) && "expected memref type"); - auto tensorType = cast<ShapedType>(tensor); auto memrefType = cast<ShapedType>(bufferType); diff --git a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp index 51feec7..f8eb45c 100644 --- a/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp +++ b/mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp @@ -17,6 +17,10 @@ // Pipeline implementation. //===----------------------------------------------------------------------===// +void mlir::bufferization::buildBufferDeallocationPipeline(OpPassManager &pm) { + buildBufferDeallocationPipeline(pm, BufferDeallocationPipelineOptions()); +} + void mlir::bufferization::buildBufferDeallocationPipeline( OpPassManager &pm, const BufferDeallocationPipelineOptions &options) { memref::ExpandReallocPassOptions expandAllocPassOptions{ @@ -44,5 +48,7 @@ void mlir::bufferization::registerBufferizationPipelines() { "The default pipeline for automatically inserting deallocation " "operations after one-shot bufferization. Deallocation operations " "(except `memref.realloc`) may not be present already.", - buildBufferDeallocationPipeline); + [](OpPassManager &pm, const BufferDeallocationPipelineOptions &options) { + buildBufferDeallocationPipeline(pm, options); + }); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 1784964..677c0ba 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SubsetOpInterface.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { namespace bufferization { @@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, // this replacement. Operation *insertionPoint = findValidInsertionPoint(emptyTensorOp, user, neededValues); - if (!insertionPoint) - return {}; + if (!insertionPoint) { + // If no already suitable insertion point was found, attempt to move all + // needed values before the user. + if (failed(moveValueDefinitions(rewriter, neededValues, user))) + return {}; + insertionPoint = user; + } rewriter.setInsertionPoint(insertionPoint); Value replacement = diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 9ccbfd3..5dfe3e6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -497,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, // terminates. All of them must be equivalent subsets. SetVector<Value> backwardSlice = state.findValueInReverseUseDefChain(opOperand, matchingSubset); - return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset)); + return llvm::all_of(backwardSlice, matchingSubset); } /// Return "true" if the given "read" and potentially conflicting "write" are diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt index 58551bb..05a787f 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect MLIRControlFlowInterfaces MLIRIR MLIRSideEffectInterfaces + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index f1da1a1..d2078d8 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -445,6 +446,37 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { return success(replaced); } }; + +/// If the destination block of a conditional branch contains only +/// ub.unreachable, unconditionally branch to the other destination. +struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> { + using OpRewritePattern<CondBranchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // If the "true" destination is unreachable, branch to the "false" + // destination. + Block *trueDest = condbr.getTrueDest(); + Block *falseDest = condbr.getFalseDest(); + if (llvm::hasSingleElement(*trueDest) && + isa<ub::UnreachableOp>(trueDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest, + condbr.getFalseOperands()); + return success(); + } + + // If the "false" destination is unreachable, branch to the "true" + // destination. + if (llvm::hasSingleElement(*falseDest) && + isa<ub::UnreachableOp>(falseDest->getTerminator())) { + rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, + condbr.getTrueOperands()); + return success(); + } + + return failure(); + } +}; } // namespace void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -452,7 +484,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, SimplifyCondBranchIdenticalSuccessors, SimplifyCondBranchFromCondBranchOnSameCondition, - CondBranchTruthPropagation>(context); + CondBranchTruthPropagation, DropUnreachableCondBranch>(context); } SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index d478220..b0566dd 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -226,6 +226,21 @@ FailureOr<SmallVector<ReplacementItem>> parseFormatString( } //===----------------------------------------------------------------------===// +// AddressOfOp +//===----------------------------------------------------------------------===// + +LogicalResult AddressOfOp::verify() { + emitc::LValueType referenceType = getReference().getType(); + emitc::PointerType resultType = getResult().getType(); + + if (referenceType.getValueType() != resultType.getPointee()) + return emitOpError("requires result to be a pointer to the type " + "referenced by operand"); + + return success(); +} + +//===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -380,6 +395,20 @@ LogicalResult emitc::ConstantOp::verify() { OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } //===----------------------------------------------------------------------===// +// DereferenceOp +//===----------------------------------------------------------------------===// + +LogicalResult DereferenceOp::verify() { + emitc::PointerType pointerType = getPointer().getType(); + + if (pointerType.getPointee() != getResult().getType().getValueType()) + return emitOpError("requires result to be an lvalue of the type " + "pointed to by operand"); + + return success(); +} + +//===----------------------------------------------------------------------===// // ExpressionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp index b4cb093..d6dfd02 100644 --- a/mlir/lib/Dialect/Func/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp @@ -254,3 +254,28 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp, return std::make_pair(*newFuncOpOrFailure, newCallOp); } + +FailureOr<func::FuncOp> +func::lookupFnDecl(SymbolOpInterface symTable, StringRef name, + FunctionType funcT, SymbolTableCollection *symbolTables) { + FuncOp func; + if (symbolTables) { + func = symbolTables->lookupSymbolIn<FuncOp>( + symTable, StringAttr::get(symTable->getContext(), name)); + } else { + func = llvm::dyn_cast_or_null<FuncOp>( + SymbolTable::lookupSymbolIn(symTable, name)); + } + + if (!func) + return func; + + mlir::FunctionType foundFuncT = func.getFunctionType(); + // Assert the signature of the found function is same as expected + if (funcT != foundFuncT) { + return func.emitError("matched function '") + << name << "' but with different type: " << foundFuncT + << " (expected " << funcT << ")"; + } + return func; +} diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2..61a630a 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32() || + return elementType.isF16() || elementType.isF32() || elementType.isF64() || elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) || elementType.isInteger(32); } @@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, if (!MMAMatrixType::isValidElementType(elementType)) return emitError() - << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; + << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64"; return success(); } diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt index ec68acf..85b7b1ce 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_dialect_library(MLIRGPUPipelines MLIRNVVMToLLVM MLIRReconcileUnrealizedCasts MLIRSCFToControlFlow + MLIRVectorToLLVMPass MLIRVectorToSCF MLIRXeGPUTransforms MLIRXeGPUToXeVM diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp index 2c3e466..5462cdd 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp @@ -72,6 +72,7 @@ void buildGpuPassPipeline(OpPassManager &pm, ConvertGpuOpsToNVVMOpsOptions opt; opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv; opt.indexBitwidth = options.indexBitWidth; + opt.allowPatternRollback = options.allowPatternRollback; pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps(opt)); pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp index b097d3a..38313dc 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -111,8 +111,11 @@ void buildPostGPUCommonPassPipeline( pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions)); } pm.addPass(createLowerAffinePass()); + pm.addPass(createConvertVectorToLLVMPass()); pm.addPass(createConvertToLLVMPass()); pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); // gpu-module-to-binary { GpuModuleToBinaryPassOptions gpuToModuleBinOptions; diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index cd13840..70d2e11 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -143,8 +143,8 @@ private: }; /// Erases `executeOp` and returns a clone with additional `results`. -async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, - ValueRange results) { +static async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, + ValueRange results) { // Add values to async.yield op. Operation *yieldOp = executeOp.getBody()->getTerminator(); yieldOp->insertOperands(yieldOp->getNumOperands(), results); diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp index 212ccc9..8d10aac 100644 --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -169,7 +169,7 @@ LogicalResult getSegmentSizes(Operation *op, StringRef elemName, LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef<Variadicity> variadicities, SmallVectorImpl<int> &segmentSizes) { - return getSegmentSizes(op, "operand", "operand_segment_sizes", + return getSegmentSizes(op, "operand", "operandSegmentSizes", op->getNumOperands(), variadicities, segmentSizes); } @@ -180,7 +180,7 @@ LogicalResult getOperandSegmentSizes(Operation *op, LogicalResult getResultSegmentSizes(Operation *op, ArrayRef<Variadicity> variadicities, SmallVectorImpl<int> &segmentSizes) { - return getSegmentSizes(op, "result", "result_segment_sizes", + return getSegmentSizes(op, "result", "resultSegmentSizes", op->getNumResults(), variadicities, segmentSizes); } diff --git a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp index 183d0e3..887e8e1 100644 --- a/mlir/lib/Dialect/Index/IR/IndexDialect.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Transforms/InliningUtils.h" using namespace mlir; using namespace mlir::index; @@ -15,10 +16,23 @@ using namespace mlir::index; //===----------------------------------------------------------------------===// // IndexDialect //===----------------------------------------------------------------------===// +namespace { +/// This class defines the interface for handling inlining for index +/// dialect operations. +struct IndexInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// All index dialect ops can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } +}; +} // namespace void IndexDialect::initialize() { registerAttributes(); registerOperations(); + addInterfaces<IndexInlinerInterface>(); declarePromisedInterface<ConvertToLLVMPatternInterface, IndexDialect>(); } diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index cc66fac..a73f0c1 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRFunctionInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index feaffa3..160b6ae 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -30,6 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16"; static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; +static constexpr llvm::StringRef kPrintApFloat = "printApFloat"; static constexpr llvm::StringRef kPrintString = "printString"; static constexpr llvm::StringRef kPrintOpen = "printOpen"; static constexpr llvm::StringRef kPrintClose = "printClose"; @@ -160,6 +161,16 @@ mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); } +FailureOr<LLVM::LLVMFuncOp> +mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, + SymbolTableCollection *symbolTables) { + return lookupOrCreateReservedFn( + b, moduleOp, kPrintApFloat, + {IntegerType::get(moduleOp->getContext(), 32), + IntegerType::get(moduleOp->getContext(), 64)}, + LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables); +} + static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { return LLVM::LLVMPointerType::get(context); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index b8331e0..9f87e50 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -219,11 +219,16 @@ bool TBAANodeAttr::classof(Attribute attr) { MemoryEffectsAttr MemoryEffectsAttr::get(MLIRContext *context, ArrayRef<ModRefInfo> memInfoArgs) { if (memInfoArgs.empty()) - return MemoryEffectsAttr::get(context, ModRefInfo::ModRef, - ModRefInfo::ModRef, ModRefInfo::ModRef); - if (memInfoArgs.size() == 3) + return MemoryEffectsAttr::get(context, /*other=*/ModRefInfo::ModRef, + /*argMem=*/ModRefInfo::ModRef, + /*inaccessibleMem=*/ModRefInfo::ModRef, + /*errnoMem=*/ModRefInfo::ModRef, + /*targetMem0=*/ModRefInfo::ModRef, + /*targetMem1=*/ModRefInfo::ModRef); + if (memInfoArgs.size() == 6) return MemoryEffectsAttr::get(context, memInfoArgs[0], memInfoArgs[1], - memInfoArgs[2]); + memInfoArgs[2], memInfoArgs[3], + memInfoArgs[4], memInfoArgs[5]); return {}; } @@ -234,6 +239,12 @@ bool MemoryEffectsAttr::isReadWrite() { return false; if (this->getOther() != ModRefInfo::ModRef) return false; + if (this->getErrnoMem() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem0() != ModRefInfo::ModRef) + return false; + if (this->getTargetMem1() != ModRefInfo::ModRef) + return false; return true; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 1bf4a1c..5b81948 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -4224,6 +4224,34 @@ LogicalResult InlineAsmOp::verify() { } //===----------------------------------------------------------------------===// +// UDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability UDivOp::getSpeculatability() { + // X / 0 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroU())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// +// SDivOp +//===----------------------------------------------------------------------===// +Speculation::Speculatability SDivOp::getSpeculatability() { + // This function conservatively assumes that all signed division by -1 are + // not speculatable. + // X / 0 => UB + // INT_MIN / -1 => UB + Value divisor = getRhs(); + if (matchPattern(divisor, m_IntRangeWithoutZeroS()) && + matchPattern(divisor, m_IntRangeWithoutNegOneS())) + return Speculation::Speculatable; + + return Speculation::NotSpeculatable; +} + +//===----------------------------------------------------------------------===// // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ce93d18..5dc4fa2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -667,6 +667,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, static constexpr llvm::StringRef kSpirvPrefix = "spirv."; static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount"; +static constexpr llvm::StringRef kAMDGCNNamedBarrier = "amdgcn.named.barrier"; bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { // See llvm/lib/IR/Type.cpp for reference. @@ -676,6 +677,9 @@ bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { properties |= (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); + if (getExtTypeName() == kAMDGCNNamedBarrier) + properties |= LLVMTargetExtType::CanBeGlobal; + return (properties & prop) == prop; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index d43f881..5ce56e6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/NVVMIntrinsicUtils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/NVPTXAddrSpace.h" @@ -48,6 +49,47 @@ using namespace NVVM; static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic; //===----------------------------------------------------------------------===// +// Helper/Utility methods +//===----------------------------------------------------------------------===// + +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); +} + +static bool isPtrInGenericSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Generic); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + +static bool isPtrInSharedClusterSpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::SharedCluster); +} + +static llvm::Value *castPtrToAddrSpace(llvm::IRBuilderBase &builder, + llvm::Value *ptr, + NVVMMemorySpace targetAS) { + unsigned AS = static_cast<unsigned>(targetAS); + return builder.CreateAddrSpaceCast( + ptr, llvm::PointerType::get(builder.getContext(), AS)); +} + +// Helper method to convert CtaGroupKind in NVVM Dialect to CtaGroupKind in LLVM +static llvm::nvvm::CTAGroupKind +getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) { + switch (ctaGroup) { + case NVVM::CTAGroupKind::CTA_1: + return llvm::nvvm::CTAGroupKind::CG_1; + case NVVM::CTAGroupKind::CTA_2: + return llvm::nvvm::CTAGroupKind::CG_2; + } + llvm_unreachable("unsupported cta_group value"); +} + +//===----------------------------------------------------------------------===// // Verifier methods //===----------------------------------------------------------------------===// @@ -199,6 +241,83 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() { return success(); } +LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() { + bool isSharedCTA = isPtrInSharedCTASpace(getDstMem()); + if (isSharedCTA && getMulticastMask()) + return emitError("Multicast is not supported with shared::cta mode."); + + return success(); +} + +static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, + NVVM::MemScopeKind scope, + Value retVal = nullptr) { + if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER) + return op->emitError("mbarrier scope must be either CTA or Cluster"); + + bool isSharedCluster = isPtrInSharedClusterSpace(addr); + bool hasRetValue = static_cast<bool>(retVal); + if (isSharedCluster && hasRetValue) + return op->emitError( + "mbarrier in shared_cluster space cannot return any value"); + + return success(); +} + +LogicalResult MBarrierArriveOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveExpectTxOp::verify() { + // The inline-ptx version of this Op does not support all features. + // With predicate, this Op lowers to inline-ptx. So, verify and + // error-out if there are unsupported features. + if (getPredicate()) { + if (getScope() != NVVM::MemScopeKind::CTA) + return emitError("mbarrier scope must be CTA when using predicate"); + + if (isPtrInSharedClusterSpace(getAddr())) + return emitError("mbarrier in shared_cluster space is not supported when " + "using predicate"); + + if (getRes()) + return emitError("return-value is not supported when using predicate"); + + if (getRelaxed() == true) + return emitError("mbarrier with relaxed semantics is not supported when " + "using predicate"); + } + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierArriveDropExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope(), + getRes()); +} + +LogicalResult MBarrierExpectTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierCompleteTxOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTestWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + +LogicalResult MBarrierTryWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { @@ -365,22 +484,71 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { return success(); } +LogicalResult PermuteOp::verify() { + using Mode = NVVM::PermuteMode; + bool hasHi = static_cast<bool>(getHi()); + + switch (getMode()) { + case Mode::DEFAULT: + case Mode::F4E: + case Mode::B4E: + if (!hasHi) + return emitError("mode '") + << stringifyPermuteMode(getMode()) << "' requires 'hi' operand."; + break; + case Mode::RC8: + case Mode::ECL: + case Mode::ECR: + case Mode::RC16: + if (hasHi) + return emitError("mode '") << stringifyPermuteMode(getMode()) + << "' does not accept 'hi' operand."; + break; + } + + return success(); +} + //===----------------------------------------------------------------------===// // Stochastic Rounding Conversion Ops //===----------------------------------------------------------------------===// -LogicalResult ConvertF32x2ToF16x2Op::verify() { - if (getRnd() != FPRoundingMode::RS) - return emitOpError("Only RS rounding mode is supported for " - "conversions from f32x2 to f16x2."); +static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, + FPRoundingMode rnd, + bool hasRandomBits, + Operation *op) { + static constexpr FPRoundingMode validRndModes[] = { + FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS}; + + if (!llvm::is_contained(validRndModes, rnd)) { + return op->emitOpError( + "Only RN, RZ, and RS rounding modes are supported for " + "conversions from f32x2 to ") + << dstType << "."; + } + + if (rnd == FPRoundingMode::RS) { + if (!hasRandomBits) { + return op->emitOpError("random_bits is required for RS rounding mode."); + } + } else { + if (hasRandomBits) { + return op->emitOpError( + "random_bits not supported for RN and RZ rounding modes."); + } + } + return success(); } +LogicalResult ConvertF32x2ToF16x2Op::verify() { + return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(), + getRandomBits() ? true : false, *this); +} + LogicalResult ConvertF32x2ToBF16x2Op::verify() { - if (getRnd() != FPRoundingMode::RS) - return emitOpError("Only RS rounding mode is supported for " - "conversions from f32x2 to bf16x2."); - return success(); + return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(), + getRandomBits() ? true : false, *this); } LogicalResult ConvertF32x4ToF8x4Op::verify() { @@ -919,6 +1087,482 @@ LogicalResult MmaOp::verify() { return success(); } +MMATypes MmaSpOp::accumPtxType() { + std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType( + getODSOperands(2).getTypes().front(), /*isAccumulator=*/true); + assert(val.has_value() && "accumulator PTX type should always be inferrable"); + return val.value(); +} + +MMATypes MmaSpOp::resultPtxType() { + std::optional<mlir::NVVM::MMATypes> val = + MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true); + assert(val.has_value() && "result PTX type should always be inferrable"); + return val.value(); +} + +mlir::NVVM::IDArgPair +MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MmaSpOp>(op); + + // Get operands + llvm::SmallVector<llvm::Value *> args; + for (mlir::Value v : thisOp.getOperands()) + args.push_back(mt.lookupValue(v)); + + // Get intrinsic ID using the existing getIntrinsicID method + auto intId = MmaSpOp::getIntrinsicID( + thisOp.getShape().getM(), thisOp.getShape().getN(), + thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(), + thisOp.getOrderedMetadata(), thisOp.getKind(), + *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(), + thisOp.accumPtxType(), thisOp.resultPtxType()); + + return {intId, args}; +} + +void MmaSpOp::print(OpAsmPrinter &p) { + SmallVector<Type, 4> regTypes; + struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector<Value, 4> regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} + }; + + std::array<OperandFragment, 5> frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", ""), OperandFragment("sparseMetadata", ""), + OperandFragment("selector", "")}; + SmallVector<StringRef, 4> ignoreAttrNames{ + mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()}; + + // Handle variadic operands A, B, C + for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(this->getOperand(operandIdx)); + if (operandIdx == varOperandSpec.first) { + regTypes.push_back(this->getOperand(operandIdx).getType()); + } + } + std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType( + regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } + + // Handle sparse metadata and selector (single operands) + frags[3].regs.push_back(getSparseMetadata()); + frags[4].regs.push_back(getSparsitySelector()); + + auto printMmaSpOperand = [&](const OperandFragment &frag) -> void { + p << " " << frag.operandName; + p << "["; + p.printOperands(frag.regs); + p << "]"; + }; + + for (const auto &frag : frags) + printMmaSpOperand(frag); + + p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames); + p << " : "; + p << "("; + for (int i = 0; i < 3; ++i) { + p << regTypes[i]; + if (i < 2) + p << ", "; + } + p << ") -> " << getResult().getType(); +} + +void MmaSpOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape, + std::optional<MMAIntOverflow> intOverflow, + std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) { + + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + MLIRContext *ctx = builder.getContext(); + result.addAttribute( + "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands(sparseMetadata); + result.addOperands(sparsitySelector); + + if (multiplicandPtxTypes) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); + if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); + } + + if (intOverflow.has_value()) + result.addAttribute("intOverflowBehavior", + MMAIntOverflowAttr::get(ctx, *intOverflow)); + + result.addTypes(resultType); + result.addAttribute( + MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), + static_cast<int32_t>(operandB.size()), + static_cast<int32_t>(operandC.size()), 1, + 1})); // sparseMetadata and sparsitySelector +} + +ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) { + struct OperandFragment { + std::optional<MMATypes> elemtype; + SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; + SmallVector<Type> regTypes; + }; + + Builder &builder = parser.getBuilder(); + std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector + + NamedAttrList namedAttributes; + + // A helper to parse the operand segments. + auto parseMmaSpOperand = [&](StringRef operandName, + OperandFragment &frag) -> LogicalResult { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser + .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) + return failure(); + return success(); + }; + + // Parse the operand segments. + if (parseMmaSpOperand("A", frags[0]).failed()) + return failure(); + if (parseMmaSpOperand("B", frags[1]).failed()) + return failure(); + if (parseMmaSpOperand("C", frags[2]).failed()) + return failure(); + if (parseMmaSpOperand("sparseMetadata", frags[3]).failed()) + return failure(); + if (parseMmaSpOperand("selector", frags[4]).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse the type specification and resolve operands. + SmallVector<Type, 3> operandTypes; + if (failed(parser.parseColon())) + return failure(); + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseTypeList(operandTypes))) + return failure(); + if (failed(parser.parseRParen())) + return failure(); + if (operandTypes.size() != 3) + return parser.emitError( + parser.getNameLoc(), + "expected one type for each operand segment but got " + + Twine(operandTypes.size()) + " types"); + for (const auto &iter : llvm::enumerate(operandTypes)) { + auto &frag = frags[iter.index()]; + frag.regTypes.resize(frag.regs.size(), iter.value()); + if (failed(parser.resolveOperands(frag.regs, frag.regTypes, + parser.getNameLoc(), result.operands))) + return failure(); + frag.elemtype = + MmaOp::inferOperandMMAType(frag.regTypes[0], + /*isAccumulator*/ iter.index() >= 2); + } + + Type resultType; + if (parser.parseArrow() || parser.parseType(resultType)) + return failure(); + frags[5].elemtype = + MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true); + + // Resolve sparse metadata and selector (assume i32 type) + Type i32Type = builder.getIntegerType(32); + if (parser + .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + if (parser + .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + + std::array<StringRef, 2> names{"multiplicandAPtxType", + "multiplicandBPtxType"}; + for (unsigned idx = 0; idx < names.size(); idx++) { + const auto &frag = frags[idx]; + std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); + if (!frag.elemtype.has_value() && !attr.has_value()) { + return parser.emitError( + parser.getNameLoc(), + "attribute " + names[idx] + + " is not provided explicitly and cannot be inferred"); + } + if (!attr.has_value()) + result.addAttribute( + names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); + } + + result.addTypes(resultType); + if (!namedAttributes.empty()) + result.addAttributes(namedAttributes); + result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast<int32_t>(frags[0].regs.size()), + static_cast<int32_t>(frags[1].regs.size()), + static_cast<int32_t>(frags[2].regs.size()), + 1, // sparseMetadata + 1 // sparsitySelector + })); + return success(); +} + +LogicalResult MmaSpOp::verify() { + MLIRContext *context = getContext(); + auto f16Ty = Float16Type::get(context); + auto i32Ty = IntegerType::get(context, 32); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f32Ty = Float32Type::get(context); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + auto s32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); + auto f32x8StructTy = + LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); + auto f16x2x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); + auto f32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); + auto s32x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); + + std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), + getShapeAttr().getK()}; + + // These variables define the set of allowed data types for matrices A, B, C, + // and result. + using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; + using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; + AllowedShapes allowedShapes; + AllowedTypes expectedA; + AllowedTypes expectedB; + AllowedTypes expectedC; + SmallVector<Type> expectedResult; + + // When M = 16, we just need to calculate the number of 8xk tiles, where + // k is a factor that depends on the data type. + if (mmaShape[0] == 16) { + int64_t kFactor; + Type multiplicandFragType; + switch (*getMultiplicandAPtxType()) { + case MMATypes::tf32: + kFactor = 4; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k8 and m16n8k16 for tf32 + allowedShapes.push_back({16, 8, 8}); + allowedShapes.push_back({16, 8, 16}); + break; + case MMATypes::bf16: + kFactor = 8; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k16 and m16n8k32 for bf16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::f16: + kFactor = 8; + multiplicandFragType = f16x2Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k16 and m16n8k32 for f16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::s4: + case MMATypes::u4: + kFactor = 32; + // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4 + allowedShapes.push_back({16, 8, 64}); + allowedShapes.push_back({16, 8, 128}); + break; + case MMATypes::s8: + case MMATypes::u8: + kFactor = 16; + // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8 + allowedShapes.push_back({16, 8, 32}); + allowedShapes.push_back({16, 8, 64}); + break; + case MMATypes::e4m3: + case MMATypes::e5m2: + case MMATypes::e3m2: + case MMATypes::e2m3: + case MMATypes::e2m1: + kFactor = 32; + multiplicandFragType = i32Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k64 for FP8 types + allowedShapes.push_back({16, 8, 64}); + break; + default: + return emitError("invalid shape or multiplicand type: " + + stringifyEnum(getMultiplicandAPtxType().value())); + } + + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedResult.push_back(s32x4StructTy); + expectedC.emplace_back(4, i32Ty); + multiplicandFragType = i32Ty; + } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 && + *getMultiplicandAPtxType() <= MMATypes::e2m1) { + // FP8 types + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } else { + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } + + // For sparse MMA, A operand is compressed (2:4 sparsity means half the + // elements) + int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2; + int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); + expectedA.emplace_back(unitA, multiplicandFragType); + expectedB.emplace_back(unitB, multiplicandFragType); + + if (resultPtxType() != accumPtxType()) + return emitOpError("ctype does not match dtype"); + } + + // In the M=8 case, there is only 1 possible case per data type. + if (mmaShape[0] == 8) { + if (*getMultiplicandAPtxType() == MMATypes::f16) { + expectedA.emplace_back(2, f16x2Ty); + expectedB.emplace_back(2, f16x2Ty); + expectedResult.push_back(f16x2x4StructTy); + expectedResult.push_back(f32x8StructTy); + expectedC.emplace_back(4, f16x2Ty); + expectedC.emplace_back(8, f32Ty); + allowedShapes.push_back({8, 8, 4}); + } + if (*getMultiplicandAPtxType() == MMATypes::f64) { + Type f64Ty = Float64Type::get(context); + expectedA.emplace_back(1, f64Ty); + expectedB.emplace_back(1, f64Ty); + expectedC.emplace_back(2, f64Ty); + expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( + context, SmallVector<Type>(2, f64Ty))); + allowedShapes.push_back({8, 8, 4}); + } + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedA.push_back({i32Ty}); + expectedB.push_back({i32Ty}); + expectedC.push_back({i32Ty, i32Ty}); + expectedResult.push_back(s32x2StructTy); + if (isInt4PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 32}); + if (isInt8PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 16}); + } + } + + std::string errorMessage; + llvm::raw_string_ostream errorStream(errorMessage); + + // Check that we matched an existing shape/dtype combination. + if (expectedA.empty() || expectedB.empty() || expectedC.empty() || + !llvm::is_contained(allowedShapes, mmaShape)) { + errorStream << "unimplemented variant for MMA shape <"; + llvm::interleaveComma(mmaShape, errorStream); + errorStream << ">"; + return emitOpError(errorMessage); + } + + // Verify the operand types for segments of A, B, and C operands. + std::array<StringRef, 3> operandNames{"A", "B", "C"}; + for (const auto &iter : llvm::enumerate( + SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { + auto spec = this->getODSOperandIndexAndLength(iter.index()); + SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, + operand_type_begin() + spec.first + + spec.second); + bool match = llvm::is_contained(iter.value(), operandTySeg); + + if (!match) { + errorStream << "Could not match types for the " + << operandNames[iter.index()] + << " operands; expected one of "; + for (const auto &x : iter.value()) { + errorStream << x.size() << "x" << x[0] << " "; + } + errorStream << "but got "; + llvm::interleaveComma(operandTySeg, errorStream); + return emitOpError(errorMessage); + } + } + + // Check the result type + if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { + return expectedResultType == getResult().getType(); + })) { + errorStream + << "Could not match allowed types for the result; expected one of "; + llvm::interleaveComma(expectedResult, errorStream); + errorStream << " but got " << getResult().getType(); + return emitOpError(errorMessage); + } + + // Ensure int4/int8 MMA variants specify the accum overflow behavior + // attribute. + if (isInt4PtxType(*getMultiplicandAPtxType()) || + isInt8PtxType(*getMultiplicandAPtxType())) { + if (!getIntOverflowBehavior()) + return emitOpError("op requires " + + getIntOverflowBehaviorAttrName().strref() + + " attribute"); + } + + // Validate sparse metadata type (should be i32) + if (!getSparseMetadata().getType().isInteger(32)) { + return emitOpError() << "sparse metadata must be i32 type"; + } + + // Validate sparsity selector type (should be i32) + if (!getSparsitySelector().getType().isInteger(32)) { + return emitOpError() << "sparsity selector must be i32 type"; + } + + return success(); +} + LogicalResult ShflOp::verify() { auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); @@ -1454,6 +2098,13 @@ bool NVVM::WgmmaMmaAsyncOp::getAsmValues( return true; // Has manual mapping } +LogicalResult NVVM::FenceSyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + return success(); +} + LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; @@ -1476,7 +2127,6 @@ LogicalResult NVVM::FenceProxyAcquireOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); - return success(); } @@ -1488,7 +2138,19 @@ LogicalResult NVVM::FenceProxyReleaseOp::verify() { if (getToProxy() != NVVM::ProxyKind::TENSORMAP) return emitOpError("uni-directional proxies only support tensormap " "for to_proxy attribute"); + return success(); +} + +LogicalResult NVVM::FenceProxySyncRestrictOp::verify() { + if (getOrder() != NVVM::MemOrderKind::ACQUIRE && + getOrder() != NVVM::MemOrderKind::RELEASE) + return emitOpError("only acquire and release semantics are supported"); + + if (getFromProxy() != NVVM::ProxyKind::GENERIC) + return emitOpError("only generic is support for from_proxy attribute"); + if (getToProxy() != NVVM::ProxyKind::async) + return emitOpError("only async is supported for to_proxy attribute"); return success(); } @@ -1504,6 +2166,15 @@ LogicalResult NVVM::BarrierOp::verify() { if (getNumberOfThreads() && !getBarrierId()) return emitOpError( "barrier id is missing, it should be set between 0 to 15"); + + if (getBarrierId() && (getReductionOp() || getReductionPredicate())) + return emitOpError("reduction are only available when id is 0"); + + if ((getReductionOp() && !getReductionPredicate()) || + (!getReductionOp() && getReductionPredicate())) + return emitOpError("reduction predicate and reduction operation must be " + "specified together"); + return success(); } @@ -1741,24 +2412,68 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, //===----------------------------------------------------------------------===// std::string NVVM::MBarrierInitOp::getPtx() { - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); - return (addressSpace == NVVMMemorySpace::Shared) - ? std::string("mbarrier.init.shared.b64 [%0], %1;") - : std::string("mbarrier.init.b64 [%0], %1;"); + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;") + : std::string("mbarrier.init.b64 [%0], %1;"); +} + +std::string NVVM::MBarrierArriveExpectTxOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared + ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;") + : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); +} + +std::string NVVM::MBarrierTryWaitParityOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + llvm::StringRef space = isShared ? ".shared" : ""; + + return llvm::formatv("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}", + space); } //===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// -static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { - auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); - return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); -} +mlir::NVVM::IDArgPair NVVM::BarrierOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::BarrierOp>(op); + llvm::Value *barrierId = thisOp.getBarrierId() + ? mt.lookupValue(thisOp.getBarrierId()) + : builder.getInt32(0); + llvm::Intrinsic::ID id; + llvm::SmallVector<llvm::Value *> args; + if (thisOp.getNumberOfThreads()) { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count; + args.push_back(barrierId); + args.push_back(mt.lookupValue(thisOp.getNumberOfThreads())); + } else if (thisOp.getReductionOp()) { + switch (*thisOp.getReductionOp()) { + case NVVM::BarrierReduction::AND: + id = llvm::Intrinsic::nvvm_barrier0_and; + break; + case NVVM::BarrierReduction::OR: + id = llvm::Intrinsic::nvvm_barrier0_or; + break; + case NVVM::BarrierReduction::POPC: + id = llvm::Intrinsic::nvvm_barrier0_popc; + break; + } + args.push_back(mt.lookupValue(thisOp.getReductionPredicate())); + } else { + id = llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all; + args.push_back(barrierId); + } -static bool isPtrInSharedCTASpace(mlir::Value ptr) { - return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); + return {id, std::move(args)}; } mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( @@ -1787,15 +2502,213 @@ mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( return {id, {mt.lookupValue(thisOp.getAddr())}}; } +mlir::NVVM::IDArgPair MBarrierExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_expect_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierCompleteTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierCompleteTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_complete_tx_scope_cluster_space_cluster}; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getTxcount())); + + return {IDs[index], std::move(args)}; +} + mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierArriveOp>(op); - bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); - llvm::Intrinsic::ID id = isShared - ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared - : llvm::Intrinsic::nvvm_mbarrier_arrive; - return {id, {mt.lookupValue(thisOp.getAddr())}}; + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cta_space_cluster, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cta, + llvm::Intrinsic:: + nvvm_mbarrier_arrive_drop_relaxed_scope_cluster_space_cluster}; + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // When count is not explicitly specified, the default is 1. + llvm::LLVMContext &ctx = mt.getLLVMContext(); + bool hasCount = static_cast<bool>(thisOp.getCount()); + llvm::Value *count = + hasCount ? mt.lookupValue(thisOp.getCount()) + : llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1); + + return {id, {mbar, count}}; +} + +bool MBarrierArriveExpectTxOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> + &asmValues) { + // Add all the operands but not the attrs to the asmValues list. + // The attrs here are used to generate the right variants for + // intrinsics-lowering. So, we ignore them while generating inline-PTX. + for (auto val : getOperands()) + asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); + + return false; +} + +mlir::NVVM::IDArgPair MBarrierArriveExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveDropExpectTxOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveDropExpectTxOp>(op); + + bool isClusterSpace = isPtrInSharedClusterSpace(thisOp.getAddr()); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: Space + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isClusterSpace ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_scope_cluster_space_cluster}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cta_space_cluster, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_arrive_drop_expect_tx_relaxed_scope_cluster_space_cluster}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *txcount = mt.lookupValue(thisOp.getTxcount()); + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, txcount}}; } mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( @@ -1813,17 +2726,100 @@ mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } -mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( +mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { - auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + auto thisOp = cast<NVVM::MBarrierArriveDropNocompleteOp>(op); bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); - llvm::Intrinsic::ID id = isShared - ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared - : llvm::Intrinsic::nvvm_mbarrier_test_wait; + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_drop_noComplete; // Fill the Intrinsic Args llvm::SmallVector<llvm::Value *> args; args.push_back(mt.lookupValue(thisOp.getAddr())); - args.push_back(mt.lookupValue(thisOp.getState())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: isPhaseParity + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, input}}; +} + +mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTryWaitOp>(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + bool hasTicks = static_cast<bool>(thisOp.getTicks()); + // bit-0: isPhaseParity + // bit-1: Scope + // bit-2: hasTicks + size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) | + (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the mbarrier pointer + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mbar); + args.push_back(mt.lookupValue(thisOp.getStateOrPhase())); + if (hasTicks) + args.push_back(mt.lookupValue(thisOp.getTicks())); return {id, std::move(args)}; } @@ -1914,11 +2910,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); - // Multicast mask, if available. + // Multicast mask for shared::cluster only, if available. mlir::Value multicastMask = thisOp.getMulticastMask(); const bool hasMulticastMask = static_cast<bool>(multicastMask); - llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); - args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem()); + if (!isSharedCTA) { + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) + : i16Unused); + } // Cache hint, if available. mlir::Value cacheHint = thisOp.getL2CacheHint(); @@ -1927,11 +2927,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); // Flag arguments for multicast and cachehint. - args.push_back(builder.getInt1(hasMulticastMask)); + if (!isSharedCTA) + args.push_back(builder.getInt1(hasMulticastMask)); args.push_back(builder.getInt1(hasCacheHint)); llvm::Intrinsic::ID id = - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + isSharedCTA + ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta + : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; return {id, std::move(args)}; } @@ -2646,30 +3649,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() -llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() { - bool hasRelu = getRelu(); - bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); +NVVM::IDArgPair +ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rn, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rz, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2f16x2_rs, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite, + }; + + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; - if (hasRelu && hasSatFinite) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite; - if (hasRelu) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu; - if (hasSatFinite) - return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite; - return llvm::Intrinsic::nvvm_ff2f16x2_rs; + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op"); + } } -llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() { - bool hasRelu = getRelu(); - bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); +NVVM::IDArgPair +ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + static constexpr llvm::Intrinsic::ID rndRNIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rn, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRZIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rz, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, + }; + static constexpr llvm::Intrinsic::ID rndRSIds[] = { + llvm::Intrinsic::nvvm_ff2bf16x2_rs, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, + llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, + }; - if (hasRelu && hasSatFinite) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite; - if (hasRelu) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu; - if (hasSatFinite) - return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite; - return llvm::Intrinsic::nvvm_ff2bf16x2_rs; + unsigned hasRelu = op.getRelu() ? 1 : 0; + unsigned hasSatFinite = + (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; + // idx: bit-0 - relu + // bit-1 - satfinite + unsigned idx = (hasSatFinite << 1) | hasRelu; + + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getSrcHi())); + args.push_back(mt.lookupValue(op.getSrcLo())); + if (op.getRandomBits()) + args.push_back(mt.lookupValue(op.getRandomBits())); + + switch (op.getRnd()) { + case FPRoundingMode::RN: + return {rndRNIds[idx], std::move(args)}; + case FPRoundingMode::RZ: + return {rndRZIds[idx], std::move(args)}; + case FPRoundingMode::RS: + return {rndRSIds[idx], std::move(args)}; + default: + llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op"); + } } llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { @@ -3010,6 +4083,630 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs( return {intrinsicID, args}; } +mlir::NVVM::IDArgPair +PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::PermuteOp>(op); + NVVM::PermuteMode mode = thisOp.getMode(); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e, + llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8, + llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr, + llvm::Intrinsic::nvvm_prmt_rc16}; + + unsigned modeIndex = static_cast<unsigned>(mode); + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getLo())); + + // Only first 3 modes (Default, f4e, b4e) need the hi operand. + if (modeIndex < 3) + args.push_back(mt.lookupValue(thisOp.getHi())); + + args.push_back(mt.lookupValue(thisOp.getSelector())); + + return {IDs[modeIndex], args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair +Tcgen05MMAOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + const bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_scale_d, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + const unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMAIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMAOp(bool isATensor, mlir::Value disableOutputLane, + NVVM::CTAGroupKind ctaGroup, bool hasAShift, + NVVM::Tcgen05MMACollectorOp collectorOp, Location loc) { + + if (disableOutputLane) { + mlir::VectorType disableOutputLaneType = + cast<mlir::VectorType>(disableOutputLane.getType()); + if ((ctaGroup == NVVM::CTAGroupKind::CTA_1 && + disableOutputLaneType.getNumElements() != 4) || + (ctaGroup == NVVM::CTAGroupKind::CTA_2 && + disableOutputLaneType.getNumElements() != 8)) + return emitError(loc) << "Disable Output Lane of length " + << disableOutputLaneType.getNumElements() + << " is incompatible with CtaGroupAttr"; + } + + if (hasAShift && !isATensor) + return emitError( + loc, "A-shift can be applied only when matrix A is in tensor memory"); + + if (hasAShift == true && (collectorOp == Tcgen05MMACollectorOp::FILL || + collectorOp == Tcgen05MMACollectorOp::USE)) + return emitError( + loc, "Cannot use collector buffer operation fill or use with ashift"); + + return success(); +} + +LogicalResult Tcgen05MMAOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>; + using CtaGroupArray = std::array<EnableAShiftArray, 2>; + using IsATensorArray = std::array<CtaGroupArray, 2>; + using HasScaleInputDArray = std::array<IsATensorArray, 2>; + using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>; + + // [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift] + static constexpr HasDisableOutputLaneArray tcgen05MMASparseIDs = { + { // without diable output lane + {{// without scale input D + {{ + // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared, notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_ashift, + }}}, + }}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}, + // cg2 + {llvm::Intrinsic::nvvm_tcgen05_mma_sp_shared_scale_d, + notIntrinsic}}}, + {{// tensor + { + // cg1 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }, + { + // cg2 + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d, + llvm::Intrinsic::nvvm_tcgen05_mma_sp_tensor_scale_d_ashift, + }}}}}}}, + // with disable output lane + {{ // without scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_disable_output_lane_cg2, + notIntrinsic}}}, + {{// cg1 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg1_ashift, + }, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_disable_output_lane_cg2_ashift, + }}}}}, + // with scale input D + {{ // shared + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg1, + notIntrinsic}, + // cg2 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_scale_d_disable_output_lane_cg2, + notIntrinsic}}}, + // tensor + {{// cg1 + {llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg1_ashift}, + // cg2 + { + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2, + llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift, + }}}}}}}}}; + + llvm::Value *ScaleInputD = mt.lookupValue(thisOp.getScaleInputD()); + bool hasScaleInputD = ScaleInputD != nullptr; + + llvm::Value *DisableOutputLane = + mt.lookupValue(thisOp.getDisableOutputLane()); + bool hasDisableOutputLane = DisableOutputLane != nullptr; + + unsigned ctaGroup = + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())); + + llvm::Intrinsic::ID ID = + tcgen05MMASparseIDs[hasDisableOutputLane][hasScaleInputD][isATensor] + [ctaGroup - 1][thisOp.getAShift()]; + + assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMASparseOp."); + + if (hasScaleInputD) + args.push_back(ScaleInputD); + + if (hasDisableOutputLane) + args.push_back(DisableOutputLane); + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + + if (!hasDisableOutputLane) + args.push_back(builder.getInt32(ctaGroup)); + + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseOp::verify() { + return verifyTcgen05MMAOp(isa<LLVM::LLVMPointerType>(getMatrixA().getType()), + getDisableOutputLane(), getCtaGroup(), getAShift(), + getCollectorOp(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMABlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale + : llvm::Intrinsic::nvvm_tcgen05_mma_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.block_scale attributes"); + }(); + + return {ID, args}; +} + +static LogicalResult +verifyTcgen05MMABlockScaleOp(NVVM::Tcgen05MMACollectorOp collectorOp, + NVVM::Tcgen05MMABlockScaleKind kind, + NVVM::Tcgen05MMABlockScale blockScale, + Location loc) { + + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT && + kind == Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, "mxf4nvf4 requires block scale attribute"); + + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 && + kind != Tcgen05MMABlockScaleKind::MXF4NVF4) + return emitError(loc, + llvm::formatv("{} kind does not support block16 attribute", + stringifyEnum(kind))); + + return success(); +} + +LogicalResult Tcgen05MMABlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.sp.block_scale functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMASparseBlockScaleOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + args.push_back(mt.lookupValue(thisOp.getScaleA())); + args.push_back(mt.lookupValue(thisOp.getScaleB())); + args.push_back(builder.getInt32( + static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup())))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + auto kind = thisOp.getKind(); + auto blockScale = thisOp.getBlockScale(); + llvm::Intrinsic::ID ID = [&]() { + if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF8F6F4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) { + return isATensor ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale; + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32; + } + } else if (kind == NVVM::Tcgen05MMABlockScaleKind::MXF4NVF4) { + if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block32 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block32; + + } else if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16) { + return isATensor + ? llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_tensor_mxf4nvf4_block_scale_block16 + : llvm::Intrinsic:: + nvvm_tcgen05_mma_sp_shared_mxf4nvf4_block_scale_block16; + } + } + llvm_unreachable("Invalid tcgen05.mma.sp.block_scale attributes"); + }(); + + return {ID, args}; +} + +LogicalResult Tcgen05MMASparseBlockScaleOp::verify() { + return verifyTcgen05MMABlockScaleOp(getCollectorOp(), getKind(), + getBlockScale(), getLoc()); +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + +//===----------------------------------------------------------------------===// +// NVVM tcgen05.mma.ws.sp functions +//===----------------------------------------------------------------------===// + +mlir::NVVM::IDArgPair Tcgen05MMAWsSparseOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + + auto thisOp = cast<NVVM::Tcgen05MMAWsSparseOp>(op); + llvm::SmallVector<llvm::Value *> args; + + args.push_back(mt.lookupValue(thisOp.getMatrixD())); + + llvm::Value *A = mt.lookupValue(thisOp.getMatrixA()); + bool isATensor = isa<llvm::PointerType>(A->getType()); + args.push_back(A); + + args.push_back(mt.lookupValue(thisOp.getMatrixB())); + args.push_back(mt.lookupValue(thisOp.getIdesc())); + args.push_back(mt.lookupValue(thisOp.getEnableInputD())); + args.push_back(mt.lookupValue(thisOp.getSparseMetadata())); + + mlir::Value ZeroColMask = thisOp.getZeroColMask(); + llvm::Intrinsic::ID ID = notIntrinsic; + if (ZeroColMask) { + args.push_back(mt.lookupValue(ZeroColMask)); + ID = isATensor + ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor_zero_col_mask + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared_zero_col_mask; + } else + ID = isATensor ? llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_tensor + : llvm::Intrinsic::nvvm_tcgen05_mma_ws_sp_shared; + + args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorBBuffer()))); + args.push_back( + builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp()))); + + return {ID, args}; +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// @@ -3213,16 +4910,20 @@ LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) { "Minimum NVVM target SM version is sm_20"); } - gpuModuleOp->walk([&](Operation *op) { - if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { - const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion(); - if (!requirement.isCompatibleWith(targetSMVersion)) { - op->emitOpError() << "is not supported on " << getChip(); - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); + if (gpuModuleOp + ->walk([&](Operation *op) { + if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) { + const NVVMCheckSMVersion requirement = + reqOp.getRequiredMinSMVersion(); + if (!requirement.isCompatibleWith(targetSMVersion)) { + op->emitOpError() << "is not supported on " << getChip(); + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }) + .wasInterrupted()) + return failure(); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp index 67573c4..12dd225 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -109,8 +109,12 @@ static Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); } +/// Adds DILexicalBlockFileAttr for operations with CallSiteLoc and operations +/// from different files than their containing function. static void setLexicalBlockFileAttr(Operation *op) { - if (auto callSiteLoc = dyn_cast<CallSiteLoc>(op->getLoc())) { + Location opLoc = op->getLoc(); + + if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) { auto callerLoc = callSiteLoc.getCaller(); auto calleeLoc = callSiteLoc.getCallee(); LLVM::DIScopeAttr scopeAttr; @@ -122,6 +126,45 @@ static void setLexicalBlockFileAttr(Operation *op) { op->setLoc( CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc)); } + + return; + } + + auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); + if (!funcOp) + return; + + FileLineColLoc opFileLoc = extractFileLoc(opLoc); + if (!opFileLoc) + return; + + FileLineColLoc funcFileLoc = extractFileLoc(funcOp.getLoc()); + if (!funcFileLoc) + return; + + StringRef opFile = opFileLoc.getFilename().getValue(); + StringRef funcFile = funcFileLoc.getFilename().getValue(); + + // Handle cross-file operations: add DILexicalBlockFileAttr when the + // operation's source file differs from its containing function. + if (opFile != funcFile) { + auto funcOpLoc = llvm::dyn_cast_if_present<FusedLoc>(funcOp.getLoc()); + if (!funcOpLoc) + return; + auto scopeAttr = dyn_cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata()); + if (!scopeAttr) + return; + + auto *context = op->getContext(); + LLVM::DIFileAttr opFileAttr = + LLVM::DIFileAttr::get(context, llvm::sys::path::filename(opFile), + llvm::sys::path::parent_path(opFile)); + + LLVM::DILexicalBlockFileAttr lexicalBlockFileAttr = + LLVM::DILexicalBlockFileAttr::get(context, scopeAttr, opFileAttr, 0); + + Location newLoc = FusedLoc::get(context, {opLoc}, lexicalBlockFileAttr); + op->setLoc(newLoc); } } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index dcc1ef9..b4b1347 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { // FillOpInterface implementation //===----------------------------------------------------------------------===// +namespace { enum class MatchFillResult { Success = 0, NotLinalgOp, WrongNumOperands, - NotScalarInput + NotScalarInput, + TypeMismatch }; +} // namespace static MatchFillResult isFillInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast<linalg::LinalgOp>(op); @@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) { if (!linalgOp.isScalar(value)) return MatchFillResult::NotScalarInput; + // Check that the scalar input type matches the output element type. + OpOperand *output = linalgOp.getDpsInitOperand(0); + Type scalarType = value->get().getType(); + Type outputElementType = getElementTypeOrSelf(output->get().getType()); + if (scalarType != outputElementType) + return MatchFillResult::TypeMismatch; + return MatchFillResult::Success; } LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { - auto res = isFillInterfaceImpl(op); + MatchFillResult res = isFillInterfaceImpl(op); if (res == MatchFillResult::NotLinalgOp) return op->emitError("expected a LinalgOp"); if (res == MatchFillResult::WrongNumOperands) return op->emitError("expected op with 1 input and 1 output"); if (res == MatchFillResult::NotScalarInput) return op->emitError("expected op with scalar input"); + if (res == MatchFillResult::TypeMismatch) { + auto linalgOp = cast<linalg::LinalgOp>(op); + Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType(); + Type outputElementType = + getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType()); + return op->emitOpError("expected fill value type (") + << scalarType << ") to match output element type (" + << outputElementType << ")"; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3dc45ed..33ec79b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1338,8 +1338,6 @@ Speculation::Speculatability GenericOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } -LogicalResult GenericOp::verify() { return success(); } - namespace { /// Remove linalg operations that are just copying the values from inputs to @@ -2091,7 +2089,7 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor, return failure(); // Single dimension transpose. - if (getPermutation().size() == 0) { + if (getPermutation().empty()) { result.push_back(getInput()); return success(); } @@ -4885,13 +4883,6 @@ void ElementwiseOp::print(OpAsmPrinter &p) { elidedAttrs); } -LogicalResult ElementwiseOp::verify() { - // All necessary checks are done either by - // - EnumAttr (e.g. unknown operation kind) - // - verifyStructuredOpInterface (incorrect map, sizes). - return success(); -} - /// Implements the block region builder for the ElementwiseOp. This is called by /// 'fillStructuredOpRegion'. void ElementwiseOp::regionBuilder( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index aa82063..b8c1bad 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -176,7 +176,8 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults( if (auto attr = dyn_cast<Attribute>(paramOrHandle)) { reified.push_back(cast<IntegerAttr>(attr).getInt()); continue; - } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) { + } + if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) { ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle)); if (params.size() != 1) return transformOp.emitSilenceableError() << "expected a single param"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 22690da..9e6c1e6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -747,8 +747,7 @@ struct RankReducedExtractSliceOp SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); auto rankReducedType = cast<RankedTensorType>( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - reassociation->size(), sliceOp.getSourceType(), offsets, sizes, - strides)); + reassociation->size(), sliceOp.getSourceType(), sizes)); Location loc = sliceOp.getLoc(); Value newSlice = tensor::ExtractSliceOp::create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05fc7cb..421ab5e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1038,6 +1038,62 @@ private: ControlFusionFn controlFoldingReshapes; }; +/// Carries information about a padded dimension. +struct PadDimInfo { + // The resulting shape after padding each dimension. + SmallVector<int64_t> paddedShape; + + // Low and high padding amounts for each dimension. + SmallVector<OpFoldResult> lowPad; + SmallVector<OpFoldResult> highPad; +}; + +/// Computes the expanded padding information for the given pad operation based +/// on the provided expanded shape and reassociation indices. Returns a list of +/// PadDimInfo containing the low and high padding amounts and the padded +/// size for each dimension, or failure if the expansion is not possible. +static FailureOr<PadDimInfo> +computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape, + ArrayRef<ReassociationIndices> reassociations, + PatternRewriter &rewriter) { + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to expand the dimensions, since it will change + // the index values on which the padding value depends. This is not currently + // supported by the pad expansion patterns, but it could be implemented + // similarly to the expansion of linalg.generic ops with linalg.index ops in + // the body, as is done in `updateExpandedGenericOpRegion`. + if (!padOp.getConstantPaddingValue()) + return failure(); + + // Expanded dimensions cannot have padding because the resulting padding may + // not be representable by a tensor.pad op. There are some special cases where + // it is possible (like expanding unit dims), but supporting these cases is + // NYI, so disallow it for now. + ArrayRef<int64_t> low = padOp.getStaticLow(); + ArrayRef<int64_t> high = padOp.getStaticHigh(); + for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { + if (reInd.size() != 1 && (l != 0 || h != 0)) + return failure(); + } + + SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad()); + SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad()); + ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.paddedShape.assign(expandedShape); + padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0)); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.paddedShape[reInd[0]] = paddedShape[idx]; + padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx]; + padDimInfo.highPad[reInd[0]] = mixedHighPad[idx]; + } + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByExpansion : public OpRewritePattern<tensor::PadOp> { public: @@ -1053,46 +1109,96 @@ public: padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, "fusion blocked by control function"); } - ArrayRef<int64_t> low = padOp.getStaticLow(); - ArrayRef<int64_t> high = padOp.getStaticHigh(); + RankedTensorType expandedType = reshapeOp.getSrcType(); SmallVector<ReassociationIndices> reassociations = reshapeOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); - for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { - if (reInd.size() != 1 && (l != 0 || h != 0)) - return failure(); + Location loc = padOp->getLoc(); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), + expandedPadding.lowPad, expandedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( + padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + +class FoldReshapeWithProducerPadOpByExpansion + : public OpRewritePattern<tensor::ExpandShapeOp> { +public: + FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>(); + if (!padOp) + return failure(); + + if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(expandOp, + "fusion blocked by control function"); } - SmallVector<OpFoldResult> newLow, newHigh; - RankedTensorType expandedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector<int64_t> expandedPaddedShape(expandedType.getShape()); + RankedTensorType expandedType = expandOp.getResultType(); + SmallVector<ReassociationIndices> reassociations = + expandOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding( + padOp, expandedType.getShape(), reassociations, rewriter); + if (failed(maybeExpandedPadding)) + return failure(); + PadDimInfo &expandedPadding = maybeExpandedPadding.value(); + + Location loc = expandOp->getLoc(); + SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape(); + SmallVector<int64_t> newExpandedShape(expandedType.getShape()); + rewriter.setInsertionPointAfterValue(padOp.getSource()); + SmallVector<OpFoldResult> padSrcSizes = + tensor::getMixedSizes(rewriter, loc, padOp.getSource()); for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + // We know that any reassociation with multiple dims is not padded because + // of the requirements of computeExpandedPadding. if (reInd.size() == 1) { - expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; - } - for (size_t i = 0; i < reInd.size(); ++i) { - newLow.push_back(padOp.getMixedLowPad()[idx]); - newHigh.push_back(padOp.getMixedHighPad()[idx]); + newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx); + newExpandedSizes[reInd[0]] = padSrcSizes[idx]; } } - - Location loc = padOp->getLoc(); - RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); + RankedTensorType newExpandedType = expandedType.clone(newExpandedShape); + auto newExpandOp = tensor::ExpandShapeOp::create( + rewriter, loc, newExpandedType, padOp.getSource(), reassociations, + newExpandedSizes); + RankedTensorType expandedPaddedType = + padOp.getResultType().clone(expandedPadding.paddedShape); + rewriter.setInsertionPoint(expandOp); auto newPadOp = tensor::PadOp::create( - rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, expandedPaddedType, newExpandOp.getResult(), + expandedPadding.lowPad, expandedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); - rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( - padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); + rewriter.replaceOp(expandOp, newPadOp.getResult()); return success(); } @@ -1921,6 +2027,62 @@ private: ControlFusionFn controlFoldingReshapes; }; +/// Computes the collapsed padding information for the given pad operation based +/// on the provided collapsed shape and reassociation indices. Returns a +/// PadDimInfo containing the low and high padding amounts and the collapsed +/// shape for each dimension, or failure if the collapse is not possible. +static FailureOr<PadDimInfo> +computeCollapsedPadding(tensor::PadOp padOp, + ArrayRef<ReassociationIndices> reassociations, + PatternRewriter &rewriter) { + // If the padding value depends on the index values of the pad operation, + // then it may not be valid to collapse the dimensions, since it will change + // the index values on which the padding value depends. This is not currently + // supported by the pad collapsing patterns, but it could be implemented + // similarly to the collapsing of linalg.generic ops with linalg.index ops in + // the body, as is done in `generateCollapsedIndexingRegion`. + if (!padOp.getConstantPaddingValue()) + return failure(); + + // Collapsed dimensions cannot have padding because this can produce strided + // padding that isn't representable by a tensor.pad op. There are some special + // cases where it is possible (like collapsing unit dims), but supporting + // these cases is NYI, so disallow it for now. + ArrayRef<int64_t> low = padOp.getStaticLow(); + ArrayRef<int64_t> high = padOp.getStaticHigh(); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + for (int64_t dim : reInd) { + if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1) + return failure(); + } + } + + // Initialize padding values for collapsed tensors with zeros + ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape(); + PadDimInfo padDimInfo; + padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0)); + + // Update padding for dimensions that are not being collapsed, and compute + // the collapsed padded shape. + SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad()); + SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad()); + for (auto [idx, reInd] : llvm::enumerate(reassociations)) { + if (reInd.size() == 1) { + padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]]; + padDimInfo.highPad[idx] = mixedHighPad[reInd[0]]; + } + SaturatedInteger collapsedSize = SaturatedInteger::wrap(1); + for (int64_t dim : reInd) { + collapsedSize = + collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]); + } + padDimInfo.paddedShape.push_back(collapsedSize.asInteger()); + } + + return padDimInfo; +} + class FoldPadWithProducerReshapeOpByCollapsing : public OpRewritePattern<tensor::PadOp> { public: @@ -1936,57 +2098,40 @@ public: padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); if (!reshapeOp) return failure(); - if (!reshapeOp->hasOneUse()) - return failure(); if (!controlFoldingReshapes(&padOp.getSourceMutable())) { return rewriter.notifyMatchFailure(padOp, "fusion blocked by control function"); } - ArrayRef<int64_t> low = padOp.getStaticLow(); - ArrayRef<int64_t> high = padOp.getStaticHigh(); SmallVector<ReassociationIndices> reassociations = reshapeOp.getReassociationIndices(); + FailureOr<PadDimInfo> maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); - for (auto reInd : reassociations) { - if (reInd.size() == 1) - continue; - if (llvm::any_of(reInd, [&](int64_t ind) { - return low[ind] != 0 || high[ind] != 0; - })) { - return failure(); - } - } - - SmallVector<OpFoldResult> newLow, newHigh; - RankedTensorType collapsedType = reshapeOp.getSrcType(); - RankedTensorType paddedType = padOp.getResultType(); - SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape()); - SmallVector<OpFoldResult> expandedPaddedSizes( - getMixedValues(reshapeOp.getStaticOutputShape(), - reshapeOp.getOutputShape(), rewriter)); + SmallVector<OpFoldResult> expandedPaddedSizes = + reshapeOp.getMixedOutputShape(); AffineExpr d0, d1, d2; bindDims(rewriter.getContext(), d0, d1, d2); auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); Location loc = reshapeOp->getLoc(); - for (auto [idx, reInd] : llvm::enumerate(reassociations)) { - OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; - OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; + for (auto [reInd, l, h] : + llvm::zip_equal(reassociations, collapsedPadding.lowPad, + collapsedPadding.highPad)) { if (reInd.size() == 1) { - collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; - OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( + expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply( rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); - expandedPaddedSizes[reInd[0]] = paddedSize; } - newLow.push_back(l); - newHigh.push_back(h); } RankedTensorType collapsedPaddedType = - paddedType.clone(collapsedPaddedShape); + padOp.getType().clone(collapsedPadding.paddedShape); auto newPadOp = tensor::PadOp::create( - rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), + collapsedPadding.lowPad, collapsedPadding.highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( @@ -2000,6 +2145,52 @@ private: ControlFusionFn controlFoldingReshapes; }; +class FoldReshapeWithProducerPadOpByCollapsing + : public OpRewritePattern<tensor::CollapseShapeOp> { +public: + FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context, + ControlFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>(); + if (!padOp) + return failure(); + + if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { + return rewriter.notifyMatchFailure(padOp, + "fusion blocked by control function"); + } + + SmallVector<ReassociationIndices> reassociations = + reshapeOp.getReassociationIndices(); + RankedTensorType collapsedPaddedType = reshapeOp.getResultType(); + FailureOr<PadDimInfo> maybeCollapsedPadding = + computeCollapsedPadding(padOp, reassociations, rewriter); + if (failed(maybeCollapsedPadding)) + return failure(); + PadDimInfo &collapsedPadding = maybeCollapsedPadding.value(); + + Location loc = reshapeOp->getLoc(); + auto newCollapseOp = tensor::CollapseShapeOp::create( + rewriter, loc, padOp.getSource(), reassociations); + + auto newPadOp = tensor::PadOp::create( + rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(), + collapsedPadding.lowPad, collapsedPadding.highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + + rewriter.replaceOp(reshapeOp, newPadOp.getResult()); + return success(); + } + +private: + ControlFusionFn controlFoldingReshapes; +}; + /// Pattern to collapse dimensions. template <typename LinalgType> class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { @@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(), controlFoldingReshapes); + patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(), + controlFoldingReshapes); patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), controlFoldingReshapes); } @@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( controlFoldingReshapes); patterns.add<FoldPadWithProducerReshapeOpByCollapsing>( patterns.getContext(), controlFoldingReshapes); + patterns.add<FoldReshapeWithProducerPadOpByCollapsing>( + patterns.getContext(), controlFoldingReshapes); patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(), controlFoldingReshapes); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp index 9974ccd..cbd6357 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -200,10 +200,10 @@ static void populateOpPayload( SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands(); updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); - SmallVector<OpOperand *> origOutputOperands = llvm::to_vector(llvm::map_range( - genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); - SmallVector<OpOperand *> newOutputOperands = llvm::to_vector(llvm::map_range( - newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector<OpOperand *> origOutputOperands = + llvm::to_vector(llvm::make_pointer_range(genericOp.getDpsInitsMutable())); + SmallVector<OpOperand *> newOutputOperands = + llvm::to_vector(llvm::make_pointer_range(newOp.getDpsInitsMutable())); updateReplacements(origOutputOperands, newOutputOperands, origOutsToNewOutsPos); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 9436f1c..161d978 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -913,8 +913,7 @@ static Value replaceByPackingResult(RewriterBase &rewriter, llvm_unreachable("loop independence prerequisite not met"); // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0]. - std::copy(loopIterationCounts.begin(), loopIterationCounts.end(), - offsets.begin()); + llvm::copy(loopIterationCounts, offsets.begin()); hoistedPackedTensor = scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front()) ->getResult(0); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d6..c2485a0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,69 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp); } +/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template <typename ConvOpTy> +static FailureOr<LinalgOp> +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) { + SmallVector<Value> inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + // Ops with no dilations and no strides. + if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> || + std::is_same_v<ConvOpTy, linalg::Conv2DOp> || + std::is_same_v<ConvOpTy, linalg::Conv3DOp>) { + namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; +} + +/// Converts linalg.generic to named linalg.*conv/pooling* where possible. +static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector<int64_t> dilations, strides; +#define CONV_OP_SPECIALIZER(ConvOpTy) \ + if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \ + return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \ + strides); \ + // ----------------------------- + // Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::Conv1DOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DOp); + CONV_OP_SPECIALIZER(linalg::Conv3DOp); + // ----------------------------- + // Depthwise Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); + // ----------------------------- + // Pooling ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp); +#undef CONV_OP_SPECIALIZER + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +379,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 705d6f2..8e14ef4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -452,8 +452,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); - if (!shapeSizesToLoopsMap) - return failure(); + assert(shapeSizesToLoopsMap && "invalid linalgOp with null ShapesToLoopsMap"); auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 8a0440b..50a84ac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -167,7 +167,7 @@ struct LinalgOpTilingInterface llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) { auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr); if (!dimExpr) - continue; + return failure(); unsigned position = dimExpr.getPosition(); auto it = mappedOffsets.find(position); if (it != mappedOffsets.end()) { @@ -357,6 +357,32 @@ struct LinalgOpTilingInterface /// Inline the op payload and store the result. return inlinePayload(builder, linalgOp, ivs, indexedValues); } + + bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber, + ArrayRef<OpFoldResult> offsets, + ArrayRef<OpFoldResult> sizes) const { + // The verifier gives all the necessary requirements for consumer fusion. + return true; + } + + bool isOpFusableWithProducerSlices( + Operation *op, ArrayRef<unsigned> operandNumbers, + ArrayRef<SmallVector<OpFoldResult>> allOffsets, + ArrayRef<SmallVector<OpFoldResult>> allSizes) const { + + auto linalgOp = cast<LinalgOp>(op); + SmallVector<AffineMap> indexingMaps = + llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { + OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); + return linalgOp.getMatchingIndexingMap(&opOperand); + }); + // Check that offsets/sizes are consistent across all operands. + OpBuilder b(op); + SmallVector<OpFoldResult> mappedOffsets, mappedSizes; + return succeeded(getMappedOffsetAndSize(linalgOp, b, indexingMaps, + allOffsets, allSizes, mappedOffsets, + mappedSizes)); + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 027268c..67e2b9f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1167,12 +1167,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( "this is not supported ATM!"); } - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); - int64_t destRank = packOp.getDestRank(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1262,14 +1259,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( writeSizes.push_back(tileSizeOfr); } - // TODO: Add a constructor for tensor.insert_slice that doesn't require - // strides nor offsets. - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - auto insert = tensor::InsertSliceOp::create( - rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), - writeOffsets, writeSizes, writeStrides); + rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes); // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); @@ -1279,7 +1270,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { - int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); @@ -1296,7 +1286,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Value source = unpackOp.getSource(); DenseMap<int64_t, OpFoldResult> dimAndTileMapping = unpackOp.getDimAndTileMapping(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of @@ -1307,9 +1296,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // outer-tiled-dims being all 1), this will be // [ outer-untiled-dims, tile-sizes ] SmallVector<OpFoldResult> extractSliceSizes; - // The offset and strides attributes for ExtractSliceOp. - SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr); - SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr); // Shape for EmptyOp that's used as the init value for TransposeOp below. // This should be: @@ -1364,8 +1350,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); Value innerTile = tensor::ExtractSliceOp::create( - rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets, - extractSliceSizes, extractSliceStrides); + rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes); // 2. Transpose the tile to match the outer corresponding tile order. SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( @@ -1381,9 +1366,6 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. - int numLoops = shapeForEmptyOp.size(); - SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); - SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); SmallVector<OpFoldResult> tileSizes; ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); for (auto i : llvm::seq<unsigned>(0, destRank)) { @@ -1393,13 +1375,11 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( } auto partialTile = - tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0], - tileOffsets, tileSizes, tileStrides); + tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(), + transposedOp.getResult()[0], tileSizes); // 4. Insert the result to the destination tensor. SmallVector<OpFoldResult> writeSizes; - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); for (int i = 0, idx = 0; i < destRank; ++i) { if (dimAndTileMapping.count(i) || destShape[i] != 1) writeSizes.push_back(tileSizes[idx++]); @@ -1407,8 +1387,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( writeSizes.push_back(oneIdxAttr); } auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile, - unpackOp.getDest(), writeOffsets, - writeSizes, writeStrides); + unpackOp.getDest(), writeSizes); rewriter.replaceOp(unpackOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 19d2d85..bb3bccd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, auto vectorType = state.getCanonicalVecType( getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap); + SmallVector<Value> indices(linalgOp.getRank(outputOperand), + arith::ConstantIndexOp::create(rewriter, loc, 0)); + Operation *write; if (vectorType.getRank() > 0) { AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); - SmallVector<Value> indices( - linalgOp.getRank(outputOperand), - arith::ConstantIndexOp::create(rewriter, loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create( @@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, value = vector::BroadcastOp::create(rewriter, loc, vectorType, value); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create(rewriter, loc, value, - outputOperand->get(), ValueRange{}); + outputOperand->get(), indices); } write = state.maskOperation(rewriter, write, linalgOp, opOperandMap); @@ -1890,9 +1890,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, // Create masked TransferReadOp. auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue, - useInBoundsInsteadOfMasking, - /*inputScalableVecSizes=*/{}); + rewriter, loc, packOp.getSource(), readVecType, padValue, + useInBoundsInsteadOfMasking); // Create ShapeCastOp. auto shapeCastOp = vector::ShapeCastOp::create( @@ -1977,9 +1976,12 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, } // -- Generate the read operation -- + VectorType readVecType = + VectorType::get(readVectorSizes, unpackTensorType.getElementType(), + readScalableVectorFlags); Value readResult = vector::createReadOrMaskedRead( - rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt, - useInBoundsInsteadOfMasking, readScalableVectorFlags); + rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt, + useInBoundsInsteadOfMasking); // -- Generate the transpose operation -- PackingMetadata packMetadata; @@ -2025,9 +2027,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, .reifyResultShapes(rewriter, reifiedReturnShapes); (void)status; // prevent unused variable warning on non-assert builds assert(succeeded(status) && "failed to reify result shapes"); + auto readType = VectorType::get(inputVectorSizes, padValue.getType()); auto maskedRead = vector::createReadOrMaskedRead( - rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{}); + rewriter, loc, padOp.getSource(), readType, padValue, + /*useInBoundsInsteadOfMasking=*/false); // Create Xfer write Op Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], @@ -2222,9 +2225,9 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, opOperand.get(), readType.getShape(), + rewriter, loc, opOperand.get(), readType, /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), - /*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims()); + /*useInBoundsInsteadOfMasking=*/false); vecOperands.push_back(read); } @@ -3165,9 +3168,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, SmallVector<Value> readIndices( vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( - rewriter, loc, source, vecType.getShape(), padValue, - /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(), - /*inputScalableVecSizes=*/{}); + rewriter, loc, source, vecType, padValue, + /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); // Create write auto writeIndices = diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 6eeb206..01e6e1e 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -235,6 +235,731 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Returns the BlockArgument that leads to `val`, if any. Traverses optional +/// ext* ops. +static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { + BlockArgument blockArg = dyn_cast<BlockArgument>(val); + if ((blockArg)) + return blockArg; + + Operation *defOp = val.getDefiningOp(); + if (!dyn_cast_if_present<arith::ExtFOp>(defOp) && + !dyn_cast_if_present<arith::ExtSIOp>(defOp) && + !dyn_cast_if_present<arith::ExtUIOp>(defOp)) { + return nullptr; + } + return dyn_cast<BlockArgument>(defOp->getOperand(0)); +} + +/// Utility to match block body for convolution ops. +/// The body is thus expected to yield :- +/// %out + (%lhs * %rhs) +/// where: %lhs, %rhs and %out are block arguments and +/// %lhs and %rhs can have optional upcast operation. +static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { + Operation *addOp = yieldVal.getDefiningOp(); + if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp)) + return false; + + Operation *mulOp = addOp->getOperand(1).getDefiningOp(); + if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp)) + return false; + + BlockArgument lhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + BlockArgument rhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || + lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || + rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) + return false; + return true; +} + +/// Utility to match block body for linalg.pool* ops. +template <typename... OpTypes> +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present<OpTypes>(defOp) || ...)) + return false; + + BlockArgument lhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(0)); + BlockArgument rhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(1)); + if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || + rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || + rhsArg.getArgNumber() != 0) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, + body); +} + +// max_unsigned ops should not allow float data type. +// TODO(#164800): Retire OPDSL logic. +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, + body); +} + +// min_unsigned ops should not allow float data type. +// TODO(#164800): Retire OPDSL logic. +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body); +} + +static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, + uint32_t dimIndex) { + auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +/// Check if `expr` is either: +/// - a dimension expr alone (implying multiplication by 1), or +/// - a multiplication of dimension expr by any positive constant != 1 +/// In both cases we will capture the dimension expression into `dim` and +/// return the constant multiplier. Returns -1 in case of a match failure. +static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) { + if ((dim = dyn_cast<AffineDimExpr>(expr))) + return 1; + + auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return -1; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + AffineConstantExpr cst = nullptr; + if (((dim = dyn_cast<AffineDimExpr>(lhs)) && + (cst = dyn_cast<AffineConstantExpr>(rhs))) || + ((dim = dyn_cast<AffineDimExpr>(rhs)) && + (cst = dyn_cast<AffineConstantExpr>(lhs)))) + return cst.getValue(); + return -1; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following +/// commutatively:- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * <c0> + +/// indexingMaps[n-1].getResult(oDim) * <c1> +/// where, +/// - c0 and c1 can be any constant, +/// - n is the size of the indexingMaps' array, +/// - 0, 1 and n-1 are input, filter and output map indices respectively, +/// - iDim, fDim and oDim are the input, filter and output dimension +/// indices in their respective indexing maps +/// Example: +/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) +/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)> +/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +/// +/// Here, +/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3 +/// Therefore, +/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride) +/// would return true and update dilation = 3 and stride = 2 +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned inputMapIdx = 0, filterMapIdx = 1, + outputMapIdx = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); + auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0); + int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1); + + if (c0 == -1 || c1 == -1) + return false; + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } + if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + return false; +} + +/// Returns true if the given indexing maps matches with the expected indexing +/// maps. +static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected, + ArrayAttr indexingMaps, MLIRContext *context) { + SmallVector<AffineMap, 4> expectedIndexingMaps = + AffineMap::inferFromExprList(mapListExpected, context); + return indexingMaps == + ArrayAttr::get( + context, llvm::to_vector<4>(llvm::map_range( + expectedIndexingMaps, [&](AffineMap m) -> Attribute { + return AffineMapAttr::get(m); + }))); +} + +/// Enum representing pooling operation types used by ConvMatcherBuilder. +enum class PoolingType { + None, + MaxSigned, + MaxUnsigned, + MinSigned, + MinUnsigned, + Sum +}; + +/// Helper class for building convolution op matchers with minimal boilerplate. +/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well +/// as Pooling ops. +/// +/// Usage: Create an instance with the op, spatial rank, and output pointers for +/// extracted dilations/strides. Then chain matchStride() calls for each spatial +/// dimension, followed by matchMaps() to verify indexing maps, and finally +/// matchBody() to verify the operation body pattern. +/// +/// The `matched` flag starts as `true` and is set to `false` if any match step +/// fails. This allows chaining multiple match calls; once any match fails, all +/// subsequent calls become no-ops and the final result is `false`. +/// +/// The `dilations` and `strides` pointers are output parameters that get +/// populated with the extracted dilation and stride values from the operation's +/// indexing maps during matchStride() calls. These values are initially set to +/// 1 for each spatial dimension and updated as patterns are matched. +class ConvMatcherBuilder { + LinalgOp op; + MLIRContext *ctx; + SmallVector<int64_t> *dilations, *strides; + ArrayAttr indexingMaps; + PoolingType poolingType; + bool matched = true; + +public: + ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d, + SmallVector<int64_t> *s, + PoolingType poolingType = PoolingType::None) + : op(op), ctx(op->getContext()), dilations(d), strides(s), + indexingMaps(op.getIndexingMaps()), poolingType(poolingType) { + *dilations = SmallVector<int64_t>(spatialRank, 1); + *strides = SmallVector<int64_t>(spatialRank, 1); + } + + /// Get affine dimension expression for dimension `i`. + AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); } + + /// Build strided expression: base * stride[idx] + kernel * dilation[idx]. + AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) { + return base * (*strides)[idx] + kernel * (*dilations)[idx]; + } + + /// Match stride/dilation pattern for a spatial dimension. + /// Returns *this for method chaining. + ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim, + unsigned idx) { + if (matched) { + matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim, + (*dilations)[idx], (*strides)[idx]); + } + return *this; + } + + /// Match expected indexing maps layout. Returns *this for method chaining. + ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) { + if (matched) + matched &= convLayoutMatches(maps, indexingMaps, ctx); + return *this; + } + + /// Match body pattern. This should be called last. + bool matchBody() { + if (!matched) + return false; + Block *body = op.getBlock(); + auto yieldOp = cast<linalg::YieldOp>(body->getTerminator()); + switch (poolingType) { + case PoolingType::None: + return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body); + case PoolingType::MaxSigned: + return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MaxUnsigned: + return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MinSigned: + return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MinUnsigned: + return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::Sum: + return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body); + } + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Matchers for specific convolution operation. +//===----------------------------------------------------------------------===// + +// #inputMap = affine_map<(W, w) -> (W + w)> +// #filterMap = affine_map<(W, w) -> (w)> +// #outputMap = affine_map<(W, w) -> (W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr W = m.dim(0); + AffineExpr w = m.dim(1); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchMaps({/*inputMap=*/{m.strided(W, w, 0)}, + /*filterMap=*/{w}, + /*outputMap=*/{W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)> +// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)> +// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DNwcWcfOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr w = m.dim(3); + AffineExpr c = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c}, + /*filterMap=*/{w, c, F}, + /*outputMap=*/{N, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)> +// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)> +// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv1DNcwFcwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr c = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)}, + /*filterMap=*/{F, c, w}, + /*outputMap=*/{N, F, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)> +// #filterMap = affine_map<(H, W, h, w) -> (h, w)> +// #outputMap = affine_map<(H, W, h, w) -> (H, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv2DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr H = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr h = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{h, w}, + /*outputMap=*/{H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> +// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)> +// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)> +template <> +bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op, + SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::Conv3DOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr D = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr d = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2) + .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2)}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{D, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)> +// #filterMap = affine_map<(N, W, C, w) -> (C, w)> +// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, + /*filterMap=*/{C, w}, + /*outputMap=*/{N, C, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, w) -> (w, C)> +// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C}, + /*outputMap=*/{N, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)> +// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr CM = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C, CM}, + /*outputMap=*/{N, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{C, h, w}, + /*outputMap=*/{N, C, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (d, h, w, C, CM)> +// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + AffineExpr C = m.dim(8); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w, C, CM}, + /*outputMap=*/{N, D, H, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMaxOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MaxSigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMinOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MinSigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcSumOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::Sum); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MaxUnsigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>( + LinalgOp op, SmallVector<int64_t> *dilations, + SmallVector<int64_t> *strides) { + if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MinUnsigned); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 1382c7ac..d358362 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRUBDialect MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp index 6ff63df..a1e3f10 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index dfa2e4e..5404238 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static bool isSupportedElementType(Type type) { - return llvm::isa<MemRefType>(type) || - OpBuilder(type.getContext()).getZeroAttr(type); -} - SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { MemRefType type = getType(); - if (!isSupportedElementType(type.getElementType())) - return {}; if (!type.hasStaticShape()) return {}; // Make sure the memref contains only a single element. @@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - assert(isSupportedElementType(slot.elemType)); - // TODO: support more types. - return TypeSwitch<Type, Value>(slot.elemType) - .Case([&](MemRefType t) { - return memref::AllocaOp::create(builder, getLoc(), t); - }) - .Default([&](Type t) { - return arith::ConstantOp::create(builder, getLoc(), t, - builder.getZeroAttr(t)); - }); + return ub::PoisonOp::create(builder, getLoc(), slot.elemType); } std::optional<PromotableAllocationOpInterface> diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1c21a2f..1035d7c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1074,13 +1074,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { return subview.getDynamicSize(sourceIndex); } - if (auto sizeInterface = - dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { - assert(sizeInterface.isDynamicSize(unsignedIndex) && - "Expected dynamic subview size"); - return sizeInterface.getDynamicSize(unsignedIndex); - } - // dim(memrefcast) -> dim if (succeeded(foldMemRefCast(*this))) return getResult(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index bd02516..c9352e8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,11 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); - if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) + // ViewLikeOpInterface by itself doesn't guarantee to preserve the base + // pointer in general and `memref.view` is one such example, so just check + // for a few specific cases. + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest() || + !isa<memref::SubViewOp, memref::ReinterpretCastOp>(viewLikeOp)) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 214410f..3667fdb 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -347,28 +347,55 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices, isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation())))) return failure(); - llvm::TypeSwitch<Operation *, void>(loadOp) + + return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp) .Case([&](affine::AffineLoadOp op) { rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices); + return success(); }) .Case([&](memref::LoadOp op) { rewriter.replaceOpWithNewOp<memref::LoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::LoadOp op) { rewriter.replaceOpWithNewOp<vector::LoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::MaskedLoadOp op) { rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getPassThru()); + return success(); + }) + .Case([&](vector::TransferReadOp op) { + // We only support minor identity maps in the permutation attribute. + if (!op.getPermutationMap().isMinorIdentity()) + return failure(); + + // We only support the case where the source of the expand shape has + // rank greater than or equal to the vector rank. + const int64_t sourceRank = sourceIndices.size(); + const int64_t vectorRank = op.getVectorType().getRank(); + if (sourceRank < vectorRank) + return failure(); + + // We need to construct a new minor identity map since we will have lost + // some dimensions in folding away the expand shape. + auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank, + op.getContext()); + + rewriter.replaceOpWithNewOp<vector::TransferReadOp>( + op, op.getVectorType(), expandShapeOp.getViewSource(), + sourceIndices, minorIdMap, op.getPadding(), op.getMask(), + op.getInBounds()); + return success(); }) .DefaultUnreachable("unexpected operation"); - return success(); } template <typename OpTy> @@ -659,6 +686,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { LoadOpOfExpandShapeOpFolder<memref::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, + LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>, StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, StoreOpOfExpandShapeOpFolder<memref::StoreOp>, StoreOpOfExpandShapeOpFolder<vector::StoreOp>, diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 6a81a15..c498c8a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> { if (!dimIndex) return failure(); - ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, dimValue.getOwner(), - reifiedResultShapes))) + FailureOr<OpFoldResult> replacement = reifyDimOfResult( + rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex); + if (failed(replacement)) return failure(); - unsigned resultNumber = dimValue.getResultNumber(); - // Do not apply pattern if the IR is invalid (dim out of bounds). - if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size()) - return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds"); - Value replacement = getValueOrCreateConstantIndexOp( - rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]); - rewriter.replaceOp(dimOp, replacement); + // Check if the OpFoldResult is empty (unreifiable dimension). + if (!replacement.value()) + return failure(); + Value replacementVal = getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), replacement.value()); + rewriter.replaceOp(dimOp, replacementVal); return success(); } }; @@ -166,12 +165,14 @@ namespace { struct ResolveRankedShapeTypeResultDimsPass final : public memref::impl::ResolveRankedShapeTypeResultDimsPassBase< ResolveRankedShapeTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; struct ResolveShapedTypeResultDimsPass final : public memref::impl::ResolveShapedTypeResultDimsPassBase< ResolveShapedTypeResultDimsPass> { + using Base::Base; void runOnOperation() override; }; @@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns( void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + auto result = applyPatternsGreedily(getOperation(), std::move(patterns)); + if (errorOnPatternIterationLimit && failed(result)) { + getOperation()->emitOpError( + "dim operation resolution hit pattern iteration limit"); return signalPassFailure(); + } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 14152c5..e5cc41e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -268,61 +268,82 @@ struct SubViewOpInterface MemRefType sourceType = subView.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); + auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); + for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + // Reset insertion point to before the operation for each dimension. builder.setInsertionPoint(subView); + Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp(builder, loc, subView.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = metadataOp.getSizes()[i]; - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage(op, diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 6200366..e548698 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -133,17 +133,20 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, } /// Returns true if all the uses of op are not read/load. -/// There can be SubviewOp users as long as all its users are also +/// There can be view-like-op users as long as all its users are also /// StoreOp/transfer_write. If return true it also fills out the uses, if it /// returns false uses is unchanged. static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) { std::vector<Operation *> opUses; for (OpOperand &use : op->getUses()) { Operation *useOp = use.getOwner(); + // Use escaped the scope + if (useOp->mightHaveTrait<OpTrait::IsTerminator>()) + return false; if (isa<memref::DeallocOp>(useOp) || (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 && !mlir::hasEffect<MemoryEffects::Read>(useOp)) || - (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) { + (isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) { opUses.push_back(useOp); continue; } diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 2a857ed..0d05313 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -675,7 +675,7 @@ MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { - auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn)); + auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn); Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); @@ -727,7 +727,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { toStore.push_back(v); }); - return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn)); + return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn); } static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, @@ -792,7 +792,7 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { if (failed(maybeInfo)) return failure(); - MmaSyncInfo info = *maybeInfo; + const MmaSyncInfo &info = *maybeInfo; auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; auto [lhsShape, rhsShape, resShape] = info.vectorShapes; Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef, diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp index 40e769e..1d775fb 100644 --- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) { return mlir::emitError(loc, "not yet implemented: " + message); } +bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) { + if (impl) + return impl->isValidSymbolUse(user, symbol, definingOpPtr); + return acc::isValidSymbolUse(user, symbol, definingOpPtr); +} + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 8c9c137..47f1222 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" @@ -203,12 +204,91 @@ struct MemRefPointerLikeModel return false; } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // Load from a memref - only valid for scalar memrefs (rank 0). + // This is because the address computation for memrefs is part of the load + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr); + if (!memrefValue) + return {}; + + auto memrefTy = memrefValue.getType(); + + // Only load from scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return {}; + + return memref::LoadOp::create(builder, loc, memrefValue); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + // Store to a memref - only valid for scalar memrefs (rank 0) + // This is because the address computation for memrefs is part of the store + // (and not computed separately), but the API does not have arguments for + // indexing. + auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr); + if (!memrefValue) + return false; + + auto memrefTy = memrefValue.getType(); + + // Only store to scalar memrefs (rank 0) + if (memrefTy.getRank() != 0) + return false; + + memref::StoreOp::create(builder, loc, valueToStore, memrefValue); + return true; + } }; struct LLVMPointerPointerLikeModel : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, LLVM::LLVMPointerType> { Type getElementType(Type pointer) const { return Type(); } + + mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> srcPtr, + Type valueType) const { + // For LLVM pointers, we need the valueType to determine what to load + if (!valueType) + return {}; + + return LLVM::LoadOp::create(builder, loc, valueType, srcPtr); + } + + bool genStore(Type pointer, OpBuilder &builder, Location loc, + Value valueToStore, TypedValue<PointerLikeType> destPtr) const { + LLVM::StoreOp::create(builder, loc, valueToStore, destPtr); + return true; + } +}; + +struct MemrefAddressOfGlobalModel + : public AddressOfGlobalOpInterface::ExternalModel< + MemrefAddressOfGlobalModel, memref::GetGlobalOp> { + SymbolRefAttr getSymbol(Operation *op) const { + auto getGlobalOp = cast<memref::GetGlobalOp>(op); + return getGlobalOp.getNameAttr(); + } +}; + +struct MemrefGlobalVariableModel + : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel, + memref::GlobalOp> { + bool isConstant(Operation *op) const { + auto globalOp = cast<memref::GlobalOp>(op); + return globalOp.getConstant(); + } + + Region *getInitRegion(Operation *op) const { + // GlobalOp uses attributes for initialization, not regions + return nullptr; + } }; /// Helper function for any of the times we need to modify an ArrayAttr based on @@ -302,6 +382,11 @@ void OpenACCDialect::initialize() { MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext()); LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( *getContext()); + + // Attach operation interfaces + memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>( + *getContext()); + memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext()); } //===----------------------------------------------------------------------===// @@ -467,6 +552,28 @@ checkValidModifier(Op op, acc::DataClauseModifier validModifiers) { return success(); } +template <typename OpT, typename RecipeOpT> +static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) { + // Mappable types do not need a recipe because it is possible to generate one + // from its API. Reject reductions though because no API is available for them + // at this time. + if (mlir::acc::isMappableType(op.getVar().getType()) && + !std::is_same_v<OpT, acc::ReductionOp>) + return success(); + + mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr(); + if (!operandRecipe) + return op->emitOpError() << "recipe expected for " << operandName; + + auto decl = + SymbolTable::lookupNearestSymbolFrom<RecipeOpT>(op, operandRecipe); + if (!decl) + return op->emitOpError() + << "expected symbol reference " << operandRecipe << " to point to a " + << operandName << " declaration"; + return success(); +} + static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var) { // Either `var` or `varPtr` keyword is required. @@ -573,6 +680,18 @@ static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, } } +static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, + mlir::SymbolRefAttr &recipeAttr) { + if (failed(parser.parseAttribute(recipeAttr))) + return failure(); + return success(); +} + +static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::SymbolRefAttr recipeAttr) { + p << recipeAttr; +} + //===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// @@ -595,6 +714,9 @@ LogicalResult acc::PrivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed( + checkRecipe<acc::PrivateOp, acc::PrivateRecipeOp>(*this, "private"))) + return failure(); return success(); } @@ -609,6 +731,9 @@ LogicalResult acc::FirstprivateOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::FirstprivateOp, acc::FirstprivateRecipeOp>( + *this, "firstprivate"))) + return failure(); return success(); } @@ -637,6 +762,9 @@ LogicalResult acc::ReductionOp::verify() { return failure(); if (failed(checkNoModifier(*this))) return failure(); + if (failed(checkRecipe<acc::ReductionOp, acc::ReductionRecipeOp>( + *this, "reduction"))) + return failure(); return success(); } @@ -1322,6 +1450,28 @@ PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, return recipe; } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, + FirstprivateRecipeOp firstprivRecipe) { + // Create the private.recipe op with the same type as the firstprivate.recipe. + OpBuilder::InsertionGuard guard(builder); + auto varType = firstprivRecipe.getType(); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Clone the init region + IRMapping mapping; + firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping); + + // Clone destroy region if the firstprivate.recipe has one. + if (!firstprivRecipe.getDestroyRegion().empty()) { + IRMapping mapping; + firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(), + mapping); + } + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1432,40 +1582,6 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() { } //===----------------------------------------------------------------------===// -// Custom parser and printer verifier for private clause -//===----------------------------------------------------------------------===// - -static ParseResult parseSymOperandList( - mlir::OpAsmParser &parser, - llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, - llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) { - llvm::SmallVector<SymbolRefAttr> attributes; - if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseAttribute(attributes.emplace_back()) || - parser.parseArrow() || - parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - return success(); - }))) - return failure(); - llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), - attributes.end()); - symbols = ArrayAttr::get(parser.getContext(), arrayAttr); - return success(); -} - -static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange operands, - mlir::TypeRange types, - std::optional<mlir::ArrayAttr> attributes) { - llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { - p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " - << std::get<1>(it).getType(); - }); -} - -//===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -1484,45 +1600,19 @@ static LogicalResult checkDataOperands(Op op, return success(); } -template <typename Op> -static LogicalResult -checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes, - mlir::OperandRange operands, llvm::StringRef operandName, - llvm::StringRef symbolName, bool checkOperandType = true) { - if (!operands.empty()) { - if (!attributes || attributes->size() != operands.size()) - return op->emitOpError() - << "expected as many " << symbolName << " symbol reference as " - << operandName << " operands"; - } else { - if (attributes) - return op->emitOpError() - << "unexpected " << symbolName << " symbol reference"; - return success(); - } - +template <typename OpT, typename RecipeOpT> +static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, + const mlir::ValueRange &operands, + llvm::StringRef operandName) { llvm::DenseSet<Value> set; - for (auto args : llvm::zip(operands, *attributes)) { - mlir::Value operand = std::get<0>(args); - + for (mlir::Value operand : operands) { + if (!mlir::isa<OpT>(operand.getDefiningOp())) + return accConstructOp->emitOpError() + << "expected " << operandName << " as defining op"; if (!set.insert(operand).second) - return op->emitOpError() + return accConstructOp->emitOpError() << operandName << " operand appears more than once"; - - mlir::Type varType = operand.getType(); - auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); - auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef); - if (!decl) - return op->emitOpError() - << "expected symbol reference " << symbolRef << " to point to a " - << operandName << " declaration"; - - if (checkOperandType && decl.getType() && decl.getType() != varType) - return op->emitOpError() << "expected " << operandName << " (" << varType - << ") to be the same type as " << operandName - << " declaration (" << decl.getType() << ")"; } - return success(); } @@ -1579,17 +1669,17 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch( } LogicalResult acc::ParallelOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -1720,7 +1810,6 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, mlir::ValueRange gangPrivateOperands, mlir::ValueRange gangFirstPrivateOperands, mlir::ValueRange dataClauseOperands) { - ParallelOp::build( odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr, @@ -1729,9 +1818,8 @@ void ParallelOp::build(mlir::OpBuilder &odsBuilder, /*numGangsDeviceType=*/nullptr, numWorkers, /*numWorkersDeviceType=*/nullptr, vectorLength, /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond, - /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr, - gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands, - /*firstprivatizations=*/nullptr, dataClauseOperands, + /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands, + gangFirstPrivateOperands, dataClauseOperands, /*defaultAttr=*/nullptr, /*combined=*/nullptr); } @@ -1808,46 +1896,22 @@ void acc::ParallelOp::addWaitOperands( void acc::ParallelOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::ParallelOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } static ParseResult parseNumGangs( @@ -2415,17 +2479,17 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { } LogicalResult acc::SerialOp::verify() { - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( @@ -2489,46 +2553,22 @@ void acc::SerialOp::addWaitOperands( void acc::SerialOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::SerialOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -2658,6 +2698,27 @@ LogicalResult acc::KernelsOp::verify() { return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands()); } +void acc::KernelsOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getPrivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getFirstprivateOperandsMutable().append(op.getResult()); +} + +void acc::KernelsOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); + getReductionOperandsMutable().append(op.getResult()); +} + void acc::KernelsOp::addNumWorkersOperand( MLIRContext *context, mlir::Value newValue, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { @@ -2967,19 +3028,21 @@ bool hasDuplicateDeviceTypes( } /// Check for duplicates in the DeviceType array attribute. -LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { +/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found. +static std::optional<mlir::acc::DeviceType> +checkDeviceTypes(mlir::ArrayAttr deviceTypes) { llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes; if (!deviceTypes) - return success(); + return std::nullopt; for (auto attr : deviceTypes) { auto deviceTypeAttr = mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr); if (!deviceTypeAttr) - return failure(); + return mlir::acc::DeviceType::None; if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second) - return failure(); + return deviceTypeAttr.getValue(); } - return success(); + return std::nullopt; } LogicalResult acc::LoopOp::verify() { @@ -3006,9 +3069,10 @@ LogicalResult acc::LoopOp::verify() { getCollapseDeviceTypeAttr().getValue().size()) return emitOpError() << "collapse attribute count must match collapse" << " device_type count"; - if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) - return emitOpError() - << "duplicate device_type found in collapseDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in collapseDeviceType attribute"; // Check gang if (!getGangOperands().empty()) { @@ -3021,8 +3085,12 @@ LogicalResult acc::LoopOp::verify() { return emitOpError() << "gangOperandsArgType attribute count must match" << " gangOperands count"; } - if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) - return emitOpError() << "duplicate device_type found in gang attribute"; + if (getGangAttr()) { + if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in gang attribute"; + } if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getGangOperands(), getGangOperandsSegmentsAttr(), @@ -3030,22 +3098,30 @@ LogicalResult acc::LoopOp::verify() { return failure(); // Check worker - if (failed(checkDeviceTypes(getWorkerAttr()))) - return emitOpError() << "duplicate device_type found in worker attribute"; - if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "workerNumOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in worker attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in workerNumOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), getWorkerNumOperandsDeviceTypeAttr(), "worker"))) return failure(); // Check vector - if (failed(checkDeviceTypes(getVectorAttr()))) - return emitOpError() << "duplicate device_type found in vector attribute"; - if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) - return emitOpError() << "duplicate device_type found in " - "vectorOperandsDeviceType attribute"; + if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vector attribute"; + if (auto duplicateDeviceType = + checkDeviceTypes(getVectorOperandsDeviceTypeAttr())) + return emitOpError() << "duplicate device_type `" + << acc::stringifyDeviceType(*duplicateDeviceType) + << "` found in vectorOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), getVectorOperandsDeviceTypeAttr(), "vector"))) @@ -3110,19 +3186,19 @@ LogicalResult acc::LoopOp::verify() { } } - if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( - *this, getPrivatizationRecipes(), getPrivateOperands(), "private", - "privatizations", false))) + if (failed(checkPrivateOperands<mlir::acc::PrivateOp, + mlir::acc::PrivateRecipeOp>( + *this, getPrivateOperands(), "private"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>( - *this, getFirstprivatizationRecipes(), getFirstprivateOperands(), - "firstprivate", "firstprivatizations", /*checkOperandType=*/false))) + if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp, + mlir::acc::FirstprivateRecipeOp>( + *this, getFirstprivateOperands(), "firstprivate"))) return failure(); - if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( - *this, getReductionRecipes(), getReductionOperands(), "reduction", - "reductions", false))) + if (failed(checkPrivateOperands<mlir::acc::ReductionOp, + mlir::acc::ReductionRecipeOp>( + *this, getReductionOperands(), "reduction"))) return failure(); if (getCombined().has_value() && @@ -3556,45 +3632,21 @@ void acc::LoopOp::addGangOperands( void acc::LoopOp::addPrivatization(MLIRContext *context, mlir::acc::PrivateOp op, mlir::acc::PrivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getPrivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getPrivatizationRecipesAttr()) - llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addFirstPrivatization( MLIRContext *context, mlir::acc::FirstprivateOp op, mlir::acc::FirstprivateRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getFirstprivateOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getFirstprivatizationRecipesAttr()) - llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, mlir::acc::ReductionRecipeOp recipe) { + op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName())); getReductionOperandsMutable().append(op.getResult()); - - llvm::SmallVector<mlir::Attribute> recipes; - - if (getReductionRecipesAttr()) - llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); - - recipes.push_back( - mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); - setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } //===----------------------------------------------------------------------===// @@ -4059,7 +4111,8 @@ LogicalResult acc::RoutineOp::verify() { if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " - "be present at the same time"; + "be present at the same time for device_type `" + << acc::stringifyDeviceType(dtype) << "`"; } return success(); @@ -4356,6 +4409,100 @@ RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { return std::nullopt; } +void RoutineOp::addSeq(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addVector(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addWorker(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes) { + setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(), + effectiveDeviceTypes)); +} + +void RoutineOp::addGang(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + uint64_t val) { + llvm::SmallVector<mlir::Attribute> dimValues; + llvm::SmallVector<mlir::Attribute> deviceTypes; + + if (getGangDimAttr()) + llvm::copy(getGangDimAttr(), std::back_inserter(dimValues)); + if (getGangDimDeviceTypeAttr()) + llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes)); + + assert(dimValues.size() == deviceTypes.size()); + + if (effectiveDeviceTypes.empty()) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back( + acc::DeviceTypeAttr::get(context, acc::DeviceType::None)); + } else { + for (DeviceType dt : effectiveDeviceTypes) { + dimValues.push_back( + mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val)); + deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt)); + } + } + assert(dimValues.size() == deviceTypes.size()); + + setGangDimAttr(mlir::ArrayAttr::get(context, dimValues)); + setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes)); +} + +void RoutineOp::addBindStrName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::StringAttr val) { + unsigned before = getBindStrNameDeviceTypeAttr() + ? getBindStrNameDeviceTypeAttr().size() + : 0; + + setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindStrNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindStrNameAttr()) + llvm::copy(getBindStrNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindStrNameAttr(mlir::ArrayAttr::get(context, vals)); +} + +void RoutineOp::addBindIDName(MLIRContext *context, + llvm::ArrayRef<DeviceType> effectiveDeviceTypes, + mlir::SymbolRefAttr val) { + unsigned before = + getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0; + + setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper( + context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes)); + unsigned after = getBindIdNameDeviceTypeAttr().size(); + + llvm::SmallVector<mlir::Attribute> vals; + if (getBindIdNameAttr()) + llvm::copy(getBindIdNameAttr(), std::back_inserter(vals)); + for (unsigned i = 0; i < after - before; ++i) + vals.push_back(val); + + setBindIdNameAttr(mlir::ArrayAttr::get(context, vals)); +} + //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// @@ -4739,3 +4886,12 @@ mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { .Default([&](mlir::Operation *) { return nullptr; })}; return dataOperands; } + +mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) { + auto recipe{ + llvm::TypeSwitch<mlir::Operation *, mlir::SymbolRefAttr>(accOp) + .Case<ACC_DATA_ENTRY_OPS>( + [&](auto entry) { return entry.getRecipeAttr(); }) + .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })}; + return recipe; +} diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp index 91262bd..67cdf10 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp @@ -237,11 +237,6 @@ public: void runOnOperation() override; private: - /// Collects all data clauses that dominate the compute construct. - /// Needed to determine if a variable is already covered by an existing data - /// clause. - SmallVector<Value> getDominatingDataClauses(Operation *computeConstructOp); - /// Looks through the `dominatingDataClauses` to find the original data clause /// op for an alias. Returns nullptr if no original data clause op is found. template <typename OpT> @@ -277,8 +272,7 @@ private: /// Generates recipes for a list of variables. void generateRecipes(ModuleOp &module, OpBuilder &builder, Operation *computeConstructOp, - const SmallVector<Value> &newOperands, - SmallVector<Attribute> &newRecipeSyms); + const SmallVector<Value> &newOperands); }; /// Determines if a variable is a candidate for implicit data mapping. @@ -301,62 +295,6 @@ static bool isCandidateForImplicitData(Value val, Region &accRegion) { return true; } -SmallVector<Value> -ACCImplicitData::getDominatingDataClauses(Operation *computeConstructOp) { - llvm::SmallSetVector<Value, 8> dominatingDataClauses; - - llvm::TypeSwitch<Operation *>(computeConstructOp) - .Case<acc::ParallelOp, acc::KernelsOp, acc::SerialOp>([&](auto op) { - for (auto dataClause : op.getDataClauseOperands()) { - dominatingDataClauses.insert(dataClause); - } - }) - .Default([](Operation *) {}); - - // Collect the data clauses from enclosing data constructs. - Operation *currParentOp = computeConstructOp->getParentOp(); - while (currParentOp) { - if (isa<acc::DataOp>(currParentOp)) { - for (auto dataClause : - dyn_cast<acc::DataOp>(currParentOp).getDataClauseOperands()) { - dominatingDataClauses.insert(dataClause); - } - } - currParentOp = currParentOp->getParentOp(); - } - - // Find the enclosing function/subroutine - auto funcOp = computeConstructOp->getParentOfType<FunctionOpInterface>(); - if (!funcOp) - return dominatingDataClauses.takeVector(); - - // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that - // dominate and post-dominate the compute construct and add their data - // clauses to the list. - auto &domInfo = this->getAnalysis<DominanceInfo>(); - auto &postDomInfo = this->getAnalysis<PostDominanceInfo>(); - funcOp->walk([&](acc::DeclareEnterOp declareEnterOp) { - if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) { - // Collect all `acc.declare_exit` ops for this token. - SmallVector<acc::DeclareExitOp> exits; - for (auto *user : declareEnterOp.getToken().getUsers()) - if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user)) - exits.push_back(declareExit); - - // Only add clauses if every `acc.declare_exit` op post-dominates the - // compute construct. - if (!exits.empty() && llvm::all_of(exits, [&](acc::DeclareExitOp exitOp) { - return postDomInfo.postDominates(exitOp, computeConstructOp); - })) { - for (auto dataClause : declareEnterOp.getDataClauseOperands()) - dominatingDataClauses.insert(dataClause); - } - } - }); - - return dominatingDataClauses.takeVector(); -} - template <typename OpT> Operation *ACCImplicitData::getOriginalDataClauseOpForAlias( Value var, OpBuilder &builder, OpT computeConstructOp, @@ -453,23 +391,23 @@ ACCImplicitData::generateFirstprivateRecipe(ModuleOp &module, Value var, void ACCImplicitData::generateRecipes(ModuleOp &module, OpBuilder &builder, Operation *computeConstructOp, - const SmallVector<Value> &newOperands, - SmallVector<Attribute> &newRecipeSyms) { + const SmallVector<Value> &newOperands) { auto &accSupport = this->getAnalysis<acc::OpenACCSupport>(); for (auto var : newOperands) { auto loc{var.getLoc()}; - if (isa<acc::PrivateOp>(var.getDefiningOp())) { + if (auto privateOp = dyn_cast<acc::PrivateOp>(var.getDefiningOp())) { auto recipe = generatePrivateRecipe( module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); if (recipe) - newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(), - recipe.getSymName().str())); - } else if (isa<acc::FirstprivateOp>(var.getDefiningOp())) { + privateOp.setRecipeAttr( + SymbolRefAttr::get(module->getContext(), recipe.getSymName())); + } else if (auto firstprivateOp = + dyn_cast<acc::FirstprivateOp>(var.getDefiningOp())) { auto recipe = generateFirstprivateRecipe( module, acc::getVar(var.getDefiningOp()), loc, builder, accSupport); if (recipe) - newRecipeSyms.push_back(SymbolRefAttr::get(module->getContext(), - recipe.getSymName().str())); + firstprivateOp.setRecipeAttr(SymbolRefAttr::get( + module->getContext(), recipe.getSymName().str())); } else { accSupport.emitNYI(var.getLoc(), "implicit reduction"); } @@ -570,6 +508,8 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate( newDataOp = acc::PresentOp::create(builder, loc, var, /*structured=*/true, /*implicit=*/true, accSupport.getVariableName(var)); + newDataOp->setAttr(acc::getFromDefaultClauseAttrName(), + builder.getUnitAttr()); } else { auto copyinOp = acc::CopyinOp::create(builder, loc, var, @@ -611,56 +551,22 @@ static void legalizeValuesInRegion(Region &accRegion, } } -// Adds the private operands and private recipes to the data construct -// operation in a valid way (ensures that the index in the privatizationRecipes -// array matches the position of the private operand). +// Adds the private operands to the compute construct operation. template <typename OpT> -static void -addNewPrivateOperands(OpT &accOp, const SmallVector<Value> &privateOperands, - const SmallVector<Attribute> &privateRecipeSyms) { - assert(privateOperands.size() == privateRecipeSyms.size()); +static void addNewPrivateOperands(OpT &accOp, + const SmallVector<Value> &privateOperands) { if (privateOperands.empty()) return; - SmallVector<Attribute> completePrivateRecipesSyms; - SmallVector<Attribute> completeFirstprivateRecipesSyms; - SmallVector<Value> newPrivateOperands; - SmallVector<Value> newFirstprivateOperands; - - // Collect all of the existing recipes since they are held in an attribute. - // To add to it, we need to create a brand new one. - if (accOp.getPrivatizationRecipes().has_value()) - for (auto privatization : accOp.getPrivatizationRecipesAttr()) - completePrivateRecipesSyms.push_back(privatization); - if (accOp.getFirstprivatizationRecipes().has_value()) - for (auto privatization : accOp.getFirstprivatizationRecipesAttr()) - completeFirstprivateRecipesSyms.push_back(privatization); - - // Now separate between private and firstprivate operands. - for (auto [priv, privateRecipeSym] : - llvm::zip(privateOperands, privateRecipeSyms)) { + for (auto priv : privateOperands) { if (isa<acc::PrivateOp>(priv.getDefiningOp())) { - newPrivateOperands.push_back(priv); - completePrivateRecipesSyms.push_back(privateRecipeSym); + accOp.getPrivateOperandsMutable().append(priv); } else if (isa<acc::FirstprivateOp>(priv.getDefiningOp())) { - newFirstprivateOperands.push_back(priv); - completeFirstprivateRecipesSyms.push_back(privateRecipeSym); + accOp.getFirstprivateOperandsMutable().append(priv); } else { - llvm_unreachable("unhandled private operand"); + llvm_unreachable("unhandled reduction operand"); } } - - // Append all of the new private operands to their appropriate list. - accOp.getPrivateOperandsMutable().append(newPrivateOperands); - accOp.getFirstprivateOperandsMutable().append(newFirstprivateOperands); - - // Update the privatizationRecipes attributes to hold all of the new recipes. - if (!completePrivateRecipesSyms.empty()) - accOp.setPrivatizationRecipesAttr( - ArrayAttr::get(accOp.getContext(), completePrivateRecipesSyms)); - if (!completeFirstprivateRecipesSyms.empty()) - accOp.setFirstprivatizationRecipesAttr( - ArrayAttr::get(accOp.getContext(), completeFirstprivateRecipesSyms)); } static Operation *findDataExitOp(Operation *dataEntryOp) { @@ -808,7 +714,10 @@ void ACCImplicitData::generateImplicitDataOps( LLVM_DEBUG(llvm::dbgs() << "== Generating clauses for ==\n" << computeConstructOp << "\n"); } - auto dominatingDataClauses = getDominatingDataClauses(computeConstructOp); + auto &domInfo = this->getAnalysis<DominanceInfo>(); + auto &postDomInfo = this->getAnalysis<PostDominanceInfo>(); + auto dominatingDataClauses = + acc::getDominatingDataClauses(computeConstructOp, domInfo, postDomInfo); for (auto var : candidateVars) { auto newDataClauseOp = generateDataClauseOpForCandidate( var, module, builder, computeConstructOp, dominatingDataClauses, @@ -829,13 +738,11 @@ void ACCImplicitData::generateImplicitDataOps( // of the data clause ops) legalizeValuesInRegion(accRegion, newPrivateOperands, newDataClauseOperands); - SmallVector<Attribute> newPrivateRecipeSyms; // 5) Generate private recipes which are required for properly attaching // private operands. if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && !std::is_same_v<OpT, acc::KernelEnvironmentOp>) - generateRecipes(module, builder, computeConstructOp, newPrivateOperands, - newPrivateRecipeSyms); + generateRecipes(module, builder, computeConstructOp, newPrivateOperands); // 6) Figure out insertion order for the new data clause operands. SmallVector<Value> sortedDataClauseOperands( @@ -846,15 +753,10 @@ void ACCImplicitData::generateImplicitDataOps( // 7) Generate the data exit operations. generateDataExitOperations(builder, computeConstructOp, newDataClauseOperands, sortedDataClauseOperands); - // 8) Add all of the new operands to the compute construct op. - assert(newPrivateOperands.size() == newPrivateRecipeSyms.size() && - "sizes must match"); if constexpr (!std::is_same_v<OpT, acc::KernelsOp> && !std::is_same_v<OpT, acc::KernelEnvironmentOp>) - addNewPrivateOperands(computeConstructOp, newPrivateOperands, - newPrivateRecipeSyms); - + addNewPrivateOperands(computeConstructOp, newPrivateOperands); computeConstructOp.getDataClauseOperandsMutable().assign( sortedDataClauseOperands); } diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp new file mode 100644 index 0000000..8cab223 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp @@ -0,0 +1,431 @@ +//===- ACCImplicitDeclare.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass applies implicit `acc declare` actions to global variables +// referenced in OpenACC compute regions and routine functions. +// +// Overview: +// --------- +// Global references in an acc regions (for globals not marked with `acc +// declare` by the user) can be handled in one of two ways: +// - Mapped through data clauses +// - Implicitly marked as `acc declare` (this pass) +// +// Thus, the OpenACC specification focuses solely on implicit data mapping rules +// whose implementation is captured in `ACCImplicitData` pass. +// +// However, it is both advantageous and required for certain cases to +// use implicit `acc declare` instead: +// - Any functions that are implicitly marked as `acc routine` through +// `ACCImplicitRoutine` may reference globals. Since data mapping +// is only possible for compute regions, such globals can only be +// made available on device through `acc declare`. +// - Compiler can generate and use globals for cases needed in IR +// representation such as type descriptors or various names needed for +// runtime calls and error reporting - such cases often are introduced +// after a frontend semantic checking is done since it is related to +// implementation detail. Thus, such compiler generated globals would +// not have been visible for a user to mark with `acc declare`. +// - Constant globals such as filename strings or data initialization values +// are values that do not get mutated but are still needed for appropriate +// runtime execution. If a kernel is launched 1000 times, it is not a +// good idea to map such a global 1000 times. Therefore, such globals +// benefit from being marked with `acc declare`. +// +// This pass automatically +// marks global variables with the `acc.declare` attribute when they are +// referenced in OpenACC compute constructs or routine functions and meet +// the criteria noted above, ensuring +// they are properly handled for device execution. +// +// The pass performs two main optimizations: +// +// 1. Hoisting: For non-constant globals referenced in compute regions, the +// pass hoists the address-of operation out of the region when possible, +// allowing them to be implicitly mapped through normal data clause +// mechanisms rather than requiring declare marking. +// +// 2. Declaration: For globals that must be available on the device (constants, +// globals in routines, globals in recipe operations), the pass adds the +// `acc.declare` attribute with the copyin data clause. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that compute addresses +// of global variables must implement the `acc::AddressOfGlobalOpInterface` +// and those that represent globals must implement the +// `acc::GlobalOpInterface`. Additionally, any operations that indirectly +// access globals must implement the `acc::IndirectGlobalAccessOpInterface`. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +// +// Examples: +// --------- +// +// Example 1: Non-constant global in compute region (hoisted) +// +// Before: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_scalar : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global @g_scalar : memref<f32> = dense<0.0> +// func.func @test() { +// %addr = memref.get_global @g_scalar : memref<f32> +// acc.serial { +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 2: Constant global in compute region (declared) +// +// Before: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// After: +// memref.global constant @g_const : memref<f32> = dense<1.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// func.func @test() { +// acc.serial { +// %addr = memref.get_global @g_const : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// acc.yield +// } +// } +// +// Example 3: Global in acc routine (declared) +// +// Before: +// memref.global @g_data : memref<f32> = dense<0.0> +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// After: +// memref.global @g_data : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.routine @routine_0 func(@device_func) +// func.func @device_func() attributes {acc.routine_info = ...} { +// %addr = memref.get_global @g_data : memref<f32> +// %val = memref.load %addr[] : memref<f32> +// } +// +// Example 4: Global in private recipe (declared if recipe is used) +// +// Before: +// memref.global @g_init : memref<f32> = dense<0.0> +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +// After: +// memref.global @g_init : memref<f32> = dense<0.0> +// {acc.declare = #acc.declare<dataClause = acc_copyin>} +// acc.private.recipe @priv_recipe : memref<f32> init { +// ^bb0(%arg0: memref<f32>): +// %alloc = memref.alloc() : memref<f32> +// %global = memref.get_global @g_init : memref<f32> +// %val = memref.load %global[] : memref<f32> +// memref.store %val, %alloc[] : memref<f32> +// acc.yield %alloc : memref<f32> +// } destroy { ... } +// func.func @test() { +// %var = memref.alloc() : memref<f32> +// %priv = acc.private varPtr(%var : memref<f32>) +// recipe(@priv_recipe) -> memref<f32> +// acc.parallel private(%priv : memref<f32>) { ... } +// } +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITDECLARE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-implicit-declare" + +using namespace mlir; + +namespace { + +using GlobalOpSetT = llvm::SmallSetVector<Operation *, 16>; + +/// Checks whether a use of the requested `globalOp` should be considered +/// for hoisting out of acc region due to avoid `acc declare`ing something +/// that instead should be implicitly mapped. +static bool isGlobalUseCandidateForHoisting(Operation *globalOp, + Operation *user, + SymbolRefAttr symbol, + acc::OpenACCSupport &accSupport) { + // This symbol is valid in GPU region. This means semantics + // would change if moved to host - therefore it is not a candidate. + if (accSupport.isValidSymbolUse(user, symbol)) + return false; + + bool isConstant = false; + bool isFunction = false; + + if (auto globalVarOp = dyn_cast<acc::GlobalVariableOpInterface>(globalOp)) + isConstant = globalVarOp.isConstant(); + + if (isa<FunctionOpInterface>(globalOp)) + isFunction = true; + + // Constants should be kept in device code to ensure they are duplicated. + // Function references should be kept in device code to ensure their device + // addresses are computed. Everything else should be hoisted since we already + // proved they are not valid symbols in GPU region. + return !isConstant && !isFunction; +} + +/// Checks whether it is valid to use acc.declare marking on the global. +bool isValidForAccDeclare(Operation *globalOp) { + // For functions - we use acc.routine marking instead. + return !isa<FunctionOpInterface>(globalOp); +} + +/// Checks whether a recipe operation has meaningful use of its symbol that +/// justifies processing its regions for global references. Returns false if: +/// 1. The recipe has no symbol uses at all, or +/// 2. The only symbol use is the recipe's own symbol definition +template <typename RecipeOpT> +static bool hasRelevantRecipeUse(RecipeOpT &recipeOp, ModuleOp &mod) { + std::optional<SymbolTable::UseRange> symbolUses = recipeOp.getSymbolUses(mod); + + // No recipe symbol uses. + if (!symbolUses.has_value() || symbolUses->empty()) + return false; + + // If more than one use, assume it's used. + auto begin = symbolUses->begin(); + auto end = symbolUses->end(); + if (begin != end && std::next(begin) != end) + return true; + + // If single use, check if the use is the recipe itself. + const SymbolTable::SymbolUse &use = *symbolUses->begin(); + return use.getUser() != recipeOp.getOperation(); +} + +// Hoists addr_of operations for non-constant globals out of OpenACC regions. +// This way - they are implicitly mapped instead of being considered for +// implicit declare. +template <typename AccConstructT> +static void hoistNonConstantDirectUses(AccConstructT accOp, + acc::OpenACCSupport &accSupport) { + accOp.walk([&](acc::AddressOfGlobalOpInterface addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + if (symRef) { + Operation *globalOp = + SymbolTable::lookupNearestSymbolFrom(addrOfOp, symRef); + if (isGlobalUseCandidateForHoisting(globalOp, addrOfOp, symRef, + accSupport)) { + addrOfOp->moveBefore(accOp); + LLVM_DEBUG( + llvm::dbgs() << "Hoisted:\n\t" << addrOfOp << "\n\tfrom:\n\t"; + accOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + } + } + }); +} + +// Collects the globals referenced in a device region +static void collectGlobalsFromDeviceRegion(Region ®ion, + GlobalOpSetT &globals, + acc::OpenACCSupport &accSupport, + SymbolTable &symTab) { + region.walk([&](Operation *op) { + // 1) Only consider relevant operations which use symbols + auto addrOfOp = dyn_cast<acc::AddressOfGlobalOpInterface>(op); + if (addrOfOp) { + SymbolRefAttr symRef = addrOfOp.getSymbol(); + // 2) Found an operation which uses the symbol. Next determine if it + // is a candidate for `acc declare`. Some of the criteria considered + // is whether this symbol is not already a device one (either because + // acc declare is already used or this is a CUF global). + Operation *globalOp = nullptr; + bool isCandidate = !accSupport.isValidSymbolUse(op, symRef, &globalOp); + // 3) Add the candidate to the set of globals to be `acc declare`d. + if (isCandidate && globalOp && isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } else if (auto indirectAccessOp = + dyn_cast<acc::IndirectGlobalAccessOpInterface>(op)) { + // Process operations that indirectly access globals + llvm::SmallVector<SymbolRefAttr> symbols; + indirectAccessOp.getReferencedSymbols(symbols, &symTab); + for (SymbolRefAttr symRef : symbols) + if (Operation *globalOp = symTab.lookup(symRef.getLeafReference())) + if (isValidForAccDeclare(globalOp)) + globals.insert(globalOp); + } + }); +} + +// Adds the declare attribute to the operation `op`. +static void addDeclareAttr(MLIRContext *context, Operation *op, + acc::DataClause clause) { + op->setAttr(acc::getDeclareAttrName(), + acc::DeclareAttr::get(context, + acc::DataClauseAttr::get(context, clause))); +} + +// This pass applies implicit declare actions for globals referenced in +// OpenACC compute and routine regions. +class ACCImplicitDeclare + : public acc::impl::ACCImplicitDeclareBase<ACCImplicitDeclare> { +public: + using ACCImplicitDeclareBase<ACCImplicitDeclare>::ACCImplicitDeclareBase; + + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *context = &getContext(); + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // 1) Start off by hoisting any AddressOf operations out of acc region + // for any cases we do not want to `acc declare`. This is because we can + // rely on implicit data mapping in majority of cases without uselessly + // polluting the device globals. + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + hoistNonConstantDirectUses(accOp, accSupport); + }); + }); + + // 2) Collect global symbols which need to be `acc declare`d. Do it for + // compute regions, acc routine, and existing globals with the declare + // attribute. + SymbolTable symTab(mod); + GlobalOpSetT globalsToAccDeclare; + mod.walk([&](Operation *op) { + TypeSwitch<Operation *, void>(op) + .Case<ACC_COMPUTE_CONSTRUCT_OPS, acc::KernelEnvironmentOp>( + [&](auto accOp) { + collectGlobalsFromDeviceRegion( + accOp.getRegion(), globalsToAccDeclare, accSupport, symTab); + }) + .Case<FunctionOpInterface>([&](auto func) { + if ((acc::isAccRoutine(func) || + acc::isSpecializedAccRoutine(func)) && + !func.isExternal()) + collectGlobalsFromDeviceRegion(func.getFunctionBody(), + globalsToAccDeclare, accSupport, + symTab); + }) + .Case<acc::GlobalVariableOpInterface>([&](auto globalVarOp) { + if (globalVarOp->getAttr(acc::getDeclareAttrName())) + if (Region *initRegion = globalVarOp.getInitRegion()) + collectGlobalsFromDeviceRegion(*initRegion, globalsToAccDeclare, + accSupport, symTab); + }) + .Case<acc::PrivateRecipeOp>([&](auto privateRecipe) { + if (hasRelevantRecipeUse(privateRecipe, mod)) { + collectGlobalsFromDeviceRegion(privateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion(privateRecipe.getDestroyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::FirstprivateRecipeOp>([&](auto firstprivateRecipe) { + if (hasRelevantRecipeUse(firstprivateRecipe, mod)) { + collectGlobalsFromDeviceRegion(firstprivateRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + firstprivateRecipe.getDestroyRegion(), globalsToAccDeclare, + accSupport, symTab); + collectGlobalsFromDeviceRegion(firstprivateRecipe.getCopyRegion(), + globalsToAccDeclare, accSupport, + symTab); + } + }) + .Case<acc::ReductionRecipeOp>([&](auto reductionRecipe) { + if (hasRelevantRecipeUse(reductionRecipe, mod)) { + collectGlobalsFromDeviceRegion(reductionRecipe.getInitRegion(), + globalsToAccDeclare, accSupport, + symTab); + collectGlobalsFromDeviceRegion( + reductionRecipe.getCombinerRegion(), globalsToAccDeclare, + accSupport, symTab); + } + }); + }); + + // 3) Finally, generate the appropriate declare actions needed to ensure + // this is considered for device global. + for (Operation *globalOp : globalsToAccDeclare) { + LLVM_DEBUG( + llvm::dbgs() << "Global is being `acc declare copyin`d: "; + globalOp->print(llvm::dbgs(), + OpPrintingFlags{}.skipRegions().enableDebugInfo()); + llvm::dbgs() << "\n"); + + // Mark it as declare copyin. + addDeclareAttr(context, globalOp, acc::DataClause::acc_copyin); + + // TODO: May need to create the global constructor which does the mapping + // action. It is not yet clear if this is needed yet (since the globals + // might just end up in the GPU image without requiring mapping via + // runtime). + } + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp new file mode 100644 index 0000000..12efaf4 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitRoutine.cpp @@ -0,0 +1,237 @@ +//===- ACCImplicitRoutine.cpp - OpenACC Implicit Routine Transform -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements the implicit rules described in OpenACC specification +// for `Routine Directive` (OpenACC 3.4 spec, section 2.15.1). +// +// "If no explicit routine directive applies to a procedure whose definition +// appears in the program unit being compiled, then the implementation applies +// an implicit routine directive to that procedure if any of the following +// conditions holds: +// - The procedure is called or its address is accessed in a compute region." +// +// The specification further states: +// "When the implementation applies an implicit routine directive to a +// procedure, it must recursively apply implicit routine directives to other +// procedures for which the above rules specify relevant dependencies. Such +// dependencies can form a cycle, so the implementation must take care to avoid +// infinite recursion." +// +// This pass implements these requirements by: +// 1. Walking through all OpenACC compute constructs and functions already +// marked with `acc routine` in the module and identifying function calls +// within these regions. +// 2. Creating implicit `acc.routine` operations for functions that don't +// already have routine declarations. +// 3. Recursively walking through all existing `acc routine` and creating +// implicit routine operations for function calls within these routines, +// while avoiding infinite recursion through proper tracking. +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements must be met: +// +// 1. Operation Interface Implementation: Operations that define functions +// or call functions should implement `mlir::FunctionOpInterface` and +// `mlir::CallOpInterface` respectively. +// +// 2. Analysis Registration (Optional): If custom behavior is needed for +// determining if a symbol use is valid within GPU regions, the dialect +// should pre-register the `acc::OpenACCSupport` analysis. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include <queue> + +#define DEBUG_TYPE "acc-implicit-routine" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIMPLICITROUTINE +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +namespace { + +using namespace mlir; + +class ACCImplicitRoutine + : public acc::impl::ACCImplicitRoutineBase<ACCImplicitRoutine> { +private: + unsigned routineCounter = 0; + static constexpr llvm::StringRef accRoutinePrefix = "acc_routine_"; + + // Count existing routine operations and update counter + void initRoutineCounter(ModuleOp module) { + module.walk([&](acc::RoutineOp routineOp) { routineCounter++; }); + } + + // Check if routine has a default bind clause or a device-type specific bind + // clause. Returns true if `acc routine` has a default bind clause or + // a device-type specific bind clause. + bool isACCRoutineBindDefaultOrDeviceType(acc::RoutineOp op, + acc::DeviceType deviceType) { + // Fast check to avoid device-type specific lookups. + if (!op.getBindIdName() && !op.getBindStrName()) + return false; + return op.getBindNameValue().has_value() || + op.getBindNameValue(deviceType).has_value(); + } + + // Generate a unique name for the routine and create the routine operation + acc::RoutineOp createRoutineOp(OpBuilder &builder, Location loc, + FunctionOpInterface &callee) { + std::string routineName = + (accRoutinePrefix + std::to_string(routineCounter++)).str(); + auto routineOp = acc::RoutineOp::create( + builder, loc, + /* sym_name=*/builder.getStringAttr(routineName), + /* func_name=*/ + mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(callee.getName())), + /* bindIdName=*/nullptr, + /* bindStrName=*/nullptr, + /* bindIdNameDeviceType=*/nullptr, + /* bindStrNameDeviceType=*/nullptr, + /* worker=*/nullptr, + /* vector=*/nullptr, + /* seq=*/nullptr, + /* nohost=*/nullptr, + /* implicit=*/builder.getUnitAttr(), + /* gang=*/nullptr, + /* gangDim=*/nullptr, + /* gangDimDeviceType=*/nullptr); + + // Assert that the callee does not already have routine info attribute + assert(!callee->hasAttr(acc::getRoutineInfoAttrName()) && + "function is already associated with a routine"); + + callee->setAttr( + acc::getRoutineInfoAttrName(), + mlir::acc::RoutineInfoAttr::get( + builder.getContext(), + {mlir::SymbolRefAttr::get(builder.getContext(), + builder.getStringAttr(routineName))})); + return routineOp; + } + + // Used to walk through a compute region looking for function calls. + void + implicitRoutineForCallsInComputeRegions(Operation *op, SymbolTable &symTab, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport) { + op->walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + createRoutineOp(builder, callee.getLoc(), callee); + }); + } + + // Recursively handle calls within a routine operation + void implicitRoutineForCallsInRoutine(acc::RoutineOp routineOp, + mlir::OpBuilder &builder, + acc::OpenACCSupport &accSupport, + acc::DeviceType targetDeviceType) { + // When bind clause is used, it means that the target is different than the + // function to which the `acc routine` is used with. Skip this case to + // avoid implicitly recursively marking calls that would not end up on + // device. + if (isACCRoutineBindDefaultOrDeviceType(routineOp, targetDeviceType)) + return; + + SymbolTable symTab(routineOp->getParentOfType<ModuleOp>()); + std::queue<acc::RoutineOp> routineQueue; + routineQueue.push(routineOp); + while (!routineQueue.empty()) { + auto currentRoutine = routineQueue.front(); + routineQueue.pop(); + auto func = symTab.lookup<FunctionOpInterface>( + currentRoutine.getFuncName().getLeafReference()); + func.walk([&](CallOpInterface callOp) { + if (!callOp.getCallableForCallee()) + return; + + auto calleeSymbolRef = + dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()); + // When call is done through ssa value, the callee is not a symbol. + // Skip it because we don't know the call target. + if (!calleeSymbolRef) + return; + + auto callee = symTab.lookup<FunctionOpInterface>( + calleeSymbolRef.getLeafReference().str()); + // If the callee does not exist or is already a valid symbol for GPU + // regions, skip it + assert(callee && "callee function must be found in symbol table"); + if (accSupport.isValidSymbolUse(callOp.getOperation(), calleeSymbolRef)) + return; + builder.setInsertionPoint(callee); + auto newRoutineOp = createRoutineOp(builder, callee.getLoc(), callee); + routineQueue.push(newRoutineOp); + }); + } + } + +public: + using ACCImplicitRoutineBase<ACCImplicitRoutine>::ACCImplicitRoutineBase; + + void runOnOperation() override { + auto module = getOperation(); + mlir::OpBuilder builder(module.getContext()); + SymbolTable symTab(module); + initRoutineCounter(module); + + acc::OpenACCSupport &accSupport = getAnalysis<acc::OpenACCSupport>(); + + // Handle compute regions + module.walk([&](Operation *op) { + if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) + implicitRoutineForCallsInComputeRegions(op, symTab, builder, + accSupport); + }); + + // Use the device type option from the pass options. + acc::DeviceType targetDeviceType = deviceType; + + // Handle existing routines + module.walk([&](acc::RoutineOp routineOp) { + implicitRoutineForCallsInRoutine(routineOp, builder, accSupport, + targetDeviceType); + }); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000..f41ce276 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index f8fff59..10a1796 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -1,5 +1,8 @@ add_mlir_dialect_library(MLIROpenACCTransforms ACCImplicitData.cpp + ACCImplicitDeclare.cpp + ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index fbac28e..7f27b44 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,8 +9,13 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/Support/Casting.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -155,3 +160,109 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { return val; } + +bool mlir::acc::isValidSymbolUse(mlir::Operation *user, + mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr) { + mlir::Operation *definingOp = + mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol); + + // If there are no defining ops, we have no way to ensure validity because + // we cannot check for any attributes. + if (!definingOp) + return false; + + if (definingOpPtr) + *definingOpPtr = definingOp; + + // Check if the defining op is a recipe (private, reduction, firstprivate). + // Recipes are valid as they get materialized before being offloaded to + // device. They are only instructions for how to materialize. + if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp, + mlir::acc::FirstprivateRecipeOp>(definingOp)) + return true; + + // Check if the defining op is a function + if (auto func = + mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) { + // If this symbol is actually an acc routine - then it is expected for it + // to be offloaded - therefore it is valid. + if (func->hasAttr(mlir::acc::getRoutineInfoAttrName())) + return true; + + // If this symbol is a call to an LLVM intrinsic, then it is likely valid. + // Check the following: + // 1. The function is private + // 2. The function has no body + // 3. Name starts with "llvm." + // 4. The function's name is a valid LLVM intrinsic name + if (func.getVisibility() == mlir::SymbolTable::Visibility::Private && + func.getFunctionBody().empty() && func.getName().starts_with("llvm.") && + llvm::Intrinsic::lookupIntrinsicID(func.getName()) != + llvm::Intrinsic::not_intrinsic) + return true; + } + + // A declare attribute is needed for symbol references. + bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName()); + return hasDeclare; +} + +llvm::SmallVector<mlir::Value> +mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp, + mlir::DominanceInfo &domInfo, + mlir::PostDominanceInfo &postDomInfo) { + llvm::SmallSetVector<mlir::Value, 8> dominatingDataClauses; + + llvm::TypeSwitch<mlir::Operation *>(computeConstructOp) + .Case<mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp>( + [&](auto op) { + for (auto dataClause : op.getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + }) + .Default([](mlir::Operation *) {}); + + // Collect the data clauses from enclosing data constructs. + mlir::Operation *currParentOp = computeConstructOp->getParentOp(); + while (currParentOp) { + if (mlir::isa<mlir::acc::DataOp>(currParentOp)) { + for (auto dataClause : mlir::dyn_cast<mlir::acc::DataOp>(currParentOp) + .getDataClauseOperands()) { + dominatingDataClauses.insert(dataClause); + } + } + currParentOp = currParentOp->getParentOp(); + } + + // Find the enclosing function/subroutine + auto funcOp = + computeConstructOp->getParentOfType<mlir::FunctionOpInterface>(); + if (!funcOp) + return dominatingDataClauses.takeVector(); + + // Walk the function to find `acc.declare_enter`/`acc.declare_exit` pairs that + // dominate and post-dominate the compute construct and add their data + // clauses to the list. + funcOp->walk([&](mlir::acc::DeclareEnterOp declareEnterOp) { + if (domInfo.dominates(declareEnterOp.getOperation(), computeConstructOp)) { + // Collect all `acc.declare_exit` ops for this token. + llvm::SmallVector<mlir::acc::DeclareExitOp> exits; + for (auto *user : declareEnterOp.getToken().getUsers()) + if (auto declareExit = mlir::dyn_cast<mlir::acc::DeclareExitOp>(user)) + exits.push_back(declareExit); + + // Only add clauses if every `acc.declare_exit` op post-dominates the + // compute construct. + if (!exits.empty() && + llvm::all_of(exits, [&](mlir::acc::DeclareExitOp exitOp) { + return postDomInfo.postDominates(exitOp, computeConstructOp); + })) { + for (auto dataClause : declareEnterOp.getDataClauseOperands()) + dominatingDataClauses.insert(dataClause); + } + } + }); + + return dominatingDataClauses.takeVector(); +} diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 1b069c6..103295d 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -617,6 +617,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, break; case ClauseScheduleKind::Auto: case ClauseScheduleKind::Runtime: + case ClauseScheduleKind::Distribute: chunkSize = std::nullopt; } @@ -1817,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser, if (mapTypeMod == "ref_ptr_ptee") mapTypeBits |= ClauseMapFlags::ref_ptr_ptee; + if (mapTypeMod == "is_device_ptr") + mapTypeBits |= ClauseMapFlags::is_device_ptr; + return success(); }; @@ -1886,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op, mapTypeStrs.push_back("ref_ptee"); if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee)) mapTypeStrs.push_back("ref_ptr_ptee"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr)) + mapTypeStrs.push_back("is_device_ptr"); if (mapFlags == ClauseMapFlags::none) mapTypeStrs.push_back("none"); @@ -2824,6 +2830,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), + /*linear_var_types*/ nullptr, /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, /*private_needs_barrier=*/false, @@ -2842,8 +2849,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, WsloopOp::build( builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars, - clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod, - clauses.ordered, clauses.privateVars, + clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait, + clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, makeDenseBoolArrayAttr(ctx, clauses.reductionByref), @@ -2888,17 +2895,16 @@ LogicalResult WsloopOp::verifyRegions() { void SimdOp::build(OpBuilder &builder, OperationState &state, const SimdOperands &clauses) { MLIRContext *ctx = builder.getContext(); - // TODO Store clauses in op: linearVars, linearStepVars - SimdOp::build(builder, state, clauses.alignedVars, - makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, - /*linear_vars=*/{}, /*linear_step_vars=*/{}, - clauses.nontemporalVars, clauses.order, clauses.orderMod, - clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.reductionMod, - clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, - clauses.simdlen); + SimdOp::build( + builder, state, clauses.alignedVars, + makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, + clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes, + clauses.nontemporalVars, clauses.order, clauses.orderMod, + clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, + clauses.simdlen); } LogicalResult SimdOp::verify() { diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt index 423e1c3..b111117 100644 --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -19,5 +19,5 @@ add_mlir_dialect_library(MLIRSCFDialect MLIRSideEffectInterfaces MLIRTensorDialect MLIRValueBoundsOpInterface + MLIRTransformUtils ) - diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 881e256..c4bd31f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -26,6 +26,7 @@ #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() { } namespace { +/// Move a scf.if op that is directly before the scf.condition op in the while +/// before region, and whose condition matches the condition of the +/// scf.condition op, down into the while after region. +/// +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// %res = scf.if %cond -> (...) { +/// use(%additional_used_values) +/// ... // then block +/// scf.yield %then_value +/// } else { +/// scf.yield %else_value +/// } +/// scf.condition(%cond) %res, ... +/// } do { +/// ^bb0(%res_arg, ...): +/// use(%res_arg) +/// ... +/// +/// becomes +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// scf.condition(%cond) %else_value, ..., %additional_used_values +/// } do { +/// ^bb0(%res_arg ..., %additional_args): : +/// use(%additional_args) +/// ... // if then block +/// use(%then_value) +/// ... +struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { + using OpRewritePattern<scf::WhileOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + auto conditionOp = op.getConditionOp(); + + // Only support ifOp right before the condition at the moment. Relaxing this + // would require to: + // - check that the body does not have side-effects conflicting with + // operations between the if and the condition. + // - check that results of the if operation are only used as arguments to + // the condition. + auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode()); + + // Check that the ifOp is directly before the conditionOp and that it + // matches the condition of the conditionOp. Also ensure that the ifOp has + // no else block with content, as that would complicate the transformation. + // TODO: support else blocks with content. + if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || + (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) + return failure(); + + assert((ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && + *ifOp->user_begin() == conditionOp)) && + "ifOp has unexpected uses"); + + Location loc = op.getLoc(); + + // Replace uses of ifOp results in the conditionOp with the yielded values + // from the ifOp branches. + for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { + auto it = llvm::find(ifOp->getResults(), arg); + if (it != ifOp->getResults().end()) { + size_t ifOpIdx = it.getIndex(); + Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); + Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); + + rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); + rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); + } + } + + // Collect additional used values from before region. + SetVector<Value> additionalUsedValuesSet; + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { + if (&op.getBefore() == operand->get().getParentRegion()) + additionalUsedValuesSet.insert(operand->get()); + }); + + // Create new whileOp with additional used values as results. + auto additionalUsedValues = additionalUsedValuesSet.getArrayRef(); + auto additionalValueTypes = llvm::map_to_vector( + additionalUsedValues, [](Value val) { return val.getType(); }); + size_t additionalValueSize = additionalUsedValues.size(); + SmallVector<Type> newResultTypes(op.getResultTypes()); + newResultTypes.append(additionalValueTypes); + + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); + + rewriter.modifyOpInPlace(newWhileOp, [&] { + newWhileOp.getBefore().takeBody(op.getBefore()); + newWhileOp.getAfter().takeBody(op.getAfter()); + newWhileOp.getAfter().addArguments( + additionalValueTypes, + SmallVector<Location>(additionalValueSize, loc)); + }); + + rewriter.modifyOpInPlace(conditionOp, [&] { + conditionOp.getArgsMutable().append(additionalUsedValues); + }); + + // Replace uses of additional used values inside the ifOp then region with + // the whileOp after region arguments. + rewriter.replaceUsesWithIf( + additionalUsedValues, + newWhileOp.getAfterArguments().take_back(additionalValueSize), + [&](OpOperand &use) { + return ifOp.getThenRegion().isAncestor( + use.getOwner()->getParentRegion()); + }); + + // Inline ifOp then region into new whileOp after region. + rewriter.eraseOp(ifOp.thenYield()); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), + newWhileOp.getAfterBody()->begin()); + rewriter.eraseOp(ifOp); + rewriter.replaceOp(op, + newWhileOp->getResults().drop_back(additionalValueSize)); + return success(); + } +}; + /// Replace uses of the condition within the do block with true, since otherwise /// the block would not be evaluated. /// @@ -4343,7 +4471,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { LogicalResult matchAndRewrite(WhileOp loop, PatternRewriter &rewriter) const override { - auto oldBefore = loop.getBeforeBody(); + auto *oldBefore = loop.getBeforeBody(); ConditionOp oldTerm = loop.getConditionOp(); ValueRange beforeArgs = oldBefore->getArguments(); ValueRange termArgs = oldTerm.getArgs(); @@ -4364,7 +4492,7 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { beforeArgs); } - auto oldAfter = loop.getAfterBody(); + auto *oldAfter = loop.getAfterBody(); SmallVector<Type> newResultTypes(beforeArgs.size()); for (auto &&[i, j] : llvm::enumerate(*mapping)) @@ -4373,8 +4501,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { auto newLoop = WhileOp::create( rewriter, loop.getLoc(), newResultTypes, loop.getInits(), /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); - auto newBefore = newLoop.getBeforeBody(); - auto newAfter = newLoop.getAfterBody(); + auto *newBefore = newLoop.getBeforeBody(); + auto *newAfter = newLoop.getAfterBody(); SmallVector<Value> newResults(beforeArgs.size()); SmallVector<Value> newAfterArgs(beforeArgs.size()); @@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add<RemoveLoopInvariantArgsFromBeforeBlock, RemoveLoopInvariantValueYielded, WhileConditionTruth, WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, - WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770f..009c2c3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast<scf::ForOp>(outerLoop); + auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); auto outerLoopYield = cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); SmallVector<Value> newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr<scf::SCFFuseConsumerOfSliceResult> -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, - MutableArrayRef<LoopLikeOpInterface> loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef<OpOperand *> consumerOpOperands, + ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || - llvm::all_of(candidateSlices, - llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector<OpOperand *> consumerOpOperands; - Operation *consumerOp; - { - FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); if (!dstOp) @@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || + llvm::all_of(candidateSlices, + llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperands)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner(); + + return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp, + maybeConsumerOpOperands.value(), + candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional<Operation *> +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) + return std::nullopt; + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +static std::optional<Operation *> +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outerMostLoop = loops.front(); + if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto innerForResult = + dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + LoopLikeOpInterface loop = loops.front(); + if (result.getOwner() != loop) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp<tensor::InsertSliceOp>(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (!isa<TilingInterface>(consumer)) { + return rewriter.notifyMatchFailure( + consumer, "unhandled consumer that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + consumer, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the consumer that come from the outermost loop of + // the loop nest. + SmallVector<OpOperand *> consumerFusableOperands; + for (OpOperand &opOperand : consumer->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector<OpOperand *>{}, + SmallVector<Operation *>{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector<Operation *> candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional<Operation *> slice = + getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + consumer, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumer, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index f0b46e6..a846d7e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -220,6 +220,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() { } //===----------------------------------------------------------------------===// +// spirv.Switch +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + DenseIntElementsAttr literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + build(builder, result, selector, defaultOperands, targetOperands, literals, + defaultTarget, targets); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef<APInt> literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast<int64_t>(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef<int32_t> literals, BlockRange targets, + ArrayRef<ValueRange> targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast<int64_t>(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +LogicalResult SwitchOp::verify() { + std::optional<DenseIntElementsAttr> literals = getLiterals(); + BlockRange targets = getTargets(); + + if (!literals && targets.empty()) + return success(); + + Type selectorType = getSelector().getType(); + Type literalType = literals->getType().getElementType(); + if (literalType != selectorType) + return emitOpError() << "'selector' type (" << selectorType + << ") should match literals type (" << literalType + << ")"; + + if (literals && literals->size() != static_cast<int64_t>(targets.size())) + return emitOpError() << "number of literals (" << literals->size() + << ") should match number of targets (" + << targets.size() << ")"; + return success(); +} + +SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() + : getTargetOperandsMutable(index - 1)); +} + +Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { + std::optional<DenseIntElementsAttr> literals = getLiterals(); + + if (!literals) + return getDefaultTarget(); + + SuccessorRange targets = getTargets(); + if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) { + for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>())) + if (literal == value.getValue()) + return targets[index]; + return getDefaultTarget(); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// // spirv.mlir.loop //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index 2f3a28f..8575487 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, } } +/// Adapted from the cf.switch implementation. +/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, Type &selectorType, Block *&defaultTarget, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, + SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals, + SmallVectorImpl<Block *> &targets, + SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> + &targetOperands, + SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) { + if (parser.parseKeyword("default") || parser.parseColon() || + parser.parseSuccessor(defaultTarget)) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || + parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) + return failure(); + } + + SmallVector<APInt> values; + unsigned bitWidth = selectorType.getIntOrFloatBitWidth(); + while (succeeded(parser.parseOptionalComma())) { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); + + Block *target; + SmallVector<OpAsmParser::UnresolvedOperand> operands; + SmallVector<Type> operandTypes; + if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseOperandList(operands, + OpAsmParser::Delimiter::None)) || + failed(parser.parseColonTypeList(operandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + targets.push_back(target); + targetOperands.emplace_back(operands); + targetOperandTypes.emplace_back(operandTypes); + } + + if (!values.empty()) { + ShapedType literalType = + VectorType::get(static_cast<int64_t>(values.size()), selectorType); + literals = DenseIntElementsAttr::get(literalType, values); + } + return success(); +} + +static void +printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType, + Block *defaultTarget, OperandRange defaultOperands, + TypeRange defaultOperandTypes, DenseIntElementsAttr literals, + SuccessorRange targets, OperandRangeRange targetOperands, + const TypeRangeRange &targetOperandTypes) { + p << " default: "; + p.printSuccessorAndUseList(defaultTarget, defaultOperands); + + if (!literals) + return; + + for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) { + p << ','; + p.printNewline(); + p << " "; + p << literal.getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(targets[index], targetOperands[index]); + } + p.printNewline(); +} + } // namespace mlir::spirv // TablenGen'erated operation definitions. diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index cb9b7f6..f07307f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, << type << " illegal: cannot handle zero-element tensors\n"); return nullptr; } + if (arrayElemCount > std::numeric_limits<unsigned>::max()) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot fit tensor into target type\n"); + return nullptr; + } Type arrayElemType = convertScalarType(targetEnv, options, scalarType); if (!arrayElemType) diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 645cbff..5941f7d 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames( //===----------------------------------------------------------------------===// void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr grid, - ArrayRef<GridAxesAttr> split_axes, - ArrayRef<int64_t> static_halos, - ArrayRef<int64_t> static_offsets) { + FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHalos, + ArrayRef<int64_t> staticOffsets) { return build( - b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes, - ArrayRef<int64_t> static_halos, - ArrayRef<int64_t> static_offsets) { + llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHalos, + ArrayRef<int64_t> staticOffsets) { return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid), - GridAxesArrayAttr::get(b.getContext(), split_axes), - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, - ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), + GridAxesArrayAttr::get(b.getContext(), splitAxes), + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {}, + ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {}); } void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes, - ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, - ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) { + FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes, + ::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes, + ::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) { mlir::SmallVector<int64_t> staticHalos, staticDims; mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims; - dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos); - dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims); + dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos); + dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims); return build( - b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims); } @@ -576,7 +575,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return failure(); } if (mlir::ShapedType::isDynamicShape(grid->getShape()) && - getStaticShardedDimsOffsets().size() > 0) { + !getStaticShardedDimsOffsets().empty()) { return emitError() << "sharded dims offsets are not allowed for " "device grids with dynamic shape."; } @@ -650,14 +649,14 @@ public: if (dynamicOffs.empty() && !staticOffs.empty()) { assert(staticOffs.size() >= 2); auto diff = staticOffs[1] - staticOffs[0]; - bool all_same = staticOffs.size() > 2; + bool allSame = staticOffs.size() > 2; for (auto i = 2u; i < staticOffs.size(); ++i) { if (staticOffs[i] - staticOffs[i - 1] != diff) { - all_same = false; + allSame = false; break; } } - if (all_same) { + if (allSame) { staticOffs.clear(); modified = true; } @@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const { bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); } -Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {} +Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {} Sharding::Sharding(Value rhs) { auto shardingOp = rhs.getDefiningOp<ShardingOp>(); @@ -767,21 +766,20 @@ Sharding::Sharding(Value rhs) { SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets())); } -Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, - ArrayRef<GridAxesAttr> split_axes_, - ArrayRef<int64_t> static_halo_sizes_, - ArrayRef<int64_t> static_sharded_dims_offsets_, - ArrayRef<Value> dynamic_halo_sizes_, - ArrayRef<Value> dynamic_sharded_dims_offsets_) { - Sharding res(grid_); - if (split_axes_.empty()) { +Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid, + ArrayRef<GridAxesAttr> splitAxes, + ArrayRef<int64_t> staticHaloSizes, + ArrayRef<int64_t> staticShardedDimsOffsets, + ArrayRef<Value> dynamicHaloSizes, + ArrayRef<Value> dynamicShardedDimsOffsets) { + Sharding res(grid); + if (splitAxes.empty()) { return res; } - res.split_axes.resize(split_axes_.size()); - for (auto [i, axis] : llvm::enumerate(split_axes_)) { - res.split_axes[i] = - GridAxesAttr::get(grid_.getContext(), axis.asArrayRef()); + res.split_axes.resize(splitAxes.size()); + for (auto [i, axis] : llvm::enumerate(splitAxes)) { + res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef()); } auto clone = [](const auto src, auto &dst) { @@ -789,10 +787,10 @@ Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, llvm::copy(src, dst.begin()); }; - clone(static_halo_sizes_, res.static_halo_sizes); - clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets); - clone(dynamic_halo_sizes_, res.dynamic_halo_sizes); - clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets); + clone(staticHaloSizes, res.static_halo_sizes); + clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets); + clone(dynamicHaloSizes, res.dynamic_halo_sizes); + clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets); return res; } @@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames( void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<int64_t> dims, - ArrayRef<Value> dims_dyn, ::mlir::Value sharding, + ArrayRef<Value> dimsDyn, ::mlir::Value sharding, ::mlir::ValueRange device) { SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType()); - build(odsBuilder, odsState, resType, dims, dims_dyn, sharding, + build(odsBuilder, odsState, resType, dims, dimsDyn, sharding, SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device); } diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp index 3bfbf373..f954131 100644 --- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp @@ -184,7 +184,7 @@ ReshardingRquirementKind getReshardingRquirementKind( for (auto [result, sharding] : llvm::zip_equal(op->getResults(), resultShardings)) { - for (auto user : result.getUsers()) { + for (auto *user : result.getUsers()) { ShardOp shardOp = llvm::dyn_cast<ShardOp>(user); if (!shardOp) { continue; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index ae7eef2..9db9814 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -1365,8 +1365,8 @@ public: arith::SubIOp::create(rewriter, loc, capacity, newSize); Value fillValue = constantZero(rewriter, loc, value.getType()); Value subBuffer = memref::SubViewOp::create( - rewriter, loc, newBuffer, /*offset=*/ValueRange{newSize}, - /*size=*/ValueRange{fillSize}, + rewriter, loc, newBuffer, /*offsets=*/ValueRange{newSize}, + /*sizes=*/ValueRange{fillSize}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); linalg::FillOp::create(rewriter, loc, fillValue, subBuffer); } @@ -1386,8 +1386,8 @@ public: memref::StoreOp::create(rewriter, loc, value, buffer, size); } else { Value subBuffer = memref::SubViewOp::create( - rewriter, loc, buffer, /*offset=*/ValueRange{size}, - /*size=*/ValueRange{n}, + rewriter, loc, buffer, /*offsets=*/ValueRange{size}, + /*sizes=*/ValueRange{n}, /*step=*/ValueRange{constantIndex(rewriter, loc, 1)}); linalg::FillOp::create(rewriter, loc, value, subBuffer); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index febec6d..23436a6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem, SmallVector<Value> scalarArgs(idxs); Value indexVec = idxs.back(); scalarArgs.back() = constantIndex(rewriter, loc, 0); - vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask, - rhs); + vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem, + scalarArgs, indexVec, vmask, rhs); return; } vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index ffa8b40..9904803 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -80,6 +80,53 @@ inline static bool includesDenseOutput(SortMask mask) { return includesAny(mask, SortMask::kIncludeDenseOutput); } +/// Returns a sparsity rank for loop ordering: lower values indicate +/// dimensions that should be placed in outer loops. +/// 0 = Dense, 1 = Compressed, 2 = Singleton, 3 = Other/Unknown. +static unsigned getLoopSparsityRank(unsigned loop, ArrayRef<Value> allTensors, + ArrayRef<AffineMap> allMaps) { + // Start with highest rank. + unsigned minRank = 3; + + for (auto [tensor, map] : llvm::zip(allTensors, allMaps)) { + // Check if this loop accesses this tensor. + bool loopAccessesTensor = false; + unsigned tensorDim = 0; + for (AffineExpr expr : map.getResults()) { + if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { + if (dimExpr.getPosition() == loop) { + loopAccessesTensor = true; + break; + } + } + tensorDim++; + } + + if (loopAccessesTensor) { + const auto enc = getSparseTensorEncoding(tensor.getType()); + if (!enc) { + // Dense tensor - lowest rank. + return 0; + } else { + // Sparse tensor - check the level type for this dimension. + auto lvlTypes = enc.getLvlTypes(); + if (tensorDim < lvlTypes.size()) { + auto lvlType = lvlTypes[tensorDim]; + if (isDenseLT(lvlType)) { + return 0; // Dense level. + } else if (isCompressedLT(lvlType)) { + minRank = std::min(minRank, 1u); // Compressed level. + } else if (isSingletonLT(lvlType)) { + minRank = std::min(minRank, 2u); // Singleton level. + } + } + } + } + } + + return minRank; +} + AffineMap IterationGraphSorter::topoSort() { // The sorted result will put the first Reduction iterator to the // latest possible position. @@ -107,10 +154,33 @@ AffineMap IterationGraphSorter::topoSort() { case sparse_tensor::LoopOrderingStrategy::kDefault: src = it.back(); break; + case sparse_tensor::LoopOrderingStrategy::kDenseOuter: { + // Prefer dense, then compressed, then singleton dimensions outermost. + // Create combined tensor and map lists for analysis. + SmallVector<Value> allTensors = ins; + allTensors.push_back(out); + SmallVector<AffineMap> allMaps = loop2InsLvl; + allMaps.push_back(loop2OutLvl); + + // Find loop with minimum (lowest) sparsity rank. + unsigned minLoop = it[0]; + unsigned minRank = getLoopSparsityRank(minLoop, allTensors, allMaps); + + for (auto candidateLoop : it) { + unsigned rank = getLoopSparsityRank(candidateLoop, allTensors, allMaps); + if (rank < minRank || (rank == minRank && candidateLoop < minLoop)) { + minLoop = candidateLoop; + minRank = rank; + } + } + src = minLoop; + break; + } } loopOrder.push_back(src); - it.pop_back(); + // Remove the selected loop from the worklist. + it.erase(std::find(it.begin(), it.end(), src)); // Update in-degree, and push 0-degree node into worklist. for (unsigned dst = 0; dst < numLoops; dst++) { if (itGraph[src][dst] && --inDegree[dst] == 0) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 3636f3f..46378b9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -197,7 +197,7 @@ public: // Sets the iterate to the specified position. void seek(ValueRange vals) { assert(vals.size() == cursorValsCnt); - std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin()); + llvm::copy(vals, cursorValsStorageRef.begin()); // Now that the iterator is re-positioned, the coordinate becomes invalid. crd = nullptr; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp index 4ec13e1..686f6ee 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -77,6 +77,9 @@ namespace { struct ReifyExpandShapeOp : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, ExpandShapeOp> { + using Base = + ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp, + ExpandShapeOp>; LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifyResultShapes) const { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 110bfdc..204e9bb 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) { assert(!inputTypes.empty() && "cannot concatenate 0 tensors"); auto tensorTypes = - llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) { - return llvm::cast<RankedTensorType>(type); - })); + llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>); int64_t concatRank = tensorTypes[0].getRank(); // The concatenation dim must be in the range [0, rank). @@ -2293,9 +2291,9 @@ void ExtractSliceOp::getAsmResultNames( /// An extract_slice result type can be inferred, when it is not /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets, - ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef<int64_t> staticSizes) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type // and strides=1. @@ -2307,11 +2305,12 @@ RankedTensorType ExtractSliceOp::inferResultType( } // TODO: This uses neither offsets nor strides! -RankedTensorType ExtractSliceOp::inferResultType( - RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, - ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType, + ArrayRef<OpFoldResult> sizes) { SmallVector<int64_t> staticSizes; std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes); + assert(static_cast<int64_t>(staticSizes.size()) == sourceTensorType.getRank() && "unexpected staticSizes not equal to rank of source"); @@ -2329,11 +2328,10 @@ RankedTensorType ExtractSliceOp::inferResultType( /// To disambiguate, this function always drops the first 1 sizes occurrences. RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, - ArrayRef<int64_t> strides) { + ArrayRef<int64_t> sizes) { // Type inferred in the absence of rank-reducing behavior. auto inferredType = llvm::cast<RankedTensorType>( - inferResultType(sourceRankedTensorType, offsets, sizes, strides)); + inferResultType(sourceRankedTensorType, sizes)); int rankDiff = inferredType.getRank() - desiredResultRank; if (rankDiff > 0) { auto shape = inferredType.getShape(); @@ -2352,16 +2350,12 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, - ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, - ArrayRef<OpFoldResult> strides) { - SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; - SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + ArrayRef<OpFoldResult> sizes) { + SmallVector<int64_t> staticSizes; + SmallVector<Value> dynamicSizes; dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); return ExtractSliceOp::inferCanonicalRankReducedResultType( - desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, - staticStrides); + desiredResultRank, sourceRankedTensorType, staticSizes); } /// Build an ExtractSliceOp with mixed static and dynamic entries and custom @@ -2380,8 +2374,8 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType()); // Structuring implementation this way avoids duplication between builders. if (!resultType) { - resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType( - sourceRankedTensorType, staticOffsets, staticSizes, staticStrides)); + resultType = llvm::cast<RankedTensorType>( + ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes)); } result.addAttributes(attrs); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, @@ -2451,13 +2445,26 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, } } +/// Build an ExtractSliceOp with mixed static and dynamic sizes, inferred +/// result type, offsets set to 0 and strides set to 1. +void ExtractSliceOp::build(OpBuilder &b, OperationState &result, + RankedTensorType resultType, Value source, + ArrayRef<OpFoldResult> sizes, + ArrayRef<NamedAttribute> attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr); + SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr); + build(b, result, resultType, source, readOffsets, sizes, readStrides, attrs); +} + /// Verifier for ExtractSliceOp. LogicalResult ExtractSliceOp::verify() { RankedTensorType sourceType = getSourceType(); // Verify result type against inferred type. - RankedTensorType expectedType = ExtractSliceOp::inferResultType( - sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides()); + RankedTensorType expectedType = + ExtractSliceOp::inferResultType(sourceType, getMixedSizes()); SliceVerificationResult result = isRankReducedType(expectedType, getType()); if (result != SliceVerificationResult::Success) return produceSliceErrorMsg(result, *this, expectedType); @@ -2697,8 +2704,7 @@ struct SliceReturnTypeCanonicalizer { ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) { return ExtractSliceOp::inferCanonicalRankReducedResultType( - op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes, - mixedStrides); + op.getType().getRank(), op.getSourceType(), mixedSizes); } }; @@ -2839,8 +2845,8 @@ static SliceVerificationResult verifyInsertSliceOp( ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type // inference. - RankedTensorType expected = ExtractSliceOp::inferResultType( - dstType, staticOffsets, staticSizes, staticStrides); + RankedTensorType expected = + ExtractSliceOp::inferResultType(dstType, staticSizes); if (expectedType) *expectedType = expected; return isRankReducedType(expected, srcType); @@ -2968,7 +2974,7 @@ public: // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), - mixedOffsets, mixedSizes, mixedStrides); + mixedSizes); Value toInsert = insertSliceOp.getSource(); if (sourceType != insertSliceOp.getSourceType()) { OpBuilder::InsertionGuard g(rewriter); @@ -3896,6 +3902,18 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +// Build an InsertSliceOp with mixed static and dynamic sizes, offsets set +// to 0, strides set to 1 and inferred result type. +void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, + Value dest, ArrayRef<OpFoldResult> sizes, + ArrayRef<NamedAttribute> attrs) { + Attribute zeroIdxAttr = b.getIndexAttr(0); + Attribute oneIdxAttr = b.getIndexAttr(1); + SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr); + build(b, result, source, dest, writeOffsets, sizes, writeStrides, attrs); +} + LogicalResult ParallelInsertSliceOp::verify() { if (!isa<InParallelOpInterface>(getOperation()->getParentOp())) return this->emitError("expected InParallelOpInterface parent, got:") diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index c607ece..310e725 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1132,35 +1132,22 @@ struct ConcatOpInterface // Extract the dimension for the concat op uint64_t concatDim = concatOp.getDim(); - bool dynamicConcatDim = false; SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1)); - SmallVector<OpFoldResult> sizes; - - for (const auto &[dimIdx, dimSize] : - llvm::enumerate(tensorType.getShape())) { - if (dimSize == ShapedType::kDynamic) { - auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx); - sizes.push_back(dimOp.getResult()); - if (dimIdx == concatDim) - dynamicConcatDim = true; - } else { - sizes.push_back(rewriter.getIndexAttr(dimSize)); - } - } - - int64_t concatDimOffset = 0; - std::optional<Value> dynamicOffset; - std::optional<Value> dynamicSize; - if (dynamicConcatDim) { - // One or more operands have dynamic size, so we must accumulate the - // offset with arith ops. - dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); - } + SmallVector<OpFoldResult> sizes = + memref::getMixedSizes(rewriter, loc, dstBuffer); + + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + auto sum = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1, + {v1, v2}); + }; + OpFoldResult concatDimOffset = rewriter.getIndexAttr(0); for (auto operand : concatOp.getInputs()) { // Get the buffer for the operand. FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state); @@ -1171,18 +1158,10 @@ struct ConcatOpInterface // so the offset on that axis must accumulate through the loop, and the // size must change to the size of the current operand. auto operandTensorType = cast<RankedTensorType>(operand.getType()); - int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim); - - if (dynamicConcatDim) { - offsets[concatDim] = dynamicOffset.value(); - dynamicSize = - memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim) - .getResult(); - sizes[concatDim] = dynamicSize.value(); - } else { - sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize); - offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset); - } + offsets[concatDim] = concatDimOffset; + OpFoldResult concatDimSize = + memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim); + sizes[concatDim] = concatDimSize; // Create a subview of the destination buffer. auto dstMemrefType = cast<MemRefType>(memrefType); @@ -1197,12 +1176,7 @@ struct ConcatOpInterface if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview))) return failure(); - if (dynamicConcatDim) { - dynamicOffset = arith::AddIOp::create( - rewriter, loc, dynamicOffset.value(), dynamicSize.value()); - } else { - concatDimOffset += operandConcatDimSize; - } + concatDimOffset = sum(concatDimOffset, concatDimSize); } replaceOpWithBufferizedValues(rewriter, op, dstBuffer); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 7ec61c7..a53af98 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -37,8 +37,7 @@ struct FoldExpandOfRankReducingExtract // supported. Moreover, only simple cases where the resulting ExtractSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( - srcType, extractSliceOp.getStaticOffsets(), - extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); + srcType, extractSliceOp.getStaticSizes()); if (nonReducingExtractType != resultType) return failure(); @@ -533,8 +532,8 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( getMixedSizes(b, loc, sliceOp.getSource()); // Helper variables and function for accumulating the size values. - AffineExpr d0, d1, d2; - bindDims(b.getContext(), d0, d1, d2); + AffineExpr d0, d1; + bindDims(b.getContext(), d0, d1); // Multiply two integers. auto mul = [&](OpFoldResult v1, OpFoldResult v2) { auto mulMap = AffineMap::get(2, 0, {d0 * d1}); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp index 753cb95..d35f458 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp @@ -155,13 +155,15 @@ struct ExtractSliceOpInterface RankedTensorType sourceType = extractSliceOp.getSource().getType(); // For each dimension, assert that: - // 0 <= offset < dim_size - // 0 <= offset + (size - 1) * stride < dim_size + // For empty slices (size == 0) : 0 <= offset <= dim_size + // For non-empty slices (size > 0): 0 <= offset < dim_size + // 0 <= offset + (size - 1) * stride < + // dim_size Value zero = arith::ConstantIndexOp::create(builder, loc, 0); Value one = arith::ConstantIndexOp::create(builder, loc, 1); for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { - // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(extractSliceOp); Value offset = getValueOrCreateConstantIndexOp( @@ -170,46 +172,63 @@ struct ExtractSliceOpInterface builder, loc, extractSliceOp.getMixedSizes()[i]); Value stride = getValueOrCreateConstantIndexOp( builder, loc, extractSliceOp.getMixedStrides()[i]); - - // Verify that offset is in-bounds. Value dimSize = builder.createOrFold<tensor::DimOp>( loc, extractSliceOp.getSource(), i); - Value offsetInBounds = - generateInBoundsCheck(builder, loc, offset, zero, dimSize); - cf::AssertOp::create(builder, loc, offsetInBounds, + + // Verify that offset is in-bounds (conditional on slice size). + Value sizeIsZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, size, zero); + auto offsetCheckIf = scf::IfOp::create( + builder, loc, sizeIsZero, + [&](OpBuilder &b, Location loc) { + // For empty slices, offset can be at the boundary: 0 <= offset <= + // dimSize. + Value offsetGEZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sge, offset, zero); + Value offsetLEDimSize = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::sle, offset, dimSize); + Value emptyOffsetValid = + arith::AndIOp::create(b, loc, offsetGEZero, offsetLEDimSize); + scf::YieldOp::create(b, loc, emptyOffsetValid); + }, + [&](OpBuilder &b, Location loc) { + // For non-empty slices, offset must be a valid index: 0 <= offset < + // dimSize. + Value offsetInBounds = + generateInBoundsCheck(b, loc, offset, zero, dimSize); + scf::YieldOp::create(b, loc, offsetInBounds); + }); + + Value offsetCondition = offsetCheckIf.getResult(0); + cf::AssertOp::create(builder, loc, offsetCondition, generateErrorMessage(op, "offset " + std::to_string(i) + " is out-of-bounds")); - // Only verify if size > 0 + // Verify that the slice endpoint is in-bounds (only for non-empty + // slices). Value sizeIsNonZero = arith::CmpIOp::create( builder, loc, arith::CmpIPredicate::sgt, size, zero); + auto ifOp = scf::IfOp::create( + builder, loc, sizeIsNonZero, + [&](OpBuilder &b, Location loc) { + // Verify that slice does not run out-of-bounds. + Value sizeMinusOne = arith::SubIOp::create(b, loc, size, one); + Value sizeMinusOneTimesStride = + arith::MulIOp::create(b, loc, sizeMinusOne, stride); + Value lastPos = + arith::AddIOp::create(b, loc, offset, sizeMinusOneTimesStride); + Value lastPosInBounds = + generateInBoundsCheck(b, loc, lastPos, zero, dimSize); + scf::YieldOp::create(b, loc, lastPosInBounds); + }, + [&](OpBuilder &b, Location loc) { + Value trueVal = + arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); + scf::YieldOp::create(b, loc, trueVal); + }); - auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), - sizeIsNonZero, /*withElseRegion=*/true); - - // Populate the "then" region (for size > 0). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - - // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); - Value sizeMinusOneTimesStride = - arith::MulIOp::create(builder, loc, sizeMinusOne, stride); - Value lastPos = - arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); - Value lastPosInBounds = - generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - scf::YieldOp::create(builder, loc, lastPosInBounds); - - // Populate the "else" region (for size == 0). - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - Value trueVal = - arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); - scf::YieldOp::create(builder, loc, trueVal); - - builder.setInsertionPointAfter(ifOp); Value finalCondition = ifOp.getResult(0); - cf::AssertOp::create( builder, loc, finalCondition, generateErrorMessage( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 293c6af..c420a4c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) { - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); } Attribute newMinValAttr, newMaxValAttr; @@ -1485,7 +1486,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { return {}; } +static bool +mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) { + const auto isDynamic = [](Type ty) { + const auto shapedTy = llvm::dyn_cast<ShapedType>(ty); + return !shapedTy || !shapedTy.hasStaticShape(); + }; + + return llvm::any_of(operandTypes, isDynamic) || + failed(verifyCompatibleShapes(operandTypes)); +} + OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { + // Select allows operand shapes to be broadcast to the output shape. For + // now, don't support folding when we cannot prove no broadcasting is + // involved. + if (mayRequireBroadcast(getOperandTypes())) + return {}; + if (getOnTrue() == getOnFalse()) return getOnTrue(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 65e0a59..1c175f9ab 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -563,7 +563,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) { static Type getStorageElementTypeOrSelf(Type type) { auto srcType = getElementTypeOrSelf(type); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType)) - srcType = quantType.getStorageType(); + srcType = getStorageElementTypeFromQuantized(quantType); return srcType; } @@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) { bool resultIsFloat = llvm::isa<FloatType>(resultEType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType)) - weightEType = quantType.getStorageType(); + weightEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType)) - biasEType = quantType.getStorageType(); + biasEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { // for now, only enforce bias element type == result element type for @@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() { if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>( outputType.getElementType())) { - if (result.getStorageType() == attrType.getElementType()) + if (getStorageElementTypeFromQuantized(result) == attrType.getElementType()) return success(); } @@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast<ShapedType>(op.getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); auto accType = op.getAccType(); if (inputEType.isInteger(8) && !accType.isInteger(32)) @@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast<ShapedType>(op.getResult().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); return success(); } @@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() { llvm::cast<ShapedType>(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) { - inputETy = quantType.getStorageType(); + inputETy = getStorageElementTypeFromQuantized(quantType); } mlir::Type outputETy = llvm::cast<ShapedType>(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) { - outputETy = quantType.getStorageType(); + outputETy = getStorageElementTypeFromQuantized(quantType); } if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 41b338d..091b481 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaAttachTarget.cpp + TosaArithConstantToConst.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp @@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaTypeConverters.cpp TosaProfileCompliance.cpp TosaValidation.cpp + TosaNarrowI64ToI32.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms @@ -21,7 +23,9 @@ add_mlir_dialect_library(MLIRTosaTransforms LINK_LIBS PUBLIC MLIRFuncDialect + MLIRFuncTransformOps MLIRPass MLIRTosaDialect MLIRTransformUtils + MLIRFuncTransforms ) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp new file mode 100644 index 0000000..73e1e2b --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp @@ -0,0 +1,111 @@ +//===- TosaArithConstantToConst.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that converts tensor-valued arith.constant ops +// into tosa.const so that TOSA pipelines operate on a uniform constant form. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +// NOTE: TOSA pipelines already lower their constants through shared Arith +// folding passes, so tensor literals often come back as `arith.constant` even +// after the IR is otherwise TOSA-only. Keep this normalization with the rest of +// the TOSA transforms so any client can re-establish a canonical `tosa.const` +// representation without needing a full Arith->TOSA conversion library. + +/// Returns true when `elementType` is natively representable by tosa.const. +static bool isSupportedElementType(Type elementType) { + if (isa<FloatType>(elementType)) + return true; + + if (auto intType = dyn_cast<IntegerType>(elementType)) + return intType.isSignless() || intType.isUnsigned(); + + if (isa<quant::QuantizedType>(elementType)) + return true; + + if (isa<tosa::mxint8Type>(elementType)) + return true; + + return false; +} + +class ArithConstantToTosaConst : public OpRewritePattern<arith::ConstantOp> { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ConstantOp constOp, + PatternRewriter &rewriter) const override { + // TOSA constant verification requires a ranked, statically shaped tensor. + auto resultType = dyn_cast<RankedTensorType>(constOp.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + + if (!isSupportedElementType(resultType.getElementType())) + return failure(); + + Attribute attr = constOp.getValueAttr(); + auto elementsAttr = dyn_cast<ElementsAttr>(attr); + if (!elementsAttr) + return failure(); + + auto attrType = dyn_cast<RankedTensorType>(elementsAttr.getType()); + if (!attrType || !attrType.hasStaticShape()) + return failure(); + if (attrType != resultType) + return failure(); + + auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(), + resultType, elementsAttr); + rewriter.replaceOp(constOp, newConst.getResult()); + return success(); + } +}; + +struct TosaArithConstantToTosaConstPass + : public tosa::impl::TosaArithConstantToTosaConstPassBase< + TosaArithConstantToTosaConstPass> { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<arith::ArithDialect, tosa::TosaDialect>(); + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add<ArithConstantToTosaConst>(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 0bec0da..022476a2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { ShapedType weightType = cast<ShapedType>(weight.getType()); ShapedType resultType = cast<ShapedType>(op.getOutput().getType()); - if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && - resultType.hasStaticShape())) { + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i)) + return failure(); + } + + if (!weightType.hasStaticShape()) { return failure(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index dc5c51b..8b23fd1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -49,8 +49,13 @@ public: if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) return failure(); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t kernelHeight = weightTy.getDimSize(1); @@ -113,8 +118,13 @@ public: if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) return rewriter.notifyMatchFailure(op, "non-one stride found."); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t batch = inputTy.getDimSize(0); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp new file mode 100644 index 0000000..ddaf7d8a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp @@ -0,0 +1,310 @@ +//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass narrows TOSA operations with 64-bit integer tensor types to +// 32-bit integer tensor types. This can be useful for backends that do not +// support the EXT-INT64 extension of TOSA. The pass has two options: +// +// - aggressive-rewrite - If enabled, all TOSA operations are rewritten, +// regardless or whether the narrowing is safe. This option may lead to +// data loss if not used carefully. +// - convert-function-boundaries - If enabled, the pass will convert function +// I/O types as well. Otherwise casts will be inserted at the I/O +// boundaries. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +LogicalResult convertGenericOp(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter, + const TypeConverter *typeConverter) { + // Convert types of results + SmallVector<Type, 4> newResults; + if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults))) + return failure(); + + // Create a new operation state + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResults, {}, op->getSuccessors()); + + for (const NamedAttribute &namedAttribute : op->getAttrs()) { + const Attribute attribute = namedAttribute.getValue(); + + // Convert integer attribute type + if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) { + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(intAttr.getType(), attribute); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) { + Type type = typeAttr.getValue(); + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(type, attribute); + if (!convertedAttribute) + return rewriter.notifyMatchFailure(op, + "Failed to convert type attribute."); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) { + const Type type = denseElementsAttr.getType(); + const std::optional<Attribute> convertedAttribute = + typeConverter->convertTypeAttribute(type, denseElementsAttr); + if (!convertedAttribute) + return rewriter.notifyMatchFailure( + op, "Failed to convert dense elements attribute."); + state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); + continue; + } + + state.addAttribute(namedAttribute.getName(), attribute); + } + + for (Region ®ion : op->getRegions()) { + Region *newRegion = state.addRegion(); + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter))) + return failure(); + } + + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} + +// =========================== +// Aggressive rewrite patterns +// =========================== + +class ConvertGenericOp : public ConversionPattern { +public: + ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + if (!isa<tosa::TosaOp>(op)) + return rewriter.notifyMatchFailure( + op, + "Support for operations other than TOSA has not been implemented."); + + return convertGenericOp(op, operands, rewriter, typeConverter); + } +}; + +// =============================== +// Bounds checked rewrite patterns +// =============================== + +class ConvertArgMaxOpWithBoundsChecking + : public OpConversionPattern<tosa::ArgMaxOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // Output type can be narrowed based on the size of the axis dimension + const int32_t axis = op.getAxis(); + const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType()); + if (!inputType || !inputType.isStaticDim(axis)) + return rewriter.notifyMatchFailure( + op, "Requires a static axis dimension for bounds checking."); + const int64_t axisDim = inputType.getDimSize(axis); + if (axisDim >= std::numeric_limits<int32_t>::max()) + return rewriter.notifyMatchFailure( + op, "Axis dimension is too large to narrow safely."); + + const Type resultType = op.getOutput().getType(); + const Type newResultType = typeConverter->convertType(resultType); + rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType, + adaptor.getInput(), axis); + return success(); + } +}; + +class ConvertCastOpWithBoundsChecking + : public OpConversionPattern<tosa::CastOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType()); + const auto resultType = dyn_cast<ShapedType>(op.getResult().getType()); + if (!inputType || !resultType) + return failure(); + + const auto elementInputIntType = + dyn_cast<IntegerType>(inputType.getElementType()); + const auto elementResultIntType = + dyn_cast<IntegerType>(resultType.getElementType()); + if (elementInputIntType && elementResultIntType && + elementInputIntType.getWidth() > elementResultIntType.getWidth()) + return rewriter.notifyMatchFailure( + op, "Narrowing cast may lead to data loss."); + + rewriter.replaceOpWithNewOp<tosa::CastOp>( + op, typeConverter->convertType(resultType), adaptor.getInput()); + return success(); + } +}; + +template <typename OpTy> +class ConvertTypedOp : public OpConversionPattern<OpTy> { + using OpConversionPattern<OpTy>::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + return convertGenericOp(op, adaptor.getOperands(), rewriter, + this->getTypeConverter()); + } +}; + +struct TosaNarrowI64ToI32 + : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> { +public: + explicit TosaNarrowI64ToI32() = default; + explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) + : TosaNarrowI64ToI32() { + this->aggressiveRewrite = options.aggressiveRewrite; + this->convertFunctionBoundaries = options.convertFunctionBoundaries; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) -> Type { return type; }); + typeConverter.addConversion([](IntegerType type) -> Type { + if (!type.isInteger(64)) + return type; + return IntegerType::get(type.getContext(), 32); + }); + typeConverter.addConversion( + [&typeConverter](RankedTensorType type) -> Type { + const Type elementType = type.getElementType(); + if (!elementType.isInteger(64)) + return type; + return RankedTensorType::get(type.getShape(), + typeConverter.convertType(elementType)); + }); + + const auto materializeCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + return tosa::CastOp::create(builder, loc, resultType, inputs.front()); + }; + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); + + typeConverter.addTypeAttributeConversion( + [](IntegerType type, IntegerAttr attribute) -> Attribute { + const APInt value = attribute.getValue().truncSSat(32); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), + value); + }); + typeConverter.addTypeAttributeConversion( + [&typeConverter](ShapedType type, + DenseIntElementsAttr attr) -> Attribute { + const ShapedType newType = + cast<ShapedType>(typeConverter.convertType(type)); + const auto oldElementType = cast<IntegerType>(type.getElementType()); + const auto newElementType = + cast<IntegerType>(newType.getElementType()); + if (oldElementType.getWidth() == newElementType.getWidth()) + return attr; + + DenseElementsAttr mapped = + attr.mapValues(newElementType, [&](const APInt &v) { + return v.truncSSat(newElementType.getWidth()); + }); + return mapped; + }); + + ConversionTarget target(*context); + target.addDynamicallyLegalDialect<tosa::TosaDialect>( + [&typeConverter](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()) && + typeConverter.isLegal(op->getOperandTypes()); + }); + if (convertFunctionBoundaries) { + target.addDynamicallyLegalOp<func::FuncOp>( + [&typeConverter](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) { + const FunctionType funcType = + op->getParentOfType<func::FuncOp>().getFunctionType(); + return llvm::equal(op.getOperandTypes(), funcType.getResults()); + }); + } else { + target.addDynamicallyLegalOp<func::FuncOp>( + [](func::FuncOp op) { return true; }); + target.addDynamicallyLegalOp<func::ReturnOp>( + [](func::ReturnOp op) { return true; }); + } + + RewritePatternSet patterns(context); + if (convertFunctionBoundaries) { + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + } + if (aggressiveRewrite) { + patterns.add<ConvertGenericOp>(typeConverter, context); + } else { + // Tensor + patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context); + // Data layout + patterns.add<ConvertTypedOp<tosa::ConcatOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::PadOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::ReshapeOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::ReverseOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::SliceOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::TileOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::TransposeOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::IdentityOp>>(typeConverter, context); + // Type conversion + patterns.add<ConvertCastOpWithBoundsChecking>(typeConverter, context); + // Controlflow + patterns.add<ConvertTypedOp<tosa::IfOp>>(typeConverter, context); + patterns.add<ConvertTypedOp<tosa::WhileOp>>(typeConverter, context); + } + + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index ac5d620..36e8940 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -70,6 +70,8 @@ namespace { // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c]. +// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?]. LogicalResult computeReshapeOutput(ArrayRef<int64_t> higherRankShape, ArrayRef<int64_t> lowerRankShape, @@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape, higherRankDim = higherRankShape[i + rankDiff]; lowerRankDim = lowerRankShape[i]; - if (lowerRankDim != 1 && higherRankDim != 1 && + auto isStaticDimAndNotEqualToOne = [](int64_t dim) { + return dim != 1 && dim != ShapedType::kDynamic; + }; + + if (isStaticDimAndNotEqualToOne(lowerRankDim) && + isStaticDimAndNotEqualToOne(higherRankDim) && lowerRankDim != higherRankDim) return failure(); @@ -216,22 +223,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { bool mlir::tosa::hasUniqueConstantScatterIndices( ShapedType indicesType, DenseIntElementsAttr indicesAttr) { - llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape(); + const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape(); const unsigned int indicesRank = indicesShape.size(); const unsigned int lastDimSize = indicesShape[indicesRank - 1]; // check each batch of indices from the flat indicesAttr values // for duplicates - auto const indicesValues = indicesAttr.getValues<int32_t>(); + auto const indicesValues = indicesAttr.getValues<APInt>(); assert( (indicesValues.size() % lastDimSize == 0) && "Constant indices data length should be a multiple of indicesShape[-1]"); - std::vector<uint64_t> indices(lastDimSize); + std::vector<APInt> indices(lastDimSize); for (auto beg = indicesValues.begin(); beg < indicesValues.end(); beg += lastDimSize) { std::copy(beg, beg + lastDimSize, indices.begin()); - std::sort(indices.begin(), indices.end()); + std::sort(indices.begin(), indices.end(), + [](const APInt &a, const APInt &b) { return a.slt(b); }); if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) { // found duplicate values in indices in batch return false; diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 02c86a0..c55b13d 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, maxAttr, quantBits, filterQuantDim, isSigned, narrowRange)); } + +Type mlir::tosa::getStorageElementTypeFromQuantized( + quant::QuantizedType quantType) { + auto quantEty = quantType.getStorageType(); + // StorageType doesn't capture the sign information + // Explicitly create unsigned type if needed + if (!quantType.isSigned()) { + quantEty = IntegerType::get(quantEty.getContext(), + quantEty.getIntOrFloatBitWidth(), + IntegerType::Unsigned); + } + return quantEty; +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 062606e..86233b0 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2062,6 +2062,10 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter, DiagnosedSilenceableFailure result = applySequenceBlock( callee.getBody().front(), getFailurePropagationMode(), state, results); + + if (!result.succeeded()) + return result; + mappings.clear(); detail::prepareValueMappings( mappings, callee.getBody().front().getTerminator()->getOperands(), state); diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 8859541..24b0487 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -1495,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp, template <typename T> static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) { - return llvm::to_vector(llvm::map_range( - range, [](transform::MappedValue value) { return cast<T>(value); })); + return llvm::map_to_vector(range, llvm::CastTo<T>); } void transform::detail::setApplyToOneResults( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index f727118..2bd6205 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -156,7 +156,7 @@ DiagnosedSilenceableFailure transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { - std::optional<size_t> selectedRegionIdx; + std::optional<int64_t> selectedRegionIdx; if (auto selectedRegionAttr = getSelectedRegionAttr()) selectedRegionIdx = selectedRegionAttr->getSExtValue(); @@ -232,7 +232,7 @@ LogicalResult transform::tune::AlternativesOp::verify() { } if (auto selectedRegionAttr = getSelectedRegionAttr()) { - size_t regionIdx = selectedRegionAttr->getSExtValue(); + int64_t regionIdx = selectedRegionAttr->getSExtValue(); if (regionIdx < 0 || regionIdx >= getNumRegions()) return emitOpError() << "'selected_region' attribute specifies region at index " diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp index a26edac..2986f4c 100644 --- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp +++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp @@ -106,14 +106,12 @@ ScalableValueBoundsConstraintSet::computeScalableBound( AffineMap bound = [&] { if (boundType == BoundType::EQ && !invalidBound(lowerBound) && - lowerBound[0] == upperBound[0]) { + lowerBound[0] == upperBound[0]) return lowerBound[0]; - } - if (boundType == BoundType::LB && !invalidBound(lowerBound)) { + if (boundType == BoundType::LB && !invalidBound(lowerBound)) return lowerBound[0]; - } else if (boundType == BoundType::UB && !invalidBound(upperBound)) { + if (boundType == BoundType::UB && !invalidBound(upperBound)) return upperBound[0]; - } return AffineMap{}; }(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index daef0ba..2789f63 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6066,19 +6066,21 @@ LogicalResult ScatterOp::verify() { VectorType indVType = getIndexVectorType(); VectorType maskVType = getMaskVectorType(); VectorType valueVType = getVectorType(); - MemRefType memType = getMemRefType(); + ShapedType baseType = getBaseType(); - if (valueVType.getElementType() != memType.getElementType()) + if (!llvm::isa<MemRefType, RankedTensorType>(baseType)) + return emitOpError("requires base to be a memref or ranked tensor type"); + + if (valueVType.getElementType() != baseType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getOffsets()) != memType.getRank()) - return emitOpError("requires ") << memType.getRank() << " indices"; + if (llvm::size(getOffsets()) != baseType.getRank()) + return emitOpError("requires ") << baseType.getRank() << " indices"; if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getShape() != maskVType.getShape()) return emitOpError("expected valueToStore dim to match mask dim"); return success(); } - namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: @@ -6241,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + LogicalResult ShapeCastOp::verify() { VectorType sourceType = getSourceVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 546099c..352f477 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::bufferization; @@ -126,6 +127,54 @@ struct TransferWriteOpInterface } }; +/// Bufferization of vector.scatter. Replaced with a new vector.scatter that +/// operates on a memref. +struct ScatterOpInterface + : public BufferizableOpInterface::ExternalModel<ScatterOpInterface, + vector::ScatterOp> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + return true; + } + + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + assert(isa<RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + auto scatterOp = cast<vector::ScatterOp>(op); + if (&opOperand != &scatterOp.getBaseMutable()) + return {}; + return {{scatterOp.getResult(), BufferRelation::Equivalent}}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto scatterOp = cast<vector::ScatterOp>(op); + assert(isa<TensorType>(scatterOp.getBaseType()) && + "only tensor types expected"); + FailureOr<Value> buffer = + getBuffer(rewriter, scatterOp.getBase(), options, state); + if (failed(buffer)) + return failure(); + vector::ScatterOp::create(rewriter, scatterOp.getLoc(), + /*resultType=*/nullptr, *buffer, + scatterOp.getOffsets(), scatterOp.getIndices(), + scatterOp.getMask(), scatterOp.getValueToStore()); + replaceOpWithBufferizedValues(rewriter, op, *buffer); + return success(); + } +}; + /// Bufferization of vector.gather. Replaced with a new vector.gather that /// operates on a memref. struct GatherOpInterface @@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels( GatherOp::attachInterface<GatherOpInterface>(*ctx); MaskOp::attachInterface<MaskOpInterface>(*ctx); YieldOp::attachInterface<YieldOpInterface>(*ctx); + ScatterOp::attachInterface<ScatterOpInterface>(*ctx); }); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index 258f2cb..1af5523 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -111,7 +111,7 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { if (!isValidKind(isInt, scanOp.getKind())) return failure(); - VectorType resType = VectorType::get(destShape, elType); + VectorType resType = destType; Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); int64_t reductionDim = scanOp.getReductionDim(); @@ -121,8 +121,18 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { int64_t initialValueRank = initialValueType.getRank(); SmallVector<int64_t> reductionShape(destShape); + SmallVector<bool> reductionScalableDims(destType.getScalableDims()); + + if (reductionScalableDims[reductionDim]) + return rewriter.notifyMatchFailure( + scanOp, "Trying to reduce scalable dimension - not yet supported!"); + + // The reduction dimension, after reducing, becomes 1. It's a fixed-width + // dimension - no need to touch the scalability flag. reductionShape[reductionDim] = 1; - VectorType reductionType = VectorType::get(reductionShape, elType); + VectorType reductionType = + VectorType::get(reductionShape, elType, reductionScalableDims); + SmallVector<int64_t> offsets(destRank, 0); SmallVector<int64_t> strides(destRank, 1); SmallVector<int64_t> sizes(destShape); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 726da1e..ad16b80 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -453,6 +453,8 @@ struct ReorderCastOpsOnBroadcast PatternRewriter &rewriter) const override { if (op->getNumOperands() != 1) return failure(); + if (!isa<VectorType>(op->getResult(0).getType())) + return failure(); auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); if (!bcastOp) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae098..462bd8c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,286 @@ private: vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.create_mask` operations into smaller mask +/// operations based on the target unroll shape. Each unrolled slice computes +/// its local mask size in each dimension (d) as: +/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]). +/// Example: +/// Given a create_mask operation: +/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10 +/// elements +/// +/// and a target unroll shape of <4x8>, the pattern produces: +/// +/// %false = arith.constant dense<false> : vector<8x16xi1> +/// +/// Slice [0,0]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8 +/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1> +/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [0,8]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2 +/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1> +/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,0]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8 +/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1> +/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,8]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2 +/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1> +/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> { + UnrollCreateMaskPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern<vector::CreateMaskOp>(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, createMaskOp); + if (!targetShape) + return failure(); + + VectorType resultType = createMaskOp.getVectorType(); + SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll(); + Location loc = createMaskOp.getLoc(); + + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + VectorType targetVectorType = + VectorType::get(*targetShape, rewriter.getI1Type()); + SmallVector<int64_t> strides(targetShape->size(), 1); + + // In each dimension (d), each unrolled vector computes its mask size as: + // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]). + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { + SmallVector<Value> unrolledOperands; + + for (auto [i, originalMaskOperand] : + llvm::enumerate(createMaskOp.getOperands())) { + Value offsetVal = + arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); + Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>( + loc, originalMaskOperand, offsetVal); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value unrolledDimSize = + arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); + Value nonNegative = + rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero); + Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>( + loc, nonNegative, unrolledDimSize); + unrolledOperands.push_back(unrolledOperand); + } + + auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>( + loc, targetVectorType, unrolledOperands); + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( + loc, unrolledMask, result, offsets, strides); + } + rewriter.replaceOp(createMaskOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + +/// Checks whether extractShape is a contiguous slice of shape. +/// For extractShape to be contiguous in shape: +/// 1) All but the leading dimension of extractShape and shape must match +/// exactly. 2) The total number of elements in shape must be evenly divisible +/// by +/// the total number of elements in extractShape. +/// Examples: +/// isContiguous([4, 4], [8, 4]) == true +/// isContiguous([2, 4], [8, 4]) == true +/// isContiguous([2, 2], [8, 4]) == false +/// Removes leading unit dimensions to handle cases like: +/// isContiguous([1, 16], [1, 32]) == true +static bool isContiguous(ArrayRef<int64_t> extractShape, + ArrayRef<int64_t> shape) { + + if (extractShape.size() > shape.size()) + return false; + + while (!extractShape.empty() && extractShape.front() == 1) { + extractShape = extractShape.drop_front(); + } + + while (!shape.empty() && shape.front() == 1) { + shape = shape.drop_front(); + } + + size_t rankDiff = shape.size() - extractShape.size(); + if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1))) + return false; + + int64_t extractElements = ShapedType::getNumElements(extractShape); + int64_t shapeElements = ShapedType::getNumElements(shape); + return shapeElements % extractElements == 0; +} + +/// Determines what shape to use with `vector.extract_strided_slice` to extract +/// a contiguous memory region from a source vector. The extraction must be +/// contiguous and contain exactly the specified number of elements. If such an +/// extraction shape cannot be determined, returns std::nullopt. +/// EXAMPLE 1: +/// sourceShape = [16], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 16) = 8 from only dim → extractShape = [8], +/// remaining = 8/8 = 1 +/// Result: [8] +/// +/// EXAMPLE 2: +/// sourceShape = [4, 4], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 4) = 4 from last dim → extractShape = [4], +/// remaining = 8/4 = 2 +/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4], +/// remaining = 2/2 = 1 +/// Result: [2, 4] +static std::optional<SmallVector<int64_t>> +calculateSourceExtractShape(ArrayRef<int64_t> sourceShape, + int64_t targetElements) { + SmallVector<int64_t> extractShape; + int64_t remainingElements = targetElements; + + // Build extract shape from innermost dimension outward to ensure contiguity. + for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { + int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); + extractShape.insert(extractShape.begin(), takeFromDim); + + if (remainingElements % takeFromDim != 0) + return std::nullopt; // Not evenly divisible. + remainingElements /= takeFromDim; + } + + // Fill remaining dimensions with 1. + while (extractShape.size() < sourceShape.size()) + extractShape.insert(extractShape.begin(), 1); + + if (ShapedType::getNumElements(extractShape) != targetElements) + return std::nullopt; + + return extractShape; +} + +// Convert result offsets to source offsets via linear position. +static SmallVector<int64_t> +calculateSourceOffsets(ArrayRef<int64_t> resultOffsets, + ArrayRef<int64_t> sourceShape, + ArrayRef<int64_t> resultShape) { + // Convert result offsets to linear position. + int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape)); + // Convert linear position to source offsets. + return delinearize(linearIndex, computeStrides(sourceShape)); +} + +/// This pattern unrolls `vector.shape_cast` operations according to the +/// provided target unroll shape. It unrolls a large shape cast into smaller +/// shape casts by extracting contiguous slices from the source vector, casting +/// each slice to the target shape, and assembling the result by inserting each +/// computed segment into the appropriate offset of the result vector. +/// +/// This pattern only applies when contiguous slices can be extracted from the +/// source vector and inserted into the result vector such that each slice +/// remains a valid vector (and not decompose to scalars). In these cases, the +/// unrolling proceeds as: +/// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> +/// vector.insert_strided_slice. +/// +/// Example: +/// Given a shape cast operation: +/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32> +/// +/// and a target unroll shape of <2x4>, the pattern produces: +/// +/// %zero = arith.constant dense<0.0> : vector<4x4xf32> +/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32> +/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32> +/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// +struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> { + UnrollShapeCastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern<vector::ShapeCastOp>(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + std::optional<SmallVector<int64_t>> targetShape = + getTargetShape(options, shapeCastOp); + if (!targetShape) + return failure(); + + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType resultType = shapeCastOp.getResultVectorType(); + ArrayRef<int64_t> sourceShape = sourceType.getShape(); + ArrayRef<int64_t> resultShape = resultType.getShape(); + + if (!isContiguous(*targetShape, resultShape)) + return rewriter.notifyMatchFailure( + shapeCastOp, "Only supports cases where target shape is " + "contiguous in result vector shape"); + + int64_t targetElements = ShapedType::getNumElements(*targetShape); + + // Calculate the shape to extract from source. + std::optional<SmallVector<int64_t>> extractShape = + calculateSourceExtractShape(sourceShape, targetElements); + if (!extractShape) + return rewriter.notifyMatchFailure( + shapeCastOp, + "cannot extract target number of elements contiguously from source"); + + Location loc = shapeCastOp.getLoc(); + + // Create result vector initialized to zero. + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + + VectorType targetType = + VectorType::get(*targetShape, sourceType.getElementType()); + + SmallVector<int64_t> extractStrides(extractShape->size(), 1); + SmallVector<int64_t> insertStrides(targetShape->size(), 1); + + for (SmallVector<int64_t> resultOffsets : + StaticTileOffsetRange(resultShape, *targetShape)) { + SmallVector<int64_t> sourceOffsets = + calculateSourceOffsets(resultOffsets, sourceShape, resultShape); + Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, + extractStrides); + Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>( + loc, targetType, sourceChunk); + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( + loc, targetChunk, result, resultOffsets, insertStrides); + } + + rewriter.replaceOp(shapeCastOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern, + UnrollCreateMaskPattern>(patterns.getContext(), options, + benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index c809c502..c307fb4 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -322,46 +322,61 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, std::optional<Value> padValue, bool useInBoundsInsteadOfMasking, ArrayRef<bool> inputScalableVecDims) { - assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && + VectorType vecToReadTy = VectorType::get( + inputVectorSizes, cast<ShapedType>(source.getType()).getElementType(), + inputScalableVecDims); + + return createReadOrMaskedRead(builder, loc, source, vecToReadTy, padValue, + useInBoundsInsteadOfMasking); +} + +Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, + Value source, + const VectorType &vecToReadTy, + std::optional<Value> padValue, + bool useInBoundsInsteadOfMasking) { + assert(!llvm::is_contained(vecToReadTy.getScalableDims(), + ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast<ShapedType>(source.getType()); auto sourceShape = sourceShapedType.getShape(); - assert(sourceShape.size() == inputVectorSizes.size() && + + int64_t vecToReadRank = vecToReadTy.getRank(); + auto vecToReadShape = vecToReadTy.getShape(); + + assert(sourceShape.size() == static_cast<size_t>(vecToReadRank) && "expected same ranks."); - auto vectorType = - VectorType::get(inputVectorSizes, sourceShapedType.getElementType(), - inputScalableVecDims); assert((!padValue.has_value() || padValue.value().getType() == sourceShapedType.getElementType()) && "expected same pad element type to match source element type"); - int64_t readRank = inputVectorSizes.size(); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); - SmallVector<bool> inBoundsVal(readRank, true); + SmallVector<bool> inBoundsVal(vecToReadRank, true); if (useInBoundsInsteadOfMasking) { // Update the inBounds attribute. // FIXME: This computation is too weak - it ignores the read indices. - for (unsigned i = 0; i < readRank; i++) - inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) && + for (unsigned i = 0; i < vecToReadRank; i++) + inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) && ShapedType::isStatic(sourceShape[i]); } auto transferReadOp = vector::TransferReadOp::create( builder, loc, - /*vectorType=*/vectorType, + /*vectorType=*/vecToReadTy, /*source=*/source, - /*indices=*/SmallVector<Value>(readRank, zero), + /*indices=*/SmallVector<Value>(vecToReadRank, zero), /*padding=*/padValue, /*inBounds=*/inBoundsVal); - if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking) + if (llvm::equal(vecToReadTy.getShape(), sourceShape) || + useInBoundsInsteadOfMasking) return transferReadOp; SmallVector<OpFoldResult> mixedSourceDims = isa<MemRefType>(source.getType()) ? memref::getMixedSizes(builder, loc, source) : tensor::getMixedSizes(builder, loc, source); - auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(), - inputScalableVecDims); + auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type()); Value mask = vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt index 9f57627..cb1e9d0 100644 --- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt new file mode 100644 index 0000000..f4c9f8a --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRX86VectorTransformOps + X86VectorTransformOps.cpp + + DEPENDS + MLIRX86VectorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRVectorDialect + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRX86VectorDialect + MLIRX86VectorTransforms + ) diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp new file mode 100644 index 0000000..95db208 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -0,0 +1,64 @@ +//===- X86VectorTransformOps.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +using namespace mlir; +using namespace mlir::x86vector; +using namespace mlir::transform; + +void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + x86vector::populateVectorContractToFMAPatterns(patterns); +} + +void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class X86VectorTransformDialectExtension + : public transform::TransformDialectExtension< + X86VectorTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + X86VectorTransformDialectExtension) + + X86VectorTransformDialectExtension() { + declareGeneratedDialect<x86vector::X86VectorDialect>(); + declareGeneratedDialect<LLVM::LLVMDialect>(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc" + +void mlir::x86vector::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions<X86VectorTransformDialectExtension>(); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index c51266a..2cab50f 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,11 +1,14 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp + VectorContractToFMA.cpp + VectorContractToPackedTypeDotProduct.cpp LINK_LIBS PUBLIC MLIRArithDialect MLIRX86VectorDialect MLIRIR + MLIRLinalgDialect MLIRLLVMCommonConversion MLIRLLVMDialect MLIRVectorDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp new file mode 100644 index 0000000..f3af5ca --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp @@ -0,0 +1,143 @@ +//===- VectorContractToFMA.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +// Implements outer product contraction as a sequence of broadcast and +// FMA operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <16xf32> +// vector.fma vector<16xf32> +// ``` +struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 lowering is supported."); + + ArrayRef<int64_t> lhsShape = lhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef<int64_t> rhsShape = rhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0) + return rewriter.notifyMatchFailure( + contractOp, "Excepts unit dimensions for either LHS or RHS shape."); + + if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator is not a vector type"); + + if (!accTy.getElementType().isF32()) + return rewriter.notifyMatchFailure(contractOp, + "Accmulator should be F32 type."); + + ArrayRef<int64_t> accShape = accTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B dimension should be non-unit."); + + // Lowers vector.contract into a broadcast+FMA sequence. + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + vector::FMAOp fma; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we + // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if (nonUnitDimRhs.size() > 0) { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, castRhs.getResult().getType(), castLhs); + fma = + vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc); + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(1, rhsTy.getElementType()), + contractOp.getRhs()); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, castLhs.getResult().getType(), castRhs); + fma = + vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc); + } + + auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma); + rewriter.replaceOp(contractOp, castFma); + + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToFMAPatterns( + RewritePatternSet &patterns) { + patterns.add<VectorContractToFMA>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp new file mode 100644 index 0000000..1e64811 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp @@ -0,0 +1,301 @@ +//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +static FailureOr<SmallVector<mlir::utils::IteratorType>> +inferIteratorsFromOutMap(AffineMap map) { + if (!map.isProjectedPermutation()) + return failure(); + SmallVector<mlir::utils::IteratorType> iterators( + map.getNumDims(), mlir::utils::IteratorType::reduction); + for (auto expr : map.getResults()) + if (auto dim = dyn_cast<AffineDimExpr>(expr)) + iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel; + return iterators; +} + +// Returns true if the operation is in VNNI layout. +// Optionally, the check can be constrained to a specific VNNI blocking factor. +static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps, + std::optional<unsigned> blockingFactor) { + // Narrow down type operations - VNNI only applies to contractions. + FailureOr<linalg::ContractionDimensions> dims = + linalg::inferContractionDims(indexingMaps); + if (failed(dims)) + return false; + + auto matA = op->getOperand(0); + auto matB = op->getOperand(1); + auto typeA = dyn_cast<ShapedType>(matA.getType()); + auto typeB = dyn_cast<ShapedType>(matB.getType()); + unsigned rankA = typeA.getRank(); + unsigned rankB = typeB.getRank(); + // VNNI format requires at least 1 parallel and 2 reduction dimensions. + if (rankA < 3 || rankB < 3) + return false; + + // At least two reduction dimensions are expected: + // one for the VNNI factor and one for the K dimension + if (dims->k.size() < 2) + return false; + + // Validate affine maps - VNNI computation should be defined by the two + // innermost reduction iterators. + // The input matrix dimensions layout must match the following: + // - matrix A - [...][K/vnniFactor][vnniFactor] + // - matrix B - [...][K/vnniFactor][N][vnniFactor] + auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]); + if (failed(maybeIters)) + return false; + SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters; + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + + auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1)); + auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1)); + if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || + iteratorTypes[vnniDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2)); + auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3)); + if (!redDimA || !redDimB || redDimA != redDimB || + iteratorTypes[redDimA.getPosition()] != + mlir::utils::IteratorType::reduction) + return false; + auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2)); + if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] != + mlir::utils::IteratorType::parallel) + return false; + + // VNNI factor must be: + // - the innermost inputs' dimension + // - statically known + // - multiple of 2 or equal to the specified factor + auto vnniDimSize = typeB.getShape().back(); + if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 || + vnniDimSize % 2 != 0) + return false; + if (typeA.getShape().back() != vnniDimSize) + return false; + if (blockingFactor && vnniDimSize != *blockingFactor) + return false; + + // The split reduction dimension size should also match. + if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3]) + return false; + + return true; +} + +// Implements packed type outer product contraction as a sequence +// of broadcast and packed dot-product operations. +// +// For example - for F32 type: +// ``` +// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32> +// ``` +// to +// ``` +// vector.broadcast %lhs to <32xbf16> +// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32> +// ``` +struct VectorContractToPackedTypeDotProduct + : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind."); + + VectorType lhsTy = contractOp.getLhsType(); + if (!lhsTy.getElementType().isBF16() && + !lhsTy.getElementType().isSignlessInteger(8)) + return rewriter.notifyMatchFailure( + contractOp, "Only BF16/Int8 lowering is supported."); + + unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4; + if (!isInVnniLayout(contractOp.getOperation(), + contractOp.getIndexingMapsArray(), blockingFactor)) + return rewriter.notifyMatchFailure(contractOp, + "Input matrices not in VNNI format."); + + ArrayRef<int64_t> lhsShape = lhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimLhs; + llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs), + [](int64_t dim) { return dim != 1; }); + + VectorType rhsTy = contractOp.getRhsType(); + ArrayRef<int64_t> rhsShape = rhsTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimRhs; + llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs), + [](int64_t dim) { return dim != 1; }); + + if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0) + return rewriter.notifyMatchFailure(contractOp, + "Excepts unit dimensions for either " + "LHS or RHS shape other than VNNI."); + + if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1) + return rewriter.notifyMatchFailure( + contractOp, + "Excepts a one non-unit A/B dimension for either LHS or RHS shape."); + + VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType()); + if (!accTy) + return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type."); + + if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) || + (lhsTy.getElementType().isSignlessInteger(8) && + !accTy.getElementType().isSignlessInteger(32))) + return rewriter.notifyMatchFailure(contractOp, + "Only F32 for BF16 or Int32 for Int8 " + "accumulation type is supported."); + + ArrayRef<int64_t> accShape = accTy.getShape(); + llvm::SmallVector<int64_t> nonUnitDimAcc; + llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc), + [](int64_t dim) { return dim != 1; }); + if (nonUnitDimAcc.size() != 1) + return rewriter.notifyMatchFailure( + contractOp, "A or B should be a non-unit dim in acc."); + + // Non-unit dimensions should match the vector length of BF16 or Int8 + // dot-product. + unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front() + : nonUnitDimRhs.front(); + if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 && + nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "BF16 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8/16."); + + if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 && + nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim) + return rewriter.notifyMatchFailure( + contractOp, "Int8 dot-product operation expects non-unit (LHR or " + "RHS) dim and acc dim of size 4/8."); + + auto loc = contractOp.getLoc(); + auto castAcc = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()), + contractOp.getAcc()); + + Value dp; + + // Broadcast the unit-dimension LHS or RHS to match the vector length of the + // corresponding non-unit dimension on the other operand. For example, + // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>, + // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit + // dimension on the LHS), we broadcast the RHS instead. + if ((nonUnitDimRhs.size() - 1) > 0) { + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(), + rhsTy.getElementType()), + contractOp.getRhs()); + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()), + contractOp.getLhs()); + auto bitcastLhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castLhs); + auto broadcastLhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)), + bitcastLhs); + auto bitcastLhsPkType = vector::BitCastOp::create( + rewriter, loc, castRhs.getResult().getType(), broadcastLhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()), + castAcc, bitcastLhsPkType, castRhs); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)), + castAcc, bitcastLhsPkType, castRhs); + } + } else { + auto castLhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(), + lhsTy.getElementType()), + contractOp.getLhs()); + auto castRhs = vector::ShapeCastOp::create( + rewriter, loc, + VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()), + contractOp.getRhs()); + auto bitcastRhs = vector::BitCastOp::create( + rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)), + castRhs); + auto broadcastRhs = vector::BroadcastOp::create( + rewriter, loc, + VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)), + bitcastRhs); + auto bitcastRhsPkType = vector::BitCastOp::create( + rewriter, loc, castLhs.getResult().getType(), broadcastRhs); + + if (lhsTy.getElementType().isBF16()) { + dp = x86vector::DotBF16Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()), + castAcc, castLhs, bitcastRhsPkType); + } + + if (lhsTy.getElementType().isSignlessInteger(8)) { + dp = x86vector::DotInt8Op::create( + rewriter, loc, + VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)), + castAcc, castLhs, bitcastRhsPkType); + } + } + + if (!dp) + return failure(); + + auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp); + rewriter.replaceOp(contractOp, castDp); + return success(); + } +}; + +} // namespace + +void x86vector::populateVectorContractToPackedTypeDotProductPatterns( + RewritePatternSet &patterns) { + patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index fb5d1e7..1a19ab5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" @@ -61,7 +60,7 @@ genCoordinates(OpBuilder &builder, Location loc, // Get the offset of `subShape` within a distribution unit. SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value { - return builder.createOrFold<index::MulOp>( + return builder.createOrFold<arith::MulIOp>( loc, std::get<0>(t), builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t))); }); @@ -84,7 +83,7 @@ genCoordinates(OpBuilder &builder, Location loc, // Do not go beyond `srcShape` bounds. SmallVector<Value> mods = llvm::map_to_vector( llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value { - return builder.createOrFold<index::RemUOp>( + return builder.createOrFold<arith::RemUIOp>( loc, std::get<0>(t), arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); }); @@ -343,7 +342,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within /// this dimension) result[dimIdx] = - builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal); + builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal); /// Update remaining for the next dimension by removing what we've already /// processed. Division tells us "how many complete groups of this dimension @@ -352,7 +351,7 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { /// no next dimension to process if (i < order.size() - 1) { remaining = - builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal); + builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal); } } return result; @@ -391,6 +390,86 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, return genCoordinates(builder, loc, ids, layout, subShape, shape); } +bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::SliceAttr>(other)) + return false; + + return *this == dyn_cast<xegpu::LayoutAttr>(other); +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) { + auto sgDataOpt = getSgData(); + auto instDataOpt = getInstData(); + auto laneDataOpt = getLaneData(); + + SmallVector<int32_t> sgData; + SmallVector<int32_t> instData; + SmallVector<int32_t> laneData; + + if (sgDataOpt) { + sgData = llvm::to_vector(sgDataOpt.asArrayRef()); + } + if (instDataOpt) { + instData = llvm::to_vector(instDataOpt.asArrayRef()); + } + if (laneDataOpt) { + laneData = llvm::to_vector(laneDataOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgData.size())) + sgData[dim] = 1; + if (dim < static_cast<int64_t>(instData.size())) + instData[dim] = 1; + if (dim < static_cast<int64_t>(laneData.size())) + laneData[dim] = 1; + } + + return LayoutAttr::get( + getContext(), getSgLayout(), + sgData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgData), + instData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), instData), + getLaneLayout(), + laneData.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneData), + getOrder()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + auto sgLayoutOpt = getSgLayout(); + auto laneLayoutOpt = getLaneLayout(); + + SmallVector<int32_t> sgLayout; + SmallVector<int32_t> laneLayout; + + if (sgLayoutOpt) { + sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef()); + } + if (laneLayoutOpt) { + laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef()); + } + + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(sgLayout.size())) + sgLayout[dim] = 1; + if (dim < static_cast<int64_t>(laneLayout.size())) + laneLayout[dim] = 1; + } + + return LayoutAttr::get( + getContext(), + sgLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), sgLayout), + getSgData(), getInstData(), + laneLayout.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(getContext(), laneLayout), + getLaneData(), getOrder()); +} + //===----------------------------------------------------------------------===// // XeGPU_SliceAttr //===----------------------------------------------------------------------===// @@ -511,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { [&](int64_t dim) { return thisDims.contains(dim); }); } +bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast<xegpu::LayoutAttr>(other)) + return false; + + auto flattenedThis = flatten(); + auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten(); + + return ((flattenedThis.getParent() == flattenedOther.getParent()) && + (flattenedThis.getDims() == flattenedOther.getDims())); +} + +// Helper function to adjust unit dimensions from sliced space to parent space +static SetVector<int64_t> +adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims, + ArrayRef<int64_t> sliceDims) { + // Reconstruct parent's non-sliced dimensions + + int64_t parentRank = sliceDims.size() + unitDims.size(); + llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(), + sliceDims.end()); + SmallVector<int64_t> nonSlicedDims; + for (int64_t i = 0; i < parentRank; ++i) { + if (!slicedDimsSet.contains(i)) + nonSlicedDims.push_back(i); + } + + // Map unit dims from sliced space to parent space + SetVector<int64_t> adjustUnitDims; + for (auto dim : unitDims) { + if (dim < static_cast<int64_t>(nonSlicedDims.size())) { + adjustUnitDims.insert(nonSlicedDims[dim]); + } + } + + return adjustUnitDims; +} + +// set the layout for unit dims: sg_data, inst_data and lane_data to 1 +DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims), + attr.getDims()); +} + +// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1 +DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) { + SliceAttr attr = flatten(); + ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef(); + auto parent = dyn_cast<LayoutAttr>(attr.getParent()); + + SetVector<int64_t> adjustUnitDims = + adjustUnitDimsWithSliceDims(unitDims, sliceDims); + + return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims), + attr.getDims()); +} + //===----------------------------------------------------------------------===// // XeGPU_RangeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 4dd10be..91ba07a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -465,14 +465,15 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -480,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, - l2_hint, l3_hint); + l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult PrefetchNdOp::verify() { @@ -519,7 +520,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, - l3_hint); + l3_hint, /*anchor_layout=*/nullptr); } void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, @@ -527,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -535,7 +537,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, - packed, transpose, l1_hint, l2_hint, l3_hint); + packed, transpose, l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/layout); } LogicalResult LoadNdOp::verify() { @@ -638,14 +641,16 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, xegpu::CachePolicyAttr l3_hint) { return build(builder, state, value, tensorDesc, ValueRange(), - DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); + DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint, + /*anchor_layout=*/nullptr); } void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, ArrayRef<OpFoldResult> offsets, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint) { + xegpu::CachePolicyAttr l3_hint, + xegpu::DistributeLayoutAttr layout) { SmallVector<Value> dynamicOffsets; SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); @@ -653,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, - l1_hint, l2_hint, l3_hint); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout); } LogicalResult StoreNdOp::verify() { @@ -826,7 +831,7 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, - IntegerAttr{}); + IntegerAttr{}, /*anchor_layout=*/nullptr); } //===----------------------------------------------------------------------===// @@ -876,7 +881,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, valueType, source, Value(), mask, IntegerAttr(), - l1_hint, l2_hint, l3_hint, /*layout=*/nullptr); + l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -892,7 +897,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, auto offset = vector::FromElementsOp::create(builder, loc, type, values); build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, - l2_hint, l3_hint, /*layout=*/nullptr); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, @@ -901,7 +906,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint, - xegpu::LayoutAttr layout) { + DistributeLayoutAttr layout) { auto loc = source.getLoc(); int64_t size = static_cast<int64_t>(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); @@ -960,7 +965,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, - l2_hint, l3_hint, /*layout=*/nullptr); + l2_hint, l3_hint, /*anchor_layout=*/nullptr); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, @@ -978,14 +983,14 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, // Call the correct builder overload that does not expect result types. build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, - l3_hint, /*layout=*/nullptr); + l3_hint, /*anchor_layout=*/nullptr); } void StoreScatterOp::build( OpBuilder &builder, OperationState &state, Value value, Value dest, ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, - xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) { + xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) { auto loc = dest.getLoc(); int64_t size = static_cast<int64_t>(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 5fdd853..e6009d5 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" @@ -165,7 +167,8 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, /// Replace xegpu.create_nd_desc op with a new one with the given layout. static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, - xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) { + xegpu::CreateNdDescOp descOp, + xegpu::DistributeLayoutAttr layout) { assert(descOp.getMixedOffsets().size() == 0 && "create desc op with offsets is not supported"); auto oldTensorDesc = descOp.getType(); @@ -210,7 +213,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, - ArrayRef<OpFoldResult> mixedInstData) { + ArrayRef<OpFoldResult> mixedInstData, + ArrayRef<int64_t> sliceDims) { SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); @@ -223,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder, /*inst_data=*/dynamicInstData, /*static_sg_layout=*/staticSgLayout, /*static_sg_data=*/staticSgData, - /*static_inst_data=*/staticInstData); + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims); } DiagnosedSilenceableFailure @@ -244,6 +249,14 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, if (!status.succeeded()) return status; + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + // For now only create_nd_desc op is supported. auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); if (!descOp) { @@ -255,7 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, } // Set layout attr in desc op's return type. Replaces old desc op. - auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); + auto newdescOp = setDescLayout(rewriter, descOp, layout); // Map result handles. results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); @@ -276,7 +289,8 @@ void transform::SetDescLayoutOp::getEffects( void transform::SetOpLayoutAttrOp::build( OpBuilder &builder, OperationState &ostate, Value target, int64_t index, ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, - ArrayRef<OpFoldResult> mixedInstData, bool result) { + ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims, + bool result) { SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); @@ -291,6 +305,7 @@ void transform::SetOpLayoutAttrOp::build( /*static_sg_layout=*/staticSgLayout, /*static_sg_data=*/staticSgData, /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims, /*result=*/result); } @@ -324,11 +339,19 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, if (!status.succeeded()) return status; + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + // Set layout attribute for the op result or operand if (resultTarget) - xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr); + xegpu::setDistributeLayoutAttr(target->getResult(index), layout); else - xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr); + xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout); return DiagnosedSilenceableFailure::success(); } @@ -341,6 +364,305 @@ void transform::SetOpLayoutAttrOp::getEffects( modifiesPayload(effects); } +void transform::SetGPULaunchThreadsOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedThreads) { + SmallVector<int64_t> staticThreads; + SmallVector<Value> dynamicThreads; + dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads); + build(builder, ostate, target.getType(), + /*target=*/target, + /*threads=*/dynamicThreads, + /*static_threads=*/staticThreads); +} + +DiagnosedSilenceableFailure +transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + auto launchOp = dyn_cast<gpu::LaunchOp>(target); + if (!launchOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a gpu.launch op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + SmallVector<int32_t> threads; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), threads, getMixedThreads()); + if (!status.succeeded()) + return status; + + if (threads.size() != 3) { + return emitSilenceableFailure(getLoc()) + << "Expected threads argument to consist of three values (got " + << threads.size() << ")"; + } + + rewriter.setInsertionPoint(launchOp); + auto createConstValue = [&](int value) { + return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value); + }; + + // Replace threads in-place. + launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); + launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); + launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetGPULaunchThreadsOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getThreadsMutable(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + int64_t nbPrefetch = getStaticNbPrefetch(); + if (getDynamicNbPrefetch()) { + // Get dynamic prefetch count from transform param or handle. + SmallVector<int32_t> dynamicNbPrefetch; + auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch, + {getDynamicNbPrefetch()}); + if (!status.succeeded()) + return status; + if (dynamicNbPrefetch.size() != 1) + return emitDefiniteFailure() + << "requires exactly one value for dynamic_nb_prefetch"; + nbPrefetch = dynamicNbPrefetch[0]; + } + if (nbPrefetch <= 0) + return emitSilenceableFailure(getLoc()) + << "nb_prefetch must be a positive integer."; + + // Find load operation of the operand. + auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value); + if (!maybeLoadOp) + return emitSilenceableFailure(getLoc()) << "Could not find load op."; + auto loadOp = *maybeLoadOp; + if (loadOp.getMixedOffsets().size() == 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op must have offsets."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find the parent scf.for loop. + auto forOp = loadOp->getParentOfType<scf::ForOp>(); + if (!forOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op is not contained in a scf.for loop."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find descriptor op. + auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value); + if (!maybeDescOp) + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + auto descOp = *maybeDescOp; + if (descOp.getMixedOffsets().size() > 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "desc op with offsets is not supported."; + diag.attachNote(descOp.getLoc()) << "desc op"; + } + + // Clone desc op outside the loop. + rewriter.setInsertionPoint(forOp); + auto newDescOp = + cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation())); + + // Clone reduction loop to emit initial prefetches. + // Compute upper bound of the init loop: start + nbPrefetch * step. + auto nbPrefetchCst = + arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch); + auto nbStep = rewriter.createOrFold<arith::MulIOp>( + forOp.getLoc(), nbPrefetchCst, forOp.getStep()); + auto initUpBound = rewriter.createOrFold<arith::AddIOp>( + forOp.getLoc(), forOp.getLowerBound(), nbStep); + auto initForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + initUpBound, forOp.getStep()); + + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + + // Modify loadOp mixedOffsets by replacing the for loop induction variable + // with the given value. + auto getPrefetchOffsets = + [&](Value replacementVal) -> SmallVector<OpFoldResult> { + IRMapping mapping; + mapping.map(forOp.getInductionVar(), replacementVal); + SmallVector<Value> dynamicOffsets = + llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) { + return mapping.lookupOrDefault(v); + })); + auto constOffsets = loadOp.getConstOffsets().value(); + return getMixedValues(constOffsets, dynamicOffsets, ctx); + }; + + // Insert prefetch op in init loop. + // Replace induction var with the init loop induction var. + rewriter.setInsertionPointToStart(initForOp.getBody()); + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(initForOp.getInductionVar()), + readCacheHint, readCacheHint, readCacheHint, + /*layout=*/nullptr); + + // Insert prefetch op in main loop. + // Calculate prefetch offset after the init prefetches have been issued. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(), + forOp.getInductionVar(), nbStep); + // Replace induction var with correct offset. + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(prefetchOffset), readCacheHint, + readCacheHint, readCacheHint, /*layout=*/nullptr); + + // Unroll the init loop. + if (failed(loopUnrollFull(initForOp))) + return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; + + results.set(llvm::cast<OpResult>(getResult()), {newDescOp}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::InsertPrefetchOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getDynamicNbPrefetchMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +void transform::ConvertLayoutOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef<OpFoldResult> mixedInputSgLayout, + ArrayRef<OpFoldResult> mixedInputSgData, + ArrayRef<OpFoldResult> mixedInputInstData, + ArrayRef<OpFoldResult> mixedTargetSgLayout, + ArrayRef<OpFoldResult> mixedTargetSgData, + ArrayRef<OpFoldResult> mixedTargetInstData) { + SmallVector<int64_t> staticInputSgLayout, staticInputSgData, + staticInputInstData; + SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData, + dynamicInputInstData; + dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout, + staticInputSgLayout); + dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData, + staticInputSgData); + dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData, + staticInputInstData); + SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData, + staticTargetInstData; + SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData, + dynamicTargetInstData; + dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout, + staticTargetSgLayout); + dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData, + staticTargetSgData); + dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData, + staticTargetInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*input_sg_layout=*/dynamicInputSgLayout, + /*input_sg_data=*/dynamicInputSgData, + /*input_inst_data=*/dynamicInputInstData, + /*target_sg_layout=*/dynamicTargetSgLayout, + /*target_sg_data=*/dynamicTargetSgData, + /*target_inst_data=*/dynamicTargetInstData, + /*static_input_sg_layout=*/staticInputSgLayout, + /*static_input_sg_data=*/staticInputSgData, + /*static_input_inst_data=*/staticInputInstData, + /*static_target_sg_layout=*/staticTargetSgLayout, + /*static_target_sg_data=*/staticTargetSgData, + /*static_target_inst_data=*/staticTargetInstData); +} + +DiagnosedSilenceableFailure +transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + auto value = *targetValues.begin(); + + // Construct layout attributes. + xegpu::LayoutAttr inputLayoutAttr = nullptr; + auto status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedInputSgLayout(), + getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr); + if (!status.succeeded()) + return status; + + xegpu::LayoutAttr targetLayoutAttr = nullptr; + status = getLayoutAttrFromOperands( + getContext(), state, (*this), getMixedTargetSgLayout(), + getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr); + if (!status.succeeded()) + return status; + + // Find first user op to define insertion point for layout conversion. + if (value.use_empty()) + return emitSilenceableFailure(getLoc()) + << "Value has no users to insert layout conversion."; + Operation *userOp = *value.getUsers().begin(); + + // Emit convert_layout op. + rewriter.setInsertionPoint(userOp); + auto convLayoutOp = + xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(), + value, inputLayoutAttr, targetLayoutAttr); + // Replace load op result with the converted layout. + rewriter.replaceUsesWithIf( + value, convLayoutOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != convLayoutOp.getOperation(); + }); + + results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp}); + return DiagnosedSilenceableFailure::success(); +} + +void transform::ConvertLayoutOp::getEffects( + ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getInputSgLayoutMutable(), effects); + onlyReadsHandle(getInputSgDataMutable(), effects); + onlyReadsHandle(getInputInstDataMutable(), effects); + onlyReadsHandle(getTargetSgLayoutMutable(), effects); + onlyReadsHandle(getTargetSgDataMutable(), effects); + onlyReadsHandle(getTargetInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + namespace { class XeGPUTransformDialectExtension : public transform::TransformDialectExtension< diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp index 4dc5ea4..ab41fe4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp @@ -214,7 +214,7 @@ static Value generateLoads(ConversionPatternRewriter &rewriter, newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY}, origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(), origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(), - origLoadOp.getL3HintAttr()); + origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr()); // Set the layout for the loadOp. auto layoutAttr = newTensorDesc.getType().getLayoutAttr(); xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index b3a780a..dc9eb96 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -387,6 +387,8 @@ private: ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results); + bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout); + public: LayoutInfoPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable, @@ -475,49 +477,71 @@ LogicalResult LayoutInfoPropagation::visitOperation( return success(); } +bool LayoutInfoPropagation::hasParamsOfLayoutKind( + xegpu::DistributeLayoutAttr anchorLayout) { + if (anchorLayout == nullptr) { + return false; + } + if (layoutKind == LayoutKind::InstData) { + return !(anchorLayout.getEffectiveInstDataAsInt().empty()); + } else if (layoutKind == LayoutKind::Lane) { + return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() || + anchorLayout.getEffectiveLaneDataAsInt().empty()); + } + return false; +} + void LayoutInfoPropagation::visitPrefetchNdOp( xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Here we assign the default layout to the tensor descriptor operand of - // prefetch. - auto tdescTy = prefetch.getTensorDescType(); - - auto uArch = getUArch(getChipStr(prefetch).value_or("")); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( - uArch->getInstruction( - xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); - - auto blockWHC = - uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); - if (!blockWHC) - prefetch.emitWarning("No known block params found for the element type."); - auto [bWidth, bHeight, bCount] = blockWHC.value(); - SmallVector<int> instData; - int instWidth = xegpu::getLargestDivisor( - static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth, - bCount); - if (instWidth == -1) - prefetch.emitWarning( - "No suitable instruction multiple found for the given shape."); - if (tdescTy.getRank() == 1) - instData = {instWidth}; - else { - int instHeight = xegpu::getLargestDivisor( - static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); - if (instHeight == -1) + + LayoutInfo prefetchLayout; + xegpu::DistributeLayoutAttr anchorLayout = prefetch.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + prefetchLayout = LayoutInfo(anchorLayout); + } else { + // Here we assign the default layout to the tensor descriptor operand of + // prefetch. + auto tdescTy = prefetch.getTensorDescType(); + + auto uArch = getUArch(getChipStr(prefetch).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); + + auto blockWHC = + uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); + if (!blockWHC) + prefetch.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth); + if (instWidth == -1) prefetch.emitWarning( "No suitable instruction multiple found for the given shape."); - instData = {instHeight, instWidth}; - } - LayoutInfo prefetchLayout; - if (layoutKind == LayoutKind::InstData) - prefetchLayout = - LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); - else - prefetchLayout = getDefaultSIMTLayoutInfo( - tdescTy, uArch, uArchInstruction->getPackedFormatBitSize()); + if (tdescTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); + if (instHeight == -1) + prefetch.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + + if (layoutKind == LayoutKind::InstData) + prefetchLayout = + LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); + else + prefetchLayout = getDefaultSIMTLayoutInfo( + tdescTy, uArch, uArchInstruction->getPackedFormatBitSize()); + prefetch.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get())); + } // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } @@ -556,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp( // Only consider vector to vector broadcasts for now. VectorType resultTy = broadcast.getResultVectorType(); VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType()); - if (!sourceTy) { - broadcast.emitWarning("Expecting source type to be a vector type."); + // skip layout propagation for non-vector source operand. + if (!sourceTy) return; - } - // Only consider nD -> nD broadcast. + // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case. if (sourceTy.getRank() != resultTy.getRank()) { - broadcast.emitWarning("Expecting source and result to have same rank."); + auto sourceDims = sourceTy.getShape(); + auto resultDims = resultTy.getShape(); + SmallVector<int64_t> bcastDims; + auto dimDiff = resultTy.getRank() - sourceTy.getRank(); + // adding the missing leading dims + for (int i = 0; i < dimDiff; i++) + bcastDims.push_back(i); + + // for the rest dims in the resultTy, if sourceTy dim is 1, then it's + // broadcasted dim + for (size_t i = 0; i < sourceDims.size(); i++) + if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1)) + bcastDims.push_back(i + dimDiff); + + // create a slice layout for the source + xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get( + broadcast->getContext(), + cast<xegpu::DistributeLayoutAttr>(resultLayout.get()), + DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims)); + + propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout))); return; } + SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims(); - if (broadcastUnitDims.size() != 1) { - broadcast.emitWarning("Expecting source type to be nD vector only with " - "one broadcasted dimension."); - return; - } - // Propagate the result layout to the source operand. + resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get()) + .setUnitDimData(broadcastUnitDims); propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); } @@ -617,70 +657,97 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp( void LayoutInfoPropagation::visitDpasOp( xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - VectorType aTy = dpas.getLhsType(); - VectorType bTy = dpas.getRhsType(); - - auto uArch = getUArch(getChipStr(dpas).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( - xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); - - const unsigned dataALen = aTy.getShape().front(); - auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); - const int maxALen = - xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); - if (maxALen == -1) - dpas.emitWarning( - "No suitable instruction multiple found for the given shape."); - - const unsigned dataBLen = bTy.getShape().back(); - auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType()); - const int maxBLen = - xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); - if (maxBLen == -1) - dpas.emitWarning( - "No suitable instruction multiple found for the given shape."); - SmallVector<int> instDataA = {maxALen, subgroupSize}; - SmallVector<int> instDataB = {subgroupSize, maxBLen}; LayoutInfo dpasALayout; LayoutInfo dpasBLayout; - LayoutInfo dpasCLayout; + LayoutInfo dpasCDLayout; + + xegpu::DistributeLayoutAttr anchorLayoutCD = dpas.getLayoutCdAttr(); + if (hasParamsOfLayoutKind(anchorLayoutCD)) { + xegpu::DistributeLayoutAttr anchorLayoutA = dpas.getLayoutAAttr(); + xegpu::DistributeLayoutAttr anchorLayoutB = dpas.getLayoutBAttr(); + assert(hasParamsOfLayoutKind(anchorLayoutA) && + "Expected anchor layout for DPAS A operand."); + assert(hasParamsOfLayoutKind(anchorLayoutB) && + "Expected anchor layout for DPAS B operand."); + dpasALayout = LayoutInfo(anchorLayoutA); + dpasBLayout = LayoutInfo(anchorLayoutB); + dpasCDLayout = LayoutInfo(anchorLayoutCD); - if (layoutKind == LayoutKind::InstData) { - dpasALayout = - LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA)); - dpasBLayout = - LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB)); } else { - dpasALayout = getSIMTLayoutInfoForDPASOperand( - aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA()); - dpasBLayout = getSIMTLayoutInfoForDPASOperand( - bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB()); - } - propagateIfChanged(operands[0], operands[0]->meet(dpasALayout)); - propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout)); - if (operands.size() > 2) { - VectorType cTy = dpas.getAccType(); - const unsigned dataCLen = bTy.getShape().back(); - auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType()); - const int maxCLen = - xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen)); - if (maxCLen == -1) + VectorType aTy = dpas.getLhsType(); + VectorType bTy = dpas.getRhsType(); + + auto uArch = getUArch(getChipStr(dpas).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( + xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); + + const unsigned dataALen = aTy.getShape().front(); + auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); + const int maxALen = + xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); + if (maxALen == -1) dpas.emitWarning( "No suitable instruction multiple found for the given shape."); - SmallVector<int> instDataC = {maxALen, maxCLen}; - if (layoutKind == LayoutKind::InstData) - dpasCLayout = - LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC)); - else - dpasCLayout = getSIMTLayoutInfoForDPASOperand( - cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB()); + const unsigned dataBLen = bTy.getShape().back(); + auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType()); + + const int maxBLen = + xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); + + if (maxBLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataA = {maxALen, subgroupSize}; + SmallVector<int> instDataB = {subgroupSize, maxBLen}; + + if (layoutKind == LayoutKind::InstData) { + dpasALayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA)); + dpasBLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB)); + } else { + dpasALayout = getSIMTLayoutInfoForDPASOperand( + aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA()); + dpasBLayout = getSIMTLayoutInfoForDPASOperand( + bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB()); + } + + if (operands.size() > 2) { + VectorType cTy = dpas.getAccType(); + if (layoutKind == LayoutKind::InstData) { + const unsigned dataCLen = bTy.getShape().back(); + auto supportedCLen = + uArchInstruction->getSupportedN(bTy.getElementType()); + const int maxCLen = xegpu::getLargestDivisor( + dataCLen, ArrayRef<unsigned>(supportedCLen)); + if (maxCLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataC = {maxALen, maxCLen}; + dpasCDLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC)); + } else + dpasCDLayout = getSIMTLayoutInfoForDPASOperand( + cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB()); + + dpas.setLayoutCdAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get())); + } + dpas.setLayoutAAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get())); + dpas.setLayoutBAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get())); + } - propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout)); + propagateIfChanged(operands[0], operands[0]->meet(dpasALayout)); + propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout)); + if (operands.size() > 2) { + propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout)); } } @@ -689,43 +756,50 @@ void LayoutInfoPropagation::visitStoreNdOp( xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - auto uArch = getUArch(getChipStr(store).value_or("")); - const auto *uArchInstruction = - dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( - uArch->getInstruction( - xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); - VectorType dataTy = store.getValueType(); - auto blockWHC = uArchInstruction->getBlockWidthHeightCount( - store.getValueType().getElementType()); - if (!blockWHC) - store.emitWarning("No known block params found for the element type."); - auto [bWidth, bHeight, bCount] = blockWHC.value(); - SmallVector<int> instData; - int instWidth = xegpu::getLargestDivisor( - static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth, - bCount); - if (instWidth == -1) - store.emitWarning( - "No suitable instruction multiple found for the given shape."); - if (dataTy.getRank() == 1) - instData = {instWidth}; - else { - int instHeight = xegpu::getLargestDivisor( - static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); - if (instHeight == -1) + LayoutInfo storeLayout; + xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + storeLayout = LayoutInfo(anchorLayout); + } else { + auto uArch = getUArch(getChipStr(store).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); + VectorType dataTy = store.getValueType(); + auto blockWHC = uArchInstruction->getBlockWidthHeightCount( + store.getValueType().getElementType()); + if (!blockWHC) + store.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth); + if (instWidth == -1) store.emitWarning( "No suitable instruction multiple found for the given shape."); - instData = {instHeight, instWidth}; - } + if (dataTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = xegpu::getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); + if (instHeight == -1) + store.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } - LayoutInfo storeLayout; - if (layoutKind == LayoutKind::InstData) - storeLayout = - LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); - else - storeLayout = - getDefaultSIMTLayoutInfo(store.getValueType(), uArch, - uArchInstruction->getPackedFormatBitSize()); + if (layoutKind == LayoutKind::InstData) + storeLayout = + LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); + else + storeLayout = + getDefaultSIMTLayoutInfo(store.getValueType(), uArch, + uArchInstruction->getPackedFormatBitSize()); + store.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get())); + } + // Propagate the layout to the value operand. // Both operands should have the same layout for (LayoutInfoLattice *operand : operands) propagateIfChanged(operand, operand->meet(storeLayout)); @@ -736,21 +810,30 @@ void LayoutInfoPropagation::visitStoreNdOp( void LayoutInfoPropagation::visitLoadNdOp( xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - LayoutInfo valueLayout = results[0]->getValue(); - // Need the layout of the value to propagate to the tensor descriptor. - if (!valueLayout.isAssigned()) - return; - LayoutInfo tensorDescLayout = valueLayout; - // LoadNdOp has the transpose effect. However, at the stage of this analysis - // this effect is not expected and should be abstracted away. Emit a - // warning. - if (auto transpose = load.getTranspose()) { - load.emitWarning("Transpose effect is not expected for LoadNdOp at " - "LayoutInfoPropagation stage."); - tensorDescLayout = valueLayout.transpose(transpose.value()); + + LayoutInfo loadLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + } else { + + LayoutInfo valueLayout = results[0]->getValue(); + // Need the layout of the value to propagate to the tensor descriptor. + if (!valueLayout.isAssigned()) + return; + loadLayout = valueLayout; + // LoadNdOp has the transpose effect. However, at the stage of this analysis + // this effect is not expected and should be abstracted away. Emit a + // warning. + if (auto transpose = load.getTranspose()) { + load.emitWarning("Transpose effect is not expected for LoadNdOp at " + "LayoutInfoPropagation stage."); + loadLayout = valueLayout.transpose(transpose.value()); + } + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); } // Propagate the new layout to the tensor descriptor operand. - propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); } /// For vector::TransposeOp, the layout of the result is transposed and @@ -840,37 +923,48 @@ void LayoutInfoPropagation::visitVectorBitcastOp( void LayoutInfoPropagation::visitLoadGatherOp( xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // The layout is strictly determined by the payload type. - auto payloadTy = dyn_cast<VectorType>(load.getValueType()); - if (!payloadTy) { - load.emitWarning("Not propagating, non-vector payload supplied."); - return; - } - auto uArch = getUArch(getChipStr(load).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - SmallVector<int> instData{subgroupSize}; - if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) - instData.push_back(chunkSize); - else if (auto srcTdescTy = - dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { - if (srcTdescTy.getChunkSizeAsInt() > 1) + + LayoutInfo loadLayout; + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + loadLayout = LayoutInfo(anchorLayout); + maskLayout = loadLayout; + } else { + + // The layout is strictly determined by the payload type. + VectorType payloadTy = load.getValueType(); + if (!payloadTy) { + load.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + auto uArch = getUArch(getChipStr(load).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) instData.push_back(chunkSize); - } - LayoutInfo layout; - if (layoutKind == LayoutKind::InstData) - layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData)); - else - layout = getDefaultSIMTLayoutInfo(payloadTy, uArch, - uArch->getGeneralPackedFormatBitSize(), - /*scattered*/ true); - - // Mask operand should have 1D default layout. - LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); + else if (auto srcTdescTy = + dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { + if (srcTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + + if (layoutKind == LayoutKind::InstData) + loadLayout = + LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData)); + else + loadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered*/ true); + // Mask operand should have 1D default layout. + maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); + + load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get())); + } // Propagate the new layout to the tensor descriptor operand. if (isa<xegpu::TensorDescType>(load.getSourceType())) - propagateIfChanged(operands[0], operands[0]->meet(layout)); + propagateIfChanged(operands[0], operands[0]->meet(loadLayout)); // Propagate the new layout to the mask and optional offset operand. propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); if (load.getOffsets()) @@ -898,21 +992,26 @@ void LayoutInfoPropagation::visitCreateDescOp( void LayoutInfoPropagation::visitStoreScatterOp( xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - // Currently, for 2D StoreScatterOp we expect that the height dimension of - // the tensor descriptor is equal to the subgroup size. This is ensured by - // the op verifier. - auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType()); - if (!payloadTy) { - storeScatter.emitWarning("Not propagating, non-vector payload supplied."); - return; - } - LayoutInfo payloadLayout; - auto uArch = getUArch(getChipStr(storeScatter).value_or("")); - const int subgroupSize = uArch->getSubgroupSize(); - if (auto layout = storeScatter.getLayoutAttr()) { - payloadLayout = LayoutInfo(layout); + LayoutInfo payloadLayout; + LayoutInfo maskLayout; + xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr(); + if (hasParamsOfLayoutKind(anchorLayout)) { + payloadLayout = LayoutInfo(anchorLayout); + maskLayout = payloadLayout; } else { + // Currently, for 2D StoreScatterOp we expect that the height dimension of + // the tensor descriptor is equal to the subgroup size. This is ensured by + // the op verifier. + VectorType payloadTy = storeScatter.getValueType(); + if (!payloadTy) { + storeScatter.emitWarning("Not propagating, non-vector payload supplied."); + return; + } + + auto uArch = getUArch(getChipStr(storeScatter).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + if (layoutKind == LayoutKind::InstData) { SmallVector<int> instData{subgroupSize}; if (auto chunkSize = storeScatter.getChunkSize().value_or(0); @@ -936,10 +1035,13 @@ void LayoutInfoPropagation::visitStoreScatterOp( payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), /*scattered=*/true); } - } - LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); + maskLayout = + getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); + + storeScatter.setLayoutAttr( + dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get())); + } // Propagate the payload operand layout propagateIfChanged(operands[0], operands[0]->meet(payloadLayout)); // Propagate the destination (if tdesc) operand layout diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index bbd7733..ca81c3c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -99,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout, for (auto [i, dim] : llvm::enumerate(originalType.getShape())) { if (i < distributionStart) continue; - // Check if the dimension can be distributed evenly. if (dim % effectiveLaneLayout[i - distributionStart] != 0) return failure(); @@ -174,6 +173,21 @@ static bool requireTranspose(const xegpu::LayoutAttr layout, return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1; } +/// Given a vector type and its distributed vector type, return the list of +/// dimensions that are distributed. +static SmallVector<int64_t> getDistributedDims(VectorType originalType, + VectorType distributedType) { + assert(originalType.getRank() == distributedType.getRank() && + "sequential and distributed vector types must have the same rank"); + SmallVector<int64_t> distributedDims; + for (int64_t i = 0; i < originalType.getRank(); ++i) { + if (distributedType.getDimSize(i) != originalType.getDimSize(i)) { + distributedDims.push_back(i); + } + } + return distributedDims; +} + /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is /// contained within a WarpExecuteOnLane0Op. @@ -926,8 +940,7 @@ static SmallVector<Value> computeDistributedCoordinatesForMatrixOp( SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned( rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]), getAsOpFoldResult(origOffsets)); - newCoods = llvm::to_vector(llvm::map_range( - ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); })); + newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>); return newCoods; } @@ -990,9 +1003,8 @@ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern { SmallVector<Value> newOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()}; - std::fill(newConstOffsets.begin(), newConstOffsets.end(), - ShapedType::kDynamic); + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); DenseI64ArrayAttr newConstOffsetsAttr = rewriter.getDenseI64ArrayAttr(newConstOffsets); ValueRange currentOffsets = @@ -1067,9 +1079,8 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern { SmallVector<Value> newOperands = llvm::map_to_vector( newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); - SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()}; - std::fill(newConstOffsets.begin(), newConstOffsets.end(), - ShapedType::kDynamic); + SmallVector<int64_t> newConstOffsets(matrixOp.getConstOffsets().size(), + ShapedType::kDynamic); DenseI64ArrayAttr newConstOffsetsAttr = rewriter.getDenseI64ArrayAttr(newConstOffsets); ValueRange currentOffsets = @@ -1412,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern { } }; +/// This pattern distributes the `vector.broadcast` operation across lanes in a +/// warp. The pattern supports three use cases: +/// +/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input +/// vector +/// must have a slice layout of the result. If the distributed source and +/// target vector types are identical, this lowers to a no-op; otherwise, it +/// remains a broadcast but operates on distributed vectors. +/// +/// 2) Broadcast a same-rank vector with identical layouts for source and +/// target: +/// The source vector must have unit dimensions, and lane_data must be unit +/// size for those unit dims. This always lowers to a no-op. +/// +/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from +/// scalar to distributed result type. +/// +/// Example 1 (lowering to a broadcast with distributed types): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// %2 = vector.broadcast %0 {layout_result_0 = +/// #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>} +/// : vector<32xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [0]> } : () -> (vector<32xf32>) +/// gpu.yield %0 : vector<32xf32> +/// } +/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32> +/// +/// Example 2 (no-op): +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// %2 = vector.broadcast %1 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8x1xf32> to vector<8x32xf32> +/// gpu.yield %1 : vector<8x32xf32> +/// } +/// ``` +/// is lowered to: +/// ``` +/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) { +/// %0 = "some_def"() {layout_result_0 = +/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>, +/// dims = [1]> } : () -> (vector<8xf32>) +/// %1 = vector.shape_cast %0 +/// {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1, +/// 1]>}: vector<8xf32> to vector<8x1xf32> +/// gpu.yield %1 : vector<8x1xf32> +/// } +/// // The broadcast is implicit through layout transformation (no-op) +/// "some_use"(%r#0) +/// ``` +struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); + if (!yieldOperand) + return failure(); + auto broadcastOp = + cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp()); + unsigned operandIdx = yieldOperand->getOperandNumber(); + + VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType()); + VectorType destType = + dyn_cast<VectorType>(broadcastOp.getResult().getType()); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0)); + xegpu::DistributeLayoutAttr resultLayout = + xegpu::getDistributeLayoutAttr(broadcastOp.getResult()); + + FailureOr<VectorType> sourceDistType; + Type sourceElemOrDistType; + if (sourceType) { + + // Case 1 and 2: source is a vector type. + int64_t rankDiff = destType.getRank() - sourceType.getRank(); + if (rankDiff > 0) { + // Case 1: source is lower-rank than result. + bool isSliceOf = sourceLayout.isSliceOf(resultLayout); + if (!isSliceOf) + return rewriter.notifyMatchFailure( + warpOp, + "Broadcast input layout must be a slice of result layout."); + } + // case 2: source and result have same rank + if (rankDiff == 0) { + SetVector<int64_t> broadcastUnitDims = + broadcastOp.computeBroadcastedUnitDims(); + resultLayout = resultLayout.setUnitDimData(broadcastUnitDims); + bool isEqualTo = sourceLayout.isEqualTo(resultLayout); + if (!isEqualTo) + return rewriter.notifyMatchFailure( + warpOp, "For same-rank broadcast, source must be identical to " + "adjusted result layouts with unit dims."); + sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims); + } + + sourceDistType = + getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType); + if (failed(sourceDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the source vector type."); + } + sourceElemOrDistType = sourceDistType.value(); + + } else { + // Case 3: source is a scalar type. + if (sourceLayout) { + return rewriter.notifyMatchFailure( + warpOp, "Broadcast from scalar must not have a layout attribute."); + } + sourceElemOrDistType = broadcastOp.getSourceType(); + } + FailureOr<VectorType> destDistType = + getDistVecTypeBasedOnLaneLayout(resultLayout, destType); + if (failed(destDistType)) { + return rewriter.notifyMatchFailure( + warpOp, "Failed to distribute the dest vector type."); + } + + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType, + newRetIndices); + + Value distributedSource = newWarpOp.getResult(newRetIndices[0]); + + Value newBroadcast = distributedSource; + + if (sourceElemOrDistType != destDistType.value()) { + rewriter.setInsertionPointAfter(newWarpOp); + newBroadcast = + vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(), + destDistType.value(), distributedSource); + } + + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast); + return success(); + } +}; + /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing /// `gpu.warp_execute_on_lane_0` region. struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { @@ -1472,6 +1643,226 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern { } }; +// Distribute a `vector.extract_strided_slice` op feeding into yield op of an +// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +// advanced cases where the distributed dimension is partially extracted and +// currently not supported by the generic vector distribution patterns. +struct VectorExtractStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>); + if (!operand) + return failure(); + auto extractOp = + cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp()); + unsigned operandIdx = operand->getOperandNumber(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + // Find the distributed dimensions. + auto extractResultType = cast<VectorType>(operand->get().getType()); + auto distributedDims = + getDistributedDims(extractResultType, distributedType); + // Collect updated source type, sizes and offsets. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + VectorType updatedSourceType = extractOp.getSourceVectorType(); + SmallVector<Attribute> updatedSizes = llvm::map_to_vector( + extractOp.getSizes(), [](Attribute attr) { return attr; }); + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + extractOp.getOffsets(), [](Attribute attr) { return attr; }); + // If the result is distributed, it must be distributed in exactly one + // dimension. In this case, we adjust the sourceDistType, distributedSizes + // and distributedOffsets accordingly. + if (distributedDims.size() > 0) { + if (distributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, "Source can not be distributed in multiple dimensions."); + int64_t distributedDim = distributedDims[0]; + int sourceDistrDimSize = + extractOp.getSourceVectorType().getShape()[distributedDim]; + auto sourceLayout = + xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0)); + if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source of extract_strided_slice op lacks distribution " + "layout"); + auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt(); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = sourceLaneLayout[distributedDim]; + // Check if the source size in the distributed dimension is a multiple of + // subgroup size. + if (sourceDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Source size along distributed dimension is not a multiple of " + "subgroup size."); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + // We expect lane data to be all ones in this case. + if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source layout"); + // The offsets in the distributed dimention must be a multiple of subgroup + // size. + int64_t distrDimOffset = + cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt(); + if (distrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Offset along distributed dimension " + "is not a multiple of subgroup size."); + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, extractOp.getSourceVectorType()) + .value(); + // Update the distributed sizes to match the distributed type. + updatedSizes[distributedDim] = rewriter.getI64IntegerAttr( + distributedType.getDimSize(distributedDim)); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[distributedDim] = + rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source of the extract op from + // the warp op and creating a new extract op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value source = newWarpOp.getResult(newRetIndices[0]); + // Create a new extract op outside the warp op. + Value newExtractOp = vector::ExtractStridedSliceOp::create( + rewriter, extractOp.getLoc(), distributedType, source, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + ArrayAttr::get(rewriter.getContext(), updatedSizes), + extractOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newExtractOp); + return success(); + } +}; + +/// Distribute a `vector.insert_strided_slice` op feeding into yield op of an +/// enclosing `gpu.warp_execute_on_lane_0` region. This pattern covers +/// advanced cases where the distributed dimension is partially inserted and +/// currently not supported by the generic vector distribution patterns. +struct VectorInsertStridedSliceDistribution + : public gpu::WarpDistributionPattern { + using gpu::WarpDistributionPattern::WarpDistributionPattern; + LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = + getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto insertOp = + operand->get().getDefiningOp<vector::InsertStridedSliceOp>(); + auto distributedType = + cast<VectorType>(warpOp.getResult(operandNumber).getType()); + // Find the distributed dimensions of the dest vector. + auto insertResultType = cast<VectorType>(operand->get().getType()); + auto destDistributedDims = + getDistributedDims(insertResultType, distributedType); + // Collect updated offsets, source type and dest type. They may be adjusted + // later if the data is distributed to lanes (as opposed to being owned by + // all lanes uniformly). + SmallVector<Attribute> updatedOffsets = llvm::map_to_vector( + insertOp.getOffsets(), [](Attribute attr) { return attr; }); + VectorType updatedSourceType = insertOp.getSourceVectorType(); + VectorType updatedDestType = insertOp.getDestVectorType(); + if (destDistributedDims.size() > 0) { + // Only single dimension distribution is supported. + if (destDistributedDims.size() != 1) + return rewriter.notifyMatchFailure( + warpOp, + "Expecting source to be distributed in a single dimension."); + int64_t destDistributedDim = destDistributedDims[0]; + + VectorType srcType = insertOp.getSourceVectorType(); + VectorType destType = insertOp.getDestVectorType(); + // Currently we require that both source (kD) and dest (nD) vectors are + // distributed. This requires that distributedDim (d) is contained in the + // last k dims of the dest vector (d >= n - k). + int64_t sourceDistributedDim = + destDistributedDim - (destType.getRank() - srcType.getRank()); + if (sourceDistributedDim < 0) + return rewriter.notifyMatchFailure( + insertOp, + "distributed dimension must be in the last k (i.e. source " + "rank) dims of dest vector"); + int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim); + // Obtain the source and dest layouts. + auto destLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1)); + auto sourceLayout = + xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(0)); + if (!destLayout || !sourceLayout || + destLayout.getEffectiveLaneLayoutAsInt().empty() || + sourceLayout.getEffectiveLaneLayoutAsInt().empty()) + return rewriter.notifyMatchFailure( + warpOp, "the source or dest of insert_strided_slice op lacks " + "distribution layout"); + // Because only single dimension distribution is supported, lane layout + // size at the distributed dim must be the subgroup size. + int subgroupSize = + destLayout.getEffectiveLaneLayoutAsInt()[destDistributedDim]; + // We require that source and dest lane data are all ones to ensure + // uniform round robin distribution. + auto destLaneData = destLayout.getEffectiveLaneDataAsInt(); + auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt(); + if (!llvm::all_of(destLaneData, [](int64_t v) { return v == 1; }) || + !llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; })) + return rewriter.notifyMatchFailure( + warpOp, "Expecting unit lane data in source and dest layouts"); + // Source distributed dim size must be multiples of subgroup size. + if (srcDistrDimSize % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, "Distributed dimension size in source is not a multiple of " + "subgroup size."); + // Offsets in the distributed dimension must be multiples of subgroup + // size. + int64_t destDistrDimOffset = + cast<IntegerAttr>(insertOp.getOffsets()[destDistributedDim]).getInt(); + if (destDistrDimOffset % subgroupSize != 0) + return rewriter.notifyMatchFailure( + warpOp, + "Offset along distributed dimension in dest is not a multiple of " + "subgroup size."); + // Update the source and dest types based on their layouts. + updatedSourceType = getDistVecTypeBasedOnLaneLayout( + sourceLayout, insertOp.getSourceVectorType()) + .value(); + updatedDestType = getDistVecTypeBasedOnLaneLayout( + destLayout, insertOp.getDestVectorType()) + .value(); + // Update the distributed offsets to match round robin distribution (i.e. + // each lane owns data at `subgroupSize` stride given unit lane data). + updatedOffsets[destDistributedDim] = + rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize); + } + // Do the distribution by yielding the source and dest of the insert op + // from the warp op and creating a new insert op outside the warp op. + SmallVector<size_t> newRetIndices; + auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, + {updatedSourceType, updatedDestType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + + Value valueToStore = newWarpOp.getResult(newRetIndices[0]); + Value dest = newWarpOp.getResult(newRetIndices[1]); + // Create a new insert op outside the warp op. + Value newInsertOp = vector::InsertStridedSliceOp::create( + rewriter, insertOp.getLoc(), updatedDestType, valueToStore, dest, + ArrayAttr::get(rewriter.getContext(), updatedOffsets), + insertOp.getStrides()); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), + newInsertOp); + return success(); + } +}; + /// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an /// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op /// outside of the warp op. @@ -1629,9 +2020,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( MemrefExtractAlignedPointerAsIndexDistribution>( patterns.getContext(), /*pattern benefit=*/regularPatternBenefit); - patterns.add<VectorShapeCastDistribution>( - patterns.getContext(), - /*pattern benefit=*/highPatternBenefit); + // For following patterns, we need to override the regular vector distribution + // patterns. Therefore, assign higher benefit. + patterns + .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution, + VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>( + patterns.getContext(), + /*pattern benefit=*/highPatternBenefit); } void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index c3bf960..af63f09 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> { auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value { xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> { return xegpu::LoadNdOp::create( rewriter, loc, newValueTy, convertedTdescs[0], offsets, op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); }; newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createLoad, loc, rewriter); @@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { if (!targetShape) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropInstData(); int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); @@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> { xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++], convertedTdescs[0], offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); // return dummy Value to satisfy function's signature return nullptr; }; @@ -678,7 +687,7 @@ struct UnrollLoadGatherOpWithOffset pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter); } - auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr()); + auto layout = op.getLayoutAttr(); if (layout) layout = layout.dropInstData(); @@ -778,7 +787,7 @@ struct UnrollStoreScatterOpWithOffsets SmallVector<Value> convertedValues = pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); - auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr()); + auto layout = op.getLayoutAttr(); if (layout) layout = layout.dropInstData(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0a9ef0a..be82cda 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -86,8 +86,16 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, if (origOffsets.empty()) return failure(); + // if op is xegpu::CreateNdDescOp, call op.getDescLayoutAttr() + xegpu::DistributeLayoutAttr layout; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp> || + std::is_same_v<OpType, xegpu::StoreMatrixOp>) { + layout = op.getLayoutAttr(); + } else { + layout = op.getDescLayoutAttr(); + } + // not applicable to ops without workgroup layout attributes - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -190,7 +198,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { xegpu::TensorDescType tdescTy = op.getType(); ArrayRef<int64_t> wgShape = tdescTy.getShape(); Type elemTy = tdescTy.getElementType(); - xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; auto newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), @@ -309,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); SmallVector<Value> newOps; for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { @@ -318,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { auto newOp = xegpu::LoadNdOp::create( rewriter, op.getLoc(), newResTy, tdesc, offsets, /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + op.getL2HintAttr(), op.getL3HintAttr(), layout); newOps.push_back(newOp); } rewriter.replaceOpWithMultiple(op, {newOps}); @@ -339,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [v, tdesc, offsets] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -363,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset if (failed(genOffsetsList(rewriter, op, offsetsList))) return failure(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (layout) + layout = layout.dropSgLayoutAndData(); for (auto [tdesc, offsets] : llvm::zip(adaptor.getTensorDesc(), offsetsList)) { xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), - op.getL3HintAttr()); + op.getL3HintAttr(), layout); } rewriter.eraseOp(op); @@ -489,10 +506,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -738,12 +753,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Location loc = op.getLoc(); auto eltType = vecType.getElementType(); - auto setLayoutIfNeeded = [&](Value val) { - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), - layout.dropSgLayoutAndData()); - } + auto setLayout = [&](Value val) { + xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val), + layout.dropSgLayoutAndData()); }; if (vecAttr.isSplat()) { @@ -751,14 +763,14 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { Attribute singleVal = vecAttr.getSplatValue<Attribute>(); auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr); - setLayoutIfNeeded(cstOp->getResult(0)); + setLayout(cstOp->getResult(0)); rewriter.replaceOp(op, cstOp); return success(); } else if (sgShape == wgShape) { // if the entire vector is shared by all // subgroups, don't distribute auto newConstOp = arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr); - setLayoutIfNeeded(newConstOp->getResult(0)); + setLayout(newConstOp->getResult(0)); rewriter.replaceOp(op, newConstOp); return success(); } else { @@ -860,9 +872,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); - setLayoutIfNeeded(baseConstVec); - setLayoutIfNeeded(bcastOffset); - setLayoutIfNeeded(finalConst); + setLayout(baseConstVec); + setLayout(bcastOffset); + setLayout(finalConst); newConstOps.push_back(finalConst); } rewriter.replaceOpWithMultiple(op, {newConstOps}); @@ -889,8 +901,8 @@ struct WgToSgLoadGatherOpWithOffset return failure(); ArrayRef<int64_t> wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>( - xegpu::getDistributeLayoutAttr(op.getResult())); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -913,10 +925,12 @@ struct WgToSgLoadGatherOpWithOffset VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); for (auto [offsets, mask] : llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLayout = layout.dropSgLayoutAndData(); auto newLoadOp = xegpu::LoadGatherOp::create( rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), - layout.dropSgLayoutAndData()); + newLayout); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout); newLoadOps.push_back(newLoadOp); } rewriter.replaceOpWithMultiple(op, {newLoadOps}); @@ -941,8 +955,8 @@ struct WgToSgStoreScatterOpWithOffset if (!valueType) return failure(); - xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>( - xegpu::getDistributeLayoutAttr(op.getOperand(0))); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -967,14 +981,11 @@ struct WgToSgStoreScatterOpWithOffset op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), layout.dropSgLayoutAndData()); // Update the layout attribute to drop sg_layout and sg_data. - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - for (OpOperand &operand : store->getOpOperands()) { - // Skip for operand one (memref) - if (operand.getOperandNumber() == 1) - continue; - xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); - } + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); } } rewriter.eraseOp(op); @@ -1067,15 +1078,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> { vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) { - xegpu::setDistributeLayoutAttr(steps->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), - layout.dropSgLayoutAndData()); - xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), - layout.dropSgLayoutAndData()); - } + xegpu::setDistributeLayoutAttr(steps->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), + layout.dropSgLayoutAndData()); newOps.push_back(finalSteps); } @@ -1143,10 +1151,8 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), newResultType, src); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), + layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); } @@ -1207,10 +1213,8 @@ struct WgToSgMultiDimReductionOp auto newOp = vector::MultiDimReductionOp::create( rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], op.getReductionDims()); - if (!layout.getEffectiveLaneLayoutAsInt().empty() || - !layout.getEffectiveInstDataAsInt().empty()) - xegpu::setDistributeLayoutAttr(newOp->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newOp->getResult(0), + layout.dropSgLayoutAndData()); newReductions.push_back(newOp.getResult()); } @@ -1283,6 +1287,78 @@ struct WgToSgVectorTransposeOp } }; +// Distribute vector mask ops to work at subgroup level. +template <typename MaskOpType> +struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> { + using OpConversionPattern<MaskOpType>::OpConversionPattern; + + LogicalResult matchAndRewrite( + MaskOpType op, + typename OpConversionPattern<MaskOpType>::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Location loc = op.getLoc(); + VectorType type = op.getResult().getType(); + auto wgShape = type.getShape(); + + SmallVector<Value> wgMaskDimSizes; + if constexpr (std::is_same_v<MaskOpType, vector::ConstantMaskOp>) { + for (int64_t maskSize : op.getMaskDimSizes()) { + wgMaskDimSizes.push_back( + arith::ConstantIndexOp::create(rewriter, loc, maskSize)); + } + } else if constexpr (std::is_same_v<MaskOpType, vector::CreateMaskOp>) { + wgMaskDimSizes = llvm::to_vector(op.getOperands()); + } + + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType resultType = VectorType::get(sgShape, type.getElementType()); + + // In each dimension, each subgroup computes its local mask size as: + // min(max(wgMaskDimSize[d] - offset[d], 0), sgDimSize[d]) + SmallVector<Value> newCreateMaskOps; + for (auto offsetSet : *sgOffsets) { + SmallVector<Value> maskOperands; + + for (auto [i, wgMaskDimSize] : llvm::enumerate(wgMaskDimSizes)) { + Value dimSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); + Value offset = offsetSet[i]; + Value adjustedMaskSize = + arith::SubIOp::create(rewriter, loc, wgMaskDimSize, offset); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value nonNegative = + arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); + Value sgMaskSize = + arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); + maskOperands.push_back(sgMaskSize); + } + + auto newCreateMaskOp = + vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); + xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0), + layout.dropSgLayoutAndData()); + newCreateMaskOps.push_back(newCreateMaskOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); + return success(); + } +}; + +using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>; +using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>; } // namespace namespace mlir { @@ -1297,7 +1373,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>( patterns.getContext()); } } // namespace xegpu @@ -1427,7 +1504,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, vector::TransposeOp, vector::BroadcastOp, - vector::MultiDimReductionOp>( + vector::MultiDimReductionOp, + vector::ConstantMaskOp, vector::CreateMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index de9e09d..9f126fe 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -140,7 +139,6 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { // for StoreMatrixOp, the layout is attached to the property of the op if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp)) return storeOp.getLayoutAttr(); - std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); @@ -308,7 +306,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, int64_t rankDiff = srcShapeRank - targetShapeRank; std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, 1); - std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + llvm::copy(shape, adjustedTargetShape.begin() + rankDiff); SmallVector<Value> result; for (SmallVector<int64_t> offsets : @@ -528,7 +526,7 @@ SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder, for (auto [l, r] : llvm::zip_equal(lhs, rhs)) { auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); - results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval)); + results.push_back(builder.createOrFold<arith::AddIOp>(loc, lval, rval)); } return results; } diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp new file mode 100644 index 0000000..f3e38eb --- /dev/null +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -0,0 +1,174 @@ +//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file exposes the APFloat infrastructure to MLIR programs as a runtime +// library. APFloat is a software implementation of floating point arithmetics. +// +// On the MLIR side, floating-point values must be bitcasted to 64-bit integers +// before calling a runtime function. If a floating-point type has less than +// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an +// integer. +// +// Runtime functions receive the floating-point operands of the arithmeic +// operation in the form of 64-bit integers, along with the APFloat semantics +// in the form of a 32-bit integer, which will be interpreted as an +// APFloatBase::Semantics enum value. +// +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APSInt.h" + +#ifdef _WIN32 +#ifndef MLIR_APFLOAT_WRAPPERS_EXPORT +#ifdef mlir_apfloat_wrappers_EXPORTS +// We are building this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllexport) +#else +// We are using this library +#define MLIR_APFLOAT_WRAPPERS_EXPORT __declspec(dllimport) +#endif // mlir_apfloat_wrappers_EXPORTS +#endif // MLIR_APFLOAT_WRAPPERS_EXPORT +#else +// Non-windows: use visibility attributes. +#define MLIR_APFLOAT_WRAPPERS_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 + +/// Binary operations without rounding mode. +#define APFLOAT_BINARY_OP(OP) \ + MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +/// Binary operations with rounding mode. +#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ + MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + lhs.OP(rhs, ROUNDING_MODE); \ + return lhs.bitcastToAPInt().getZExtValue(); \ + } + +extern "C" { + +#define BIN_OPS_WITH_ROUNDING(X) \ + X(add, llvm::RoundingMode::NearestTiesToEven) \ + X(subtract, llvm::RoundingMode::NearestTiesToEven) \ + X(multiply, llvm::RoundingMode::NearestTiesToEven) \ + X(divide, llvm::RoundingMode::NearestTiesToEven) + +BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE) +#undef BIN_OPS_WITH_ROUNDING +#undef APFLOAT_BINARY_OP_ROUNDING_MODE + +APFLOAT_BINARY_OP(remainder) + +#undef APFLOAT_BINARY_OP + +MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + double d = x.convertToDouble(); + fprintf(stdout, "%lg", d); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t +_mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) { + const llvm::fltSemantics &inSem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(inSemantics)); + const llvm::fltSemantics &outSem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(outSemantics)); + unsigned bitWidthIn = llvm::APFloatBase::semanticsSizeInBits(inSem); + llvm::APFloat val(inSem, llvm::APInt(bitWidthIn, a)); + // TODO: Custom rounding modes are not supported yet. + bool losesInfo; + val.convert(outSem, llvm::RoundingMode::NearestTiesToEven, &losesInfo); + llvm::APInt result = val.bitcastToAPInt(); + return result.getZExtValue(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int( + int32_t semantics, int32_t resultWidth, bool isUnsigned, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned inputWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat val(sem, llvm::APInt(inputWidth, a)); + llvm::APSInt result(resultWidth, isUnsigned); + bool isExact; + // TODO: Custom rounding modes are not supported yet. + val.convertToInteger(result, llvm::RoundingMode::NearestTiesToEven, &isExact); + // This function always returns uint64_t, regardless of the desired result + // width. It does not matter whether we zero-extend or sign-extend the APSInt + // to 64 bits because the generated IR in arith-to-apfloat will truncate the + // result to the desired result width. + return result.getZExtValue(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int( + int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) { + llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned); + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + llvm::APFloat result(sem); + // TODO: Custom rounding modes are not supported yet. + result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned, + llvm::RoundingMode::NearestTiesToEven); + return result.bitcastToAPInt().getZExtValue(); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, + uint64_t a, + uint64_t b) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + llvm::APFloat y(sem, llvm::APInt(bitWidth, b)); + return static_cast<int8_t>(x.compare(y)); +} + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_neg(int32_t semantics, uint64_t a) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + x.changeSign(); + return x.bitcastToAPInt().getZExtValue(); +} + +/// Min/max operations. +#define APFLOAT_MIN_MAX_OP(OP) \ + MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \ + int32_t semantics, uint64_t a, uint64_t b) { \ + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ + static_cast<llvm::APFloatBase::Semantics>(semantics)); \ + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \ + llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \ + llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \ + llvm::APFloat result = llvm::OP(lhs, rhs); \ + return result.bitcastToAPInt().getZExtValue(); \ + } + +APFLOAT_MIN_MAX_OP(minimum) +APFLOAT_MIN_MAX_OP(maximum) +APFLOAT_MIN_MAX_OP(minnum) +APFLOAT_MIN_MAX_OP(maxnum) + +#undef APFLOAT_MIN_MAX_OP +} diff --git a/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp b/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp index 9868ffa..9b1c39e 100644 --- a/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/ArmRunnerUtils.cpp @@ -49,7 +49,7 @@ extern "C" { /// The recommended strategy is to call `setArmVectorLength` only from functions /// that do not access SVE registers, either by themselves or by inlining other /// functions. -static void setArmVectorLength(std::string_view helper_name, int option, +static void setArmVectorLength(std::string_view helperName, int option, uint32_t bits) { #if defined(__linux__) && defined(__aarch64__) if (bits < 128 || bits > 2048 || !llvm::isPowerOf2_32(bits)) { @@ -63,7 +63,7 @@ static void setArmVectorLength(std::string_view helper_name, int option, abort(); } #else - std::cerr << "[error] " << helper_name << " is unsupported" << std::endl; + std::cerr << "[error] " << helperName << " is unsupported" << std::endl; abort(); #endif } diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt index fdeb4dac..a615352 100644 --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -2,6 +2,7 @@ # is a big dependency which most don't need. set(LLVM_OPTIONAL_SOURCES + APFloatWrappers.cpp ArmRunnerUtils.cpp ArmSMEStubs.cpp AsyncRuntime.cpp @@ -167,6 +168,26 @@ if(LLVM_ENABLE_PIC) set_property(TARGET mlir_float16_utils PROPERTY CXX_STANDARD 17) target_compile_definitions(mlir_float16_utils PRIVATE mlir_float16_utils_EXPORTS) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # TODO: This support library is only used on Linux builds until we figure + # out how to hide LLVM symbols in a way that works for all platforms. + add_mlir_library(mlir_apfloat_wrappers + SHARED + APFloatWrappers.cpp + + EXCLUDE_FROM_LIBMLIR + ) + set_target_properties( + mlir_apfloat_wrappers + PROPERTIES CXX_STANDARD 17 + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON + ) + target_compile_definitions(mlir_apfloat_wrappers PRIVATE mlir_apfloat_wrappers_EXPORTS) + # Hide LLVM symbols to avoid ODR violations. + target_link_options(mlir_apfloat_wrappers PRIVATE "-Wl,--exclude-libs,ALL") + endif() + add_subdirectory(SparseTensor) add_mlir_library(mlir_c_runner_utils @@ -184,6 +205,11 @@ if(LLVM_ENABLE_PIC) set_property(TARGET mlir_c_runner_utils PROPERTY CXX_STANDARD 17) target_compile_definitions(mlir_c_runner_utils PRIVATE mlir_c_runner_utils_EXPORTS) + # Conditionally link apfloat wrappers only on Linux. + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_libraries(mlir_c_runner_utils PUBLIC mlir_apfloat_wrappers) + endif() + add_mlir_library(mlir_runner_utils SHARED RunnerUtils.cpp @@ -195,6 +221,11 @@ if(LLVM_ENABLE_PIC) ) target_compile_definitions(mlir_runner_utils PRIVATE mlir_runner_utils_EXPORTS) + # Conditionally link apfloat wrappers only on Linux. + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_libraries(mlir_runner_utils PUBLIC mlir_apfloat_wrappers) + endif() + add_mlir_library(mlir_async_runtime SHARED AsyncRuntime.cpp @@ -323,7 +354,6 @@ if(LLVM_ENABLE_PIC) endif() string(STRIP AGENTS_STRING ${AGENTS_STRING}) string(REPLACE "\n" ";" AGENTS_LIST ${AGENTS_STRING}) - list(FILTER AGENTS_LIST EXCLUDE REGEX "gfx000") if (AGENTS_LIST STREQUAL "") message(SEND_ERROR "No non-CPU ROCm agents found on the system, and ROCM_TEST_CHIPSET is not defined") else() diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index 6cc2b7fd..f203363 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -57,7 +57,7 @@ thread_local static int32_t defaultDevice = 0; /// Helper method that checks environment value for debugging. -bool isDebugEnabled() { +static bool isDebugEnabled() { const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG"; static bool isEnabled = getenv(kDebugEnvironmentVariable) != nullptr; return isEnabled; @@ -71,7 +71,7 @@ bool isDebugEnabled() { } while (0) // Returns default CUdevice -CUdevice getDefaultCuDevice() { +static CUdevice getDefaultCuDevice() { CUdevice device; CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); return device; diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 2255633..287c52a 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -146,12 +146,10 @@ static void packFunctionArguments(Module *module) { llvm::IRBuilder<> builder(ctx); DenseSet<llvm::Function *> interfaceFunctions; for (auto &func : module->getFunctionList()) { - if (func.isDeclaration()) { + if (func.isDeclaration() || func.hasLocalLinkage()) continue; - } - if (interfaceFunctions.count(&func)) { + if (interfaceFunctions.count(&func)) continue; - } // Given a function `foo(<...>)`, define the interface function // `mlir_foo(i8**)`. diff --git a/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp index ddea230..ff0dd54 100644 --- a/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp @@ -156,7 +156,7 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/, size_t /*smem*/, void *vkRuntimeManager, void **params, void ** /*extra*/, size_t paramsCount) { - auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); + auto *manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); // GpuToLLVMConversionPass with the kernelBarePtrCallConv and // kernelIntersperseSizeCallConv options will set up the params array like: @@ -180,7 +180,7 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, static_cast<uint32_t>(gridY), static_cast<uint32_t>(gridZ)}); - auto function = static_cast<VulkanFunction *>(vkKernel); + auto *function = static_cast<VulkanFunction *>(vkKernel); // Expected size should be in bytes. manager->setShaderModule( function->module->blobData(), diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 9b23dd6..fd846e4 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2032,7 +2032,7 @@ private: }; template <typename Range> -void printDimensionList(raw_ostream &stream, Range &&shape) { +static void printDimensionList(raw_ostream &stream, Range &&shape) { llvm::interleave( shape, stream, [&stream](const auto &dimSize) { diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index 031eae2..4cce16b 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -31,6 +31,11 @@ Remark::Arg::Arg(llvm::StringRef k, Type t) : key(k) { os << t; } +Remark::Arg::Arg(llvm::StringRef k, Attribute a) : key(k), attr(a) { + llvm::raw_string_ostream os(val); + os << a; +} + void Remark::insert(llvm::StringRef s) { args.emplace_back(s); } void Remark::insert(Arg a) { args.push_back(std::move(a)); } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index e438631..199744d2 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -118,8 +118,7 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { /// have compatible dimensions. Dimensions are compatible if all non-dynamic /// dims are equal. The element type does not matter. LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { - auto shapedTypes = llvm::map_to_vector<8>( - types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }); + auto shapedTypes = llvm::map_to_vector<8>(types, llvm::DynCastTo<ShapedType>); // Return failure if some, but not all are not shaped. Return early if none // are shaped also. if (llvm::none_of(shapedTypes, [](auto t) { return t; })) diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index 9f4f672..c31e0ae7 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op, return status; } +FailureOr<SmallVector<OpFoldResult>> +mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) { + auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); + if (!reifiableOp) + return failure(); + return reifiableOp.reifyShapeOfResult(b, resultIndex); +} + +FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op, + int resultIndex, int dim) { + auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); + if (!reifiableOp) + return failure(); + return reifiableOp.reifyDimOfResult(b, resultIndex, dim); +} + bool ShapeAdaptor::hasRank() const { if (val.isNull()) return false; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index a5bfde1..cfe808b 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -129,7 +129,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map, assert(var.map.getNumDims() == 0 && "expected only symbols"); SmallVector<AffineExpr> symReplacements; for (auto valueDim : var.mapOperands) { - auto it = llvm::find(this->mapOperands, valueDim); + auto *it = llvm::find(this->mapOperands, valueDim); if (it != this->mapOperands.end()) { // There is already a symbol for this operand. symReplacements.push_back(b.getAffineSymbolExpr( diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 521c7c6..75f8826 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -559,9 +559,9 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, return op->emitOpError() << "trying to schedule a pass on an operation not " "marked as 'IsolatedFromAbove'"; } - if (!pass->canScheduleOn(*op->getName().getRegisteredInfo())) { - return op->emitOpError() - << "trying to schedule a pass on an unsupported operation"; + if (!pass->canScheduleOn(op)) { + return op->emitOpError() << "trying to schedule pass '" << pass->getName() + << "' on an unsupported operation"; } // Initialize the pass state with a callback for the pass to dynamically diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index e392a88..7bfe03d 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -27,7 +27,7 @@ struct Parser::TokenInfo { } // Known identifiers. - static const char *const ID_Extract; + static const char *const idExtract; llvm::StringRef text; TokenKind kind = TokenKind::Eof; @@ -35,7 +35,7 @@ struct Parser::TokenInfo { VariantValue value; }; -const char *const Parser::TokenInfo::ID_Extract = "extract"; +const char *const Parser::TokenInfo::idExtract = "extract"; class Parser::CodeTokenizer { public: @@ -452,13 +452,13 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, } if (chainCallToken.kind != TokenKind::Ident || - chainCallToken.text != TokenInfo::ID_Extract) { + chainCallToken.text != TokenInfo::idExtract) { error->addError(chainCallToken.range, ErrorType::ParserMalformedChainedExpr); return false; } - if (chainCallToken.text == TokenInfo::ID_Extract && + if (chainCallToken.text == TokenInfo::idExtract && !parseChainedExpression(functionName)) return false; } diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index 5b49204..1e00ed6 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -175,9 +175,12 @@ public: using Base::Base; // Collect the reduce patterns defined by each dialect. - void populateReductionPatterns(RewritePatternSet &pattern) const { - for (const DialectReductionPatternInterface &interface : *this) + void populateReductionPatterns(RewritePatternSet &pattern, + Tester &tester) const { + for (const DialectReductionPatternInterface &interface : *this) { interface.populateReductionPatterns(pattern); + interface.populateReductionPatternsWithTester(pattern, tester); + } } }; @@ -201,15 +204,21 @@ public: private: LogicalResult reduceOp(ModuleOp module, Region ®ion); + Tester tester; FrozenRewritePatternSet reducerPatterns; }; } // namespace LogicalResult ReductionTreePass::initialize(MLIRContext *context) { + tester.setTestScript(testerName); + tester.setTestScriptArgs(testerArgs); + RewritePatternSet patterns(context); + ReductionPatternInterfaceCollection reducePatternCollection(context); - reducePatternCollection.populateReductionPatterns(patterns); + reducePatternCollection.populateReductionPatterns(patterns, tester); + reducerPatterns = std::move(patterns); return success(); } @@ -244,11 +253,10 @@ void ReductionTreePass::runOnOperation() { } LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { - Tester test(testerName, testerArgs); switch (traversalModeId) { case TraversalMode::SinglePath: return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( - module, region, reducerPatterns, test); + module, region, reducerPatterns, tester); default: return module.emitError() << "unsupported traversal mode detected"; } diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index c857c38..4312100 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -56,6 +56,7 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h" #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -113,6 +114,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); + x86vector::registerTransformDialectExtension(registry); xegpu::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 42843ea..159aa54 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1099,12 +1099,12 @@ public: MutableArrayRef<PDLValue> getResults() { return results; } /// Return the type ranges allocated by this list. - MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { + MutableArrayRef<std::vector<Type>> getAllocatedTypeRanges() { return allocatedTypeRanges; } /// Return the value ranges allocated by this list. - MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { + MutableArrayRef<std::vector<Value>> getAllocatedValueRanges() { return allocatedValueRanges; } }; @@ -1112,19 +1112,20 @@ public: /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: - ByteCodeExecutor( - const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory, - MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory, - MutableArrayRef<TypeRange> typeRangeMemory, - std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory, - MutableArrayRef<ValueRange> valueRangeMemory, - std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory, - MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory, - ArrayRef<ByteCodeField> code, - ArrayRef<PatternBenefit> currentPatternBenefits, - ArrayRef<PDLByteCodePattern> patterns, - ArrayRef<PDLConstraintFunction> constraintFunctions, - ArrayRef<PDLRewriteFunction> rewriteFunctions) + ByteCodeExecutor(const ByteCodeField *curCodeIt, + MutableArrayRef<const void *> memory, + MutableArrayRef<std::vector<Operation *>> opRangeMemory, + MutableArrayRef<TypeRange> typeRangeMemory, + std::vector<std::vector<Type>> &allocatedTypeRangeMemory, + MutableArrayRef<ValueRange> valueRangeMemory, + std::vector<std::vector<Value>> &allocatedValueRangeMemory, + MutableArrayRef<unsigned> loopIndex, + ArrayRef<const void *> uniquedMemory, + ArrayRef<ByteCodeField> code, + ArrayRef<PatternBenefit> currentPatternBenefits, + ArrayRef<PDLByteCodePattern> patterns, + ArrayRef<PDLConstraintFunction> constraintFunctions, + ArrayRef<PDLRewriteFunction> rewriteFunctions) : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), @@ -1367,13 +1368,9 @@ private: if (range.empty()) { rangeMemory[rangeIndex] = {}; } else { - // Allocate a buffer for this type range. - llvm::OwningArrayRef<T> storage(llvm::size(range)); - llvm::copy(range, storage.begin()); - // Assign this to the range slot and use the range as the value for the // memory index. - allocatedRangeMemory.emplace_back(std::move(storage)); + allocatedRangeMemory.emplace_back(range.begin(), range.end()); rangeMemory[rangeIndex] = allocatedRangeMemory.back(); } memory[memIndex] = &rangeMemory[rangeIndex]; @@ -1397,11 +1394,11 @@ private: /// The current execution memory. MutableArrayRef<const void *> memory; - MutableArrayRef<OwningOpRange> opRangeMemory; + MutableArrayRef<std::vector<Operation *>> opRangeMemory; MutableArrayRef<TypeRange> typeRangeMemory; - std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory; + std::vector<std::vector<Type>> &allocatedTypeRangeMemory; MutableArrayRef<ValueRange> valueRangeMemory; - std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory; + std::vector<std::vector<Value>> &allocatedValueRangeMemory; /// The current loop indices. MutableArrayRef<unsigned> loopIndex; @@ -1907,10 +1904,10 @@ void ByteCodeExecutor::executeGetUsers() { LDBG() << "Executing GetUsers:"; unsigned memIndex = read(); unsigned rangeIndex = read(); - OwningOpRange &range = opRangeMemory[rangeIndex]; + std::vector<Operation *> &range = opRangeMemory[rangeIndex]; memory[memIndex] = ⦥ - range = OwningOpRange(); + range.clear(); if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { // Read the value. Value value = read<Value>(); @@ -1918,9 +1915,7 @@ void ByteCodeExecutor::executeGetUsers() { return; LDBG() << " * Value: " << value; - // Extract the users of a single value. - range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); - llvm::copy(value.getUsers(), range.begin()); + range.assign(value.user_begin(), value.user_end()); } else { // Read a range of values. ValueRange *values = read<ValueRange *>(); @@ -1929,12 +1924,8 @@ void ByteCodeExecutor::executeGetUsers() { LDBG() << " * Values (" << values->size() << "): " << llvm::interleaved(*values); - // Extract all the users of a range of values. - SmallVector<Operation *> users; for (Value value : *values) - users.append(value.user_begin(), value.user_end()); - range = OwningOpRange(users.size()); - llvm::copy(users, range.begin()); + range.insert(range.end(), value.user_begin(), value.user_end()); } LDBG() << " * Result: " << range.size() << " operations"; @@ -2174,7 +2165,8 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, executeEraseOp(rewriter); break; case ExtractOp: - executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>(); + executeExtract<Operation *, std::vector<Operation *>, + PDLValue::Kind::Operation>(); break; case ExtractType: executeExtract<Type, TypeRange, PDLValue::Kind::Type>(); diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h index 4aceac7..566c1cb 100644 --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -30,7 +30,6 @@ class PDLByteCode; /// entries. ByteCodeAddr refers to size of indices into the bytecode. using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; -using OwningOpRange = llvm::OwningArrayRef<Operation *>; //===----------------------------------------------------------------------===// // PDLByteCodePattern @@ -94,21 +93,21 @@ private: /// the bytecode to store ranges of operations. These are always stored by /// owning references, because at no point in the execution of the byte code /// we get an indexed range (view) of operations. - std::vector<OwningOpRange> opRangeMemory; + std::vector<std::vector<Operation *>> opRangeMemory; /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector<TypeRange> typeRangeMemory; /// A set of type ranges that have been allocated by the byte code interpreter /// to provide a guaranteed lifetime. - std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory; + std::vector<std::vector<Type>> allocatedTypeRangeMemory; /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of values. std::vector<ValueRange> valueRangeMemory; /// A set of value ranges that have been allocated by the byte code /// interpreter to provide a guaranteed lifetime. - std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory; + std::vector<std::vector<Value>> allocatedValueRangeMemory; /// The current index of ranges being iterated over for each level of nesting. /// These are always maintained at 0 for the loops that are not active, so we diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index b0ad3ee..77a6cec 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) { bool TypeInterface::classof(const Interface *interface) { return interface->getDef().isSubClassOf("TypeInterface"); } + +//===----------------------------------------------------------------------===// +// DialectInterface +//===----------------------------------------------------------------------===// + +bool DialectInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("DialectInterface"); +} diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 1a1a58a..ce09f5c 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -771,15 +772,27 @@ int Pattern::getBenefit() const { return initBenefit + dyn_cast<IntInit>(delta->getArg(0))->getValue(); } -std::vector<Pattern::IdentifierLine> Pattern::getLocation() const { +std::vector<Pattern::IdentifierLine> +Pattern::getLocation(bool forSourceOutput) const { std::vector<std::pair<StringRef, unsigned>> result; result.reserve(def.getLoc().size()); for (auto loc : def.getLoc()) { unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); assert(buf && "invalid source location"); - result.emplace_back( - llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), - llvm::SrcMgr.getLineAndColumn(loc, buf).first); + + StringRef bufferName = + llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(); + // If we're emitting a generated file, we'd like to have some indication of + // where our patterns came from. However, LLVM's build rules use absolute + // paths as arguments to TableGen, and naively echoing such paths makes the + // contents of the generated source file depend on the build location, + // making MLIR builds substantially less reproducable. As a compromise, we + // trim absolute paths back to only the filename component. + if (forSourceOutput && llvm::sys::path::is_absolute(bufferName)) + bufferName = llvm::sys::path::filename(bufferName); + + result.emplace_back(bufferName, + llvm::SrcMgr.getLineAndColumn(loc, buf).first); } return result; } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 1243511..15c23c6 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -70,6 +70,7 @@ static inline LogicalResult interleaveCommaWithError(const Container &c, /// imply higher precedence. static FailureOr<int> getOperatorPrecedence(Operation *operation) { return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation) + .Case<emitc::AddressOfOp>([&](auto op) { return 15; }) .Case<emitc::AddOp>([&](auto op) { return 12; }) .Case<emitc::ApplyOp>([&](auto op) { return 15; }) .Case<emitc::BitwiseAndOp>([&](auto op) { return 7; }) @@ -111,6 +112,8 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) { .Default([](auto op) { return op->emitError("unsupported operation"); }); } +static bool shouldBeInlined(Operation *op); + namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { @@ -173,8 +176,11 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); - /// Emits value as an operands of an operation - LogicalResult emitOperand(Value value); + /// Emits value as an operand of some operation. Unless \p isInBrackets is + /// true, operands emitted as sub-expressions will be parenthesized if needed + /// in order to enforce correct evaluation based on precedence and + /// associativity. + LogicalResult emitOperand(Value value, bool isInBrackets = false); /// Emit an expression as a C expression. LogicalResult emitExpression(ExpressionOp expressionOp); @@ -189,15 +195,6 @@ struct CppEmitter { /// emitc::ForOp. StringRef getOrCreateInductionVarName(Value val); - // Returns the textual representation of a subscript operation. - std::string getSubscriptName(emitc::SubscriptOp op); - - // Returns the textual representation of a member (of object) operation. - std::string createMemberAccess(emitc::MemberOp op); - - // Returns the textual representation of a member of pointer operation. - std::string createMemberAccess(emitc::MemberOfPtrOp op); - /// Return the existing or a new label of a Block. StringRef getOrCreateName(Block &block); @@ -259,25 +256,20 @@ struct CppEmitter { return !fileId.empty() && file.getId() == fileId; } - /// Get expression currently being emitted. - ExpressionOp getEmittedExpression() { return emittedExpression; } + /// Is expression currently being emitted. + bool isEmittingExpression() { return !emittedExpressionPrecedence.empty(); } /// Determine whether given value is part of the expression potentially being /// emitted. bool isPartOfCurrentExpression(Value value) { - if (!emittedExpression) - return false; Operation *def = value.getDefiningOp(); - if (!def) - return false; - return isPartOfCurrentExpression(def); + return def ? isPartOfCurrentExpression(def) : false; } /// Determine whether given operation is part of the expression potentially /// being emitted. bool isPartOfCurrentExpression(Operation *def) { - auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp()); - return operandExpression && operandExpression == emittedExpression; + return isEmittingExpression() && shouldBeInlined(def); }; // Resets the value counter to 0. @@ -324,7 +316,6 @@ private: unsigned int valueCount{0}; /// State of the current expression being emitted. - ExpressionOp emittedExpression; SmallVector<int> emittedExpressionPrecedence; void pushExpressionPrecedence(int precedence) { @@ -342,17 +333,28 @@ private: /// Determine whether expression \p op should be emitted in a deferred way. static bool hasDeferredEmission(Operation *op) { - return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp, + return isa_and_nonnull<emitc::DereferenceOp, emitc::GetGlobalOp, + emitc::LiteralOp, emitc::MemberOp, emitc::MemberOfPtrOp, emitc::SubscriptOp, emitc::GetFieldOp>(op); } -/// Determine whether expression \p expressionOp should be emitted inline, i.e. +/// Determine whether operation \p op should be emitted inline, i.e. /// as part of its user. This function recommends inlining of any expressions /// that can be inlined unless it is used by another expression, under the /// assumption that any expression fusion/re-materialization was taken care of /// by transformations run by the backend. -static bool shouldBeInlined(ExpressionOp expressionOp) { +static bool shouldBeInlined(Operation *op) { + // CExpression operations are inlined if and only if they reside within an + // ExpressionOp. + if (isa<CExpressionInterface>(op)) + return isa<ExpressionOp>(op->getParentOp()); + + // Only other inlinable operation is ExpressionOp itself. + ExpressionOp expressionOp = dyn_cast<ExpressionOp>(op); + if (!expressionOp) + return false; + // Do not inline if expression is marked as such. if (expressionOp.getDoNotInline()) return false; @@ -402,6 +404,66 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { return false; } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::DereferenceOp dereferenceOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << "*" << emitter.getOrCreateName(dereferenceOp.getPointer()); + emitter.cacheDeferredOpResult(dereferenceOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetFieldOp getFieldOp) { + emitter.cacheDeferredOpResult(getFieldOp.getResult(), + getFieldOp.getFieldName()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetGlobalOp getGlobalOp) { + emitter.cacheDeferredOpResult(getGlobalOp.getResult(), getGlobalOp.getName()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LiteralOp literalOp) { + emitter.cacheDeferredOpResult(literalOp.getResult(), literalOp.getValue()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOp memberOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(memberOp.getOperand()); + ss << "." << memberOp.getMember(); + emitter.cacheDeferredOpResult(memberOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOfPtrOp memberOfPtrOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(memberOfPtrOp.getOperand()); + ss << "->" << memberOfPtrOp.getMember(); + emitter.cacheDeferredOpResult(memberOfPtrOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::SubscriptOp subscriptOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(subscriptOp.getValue()); + for (auto index : subscriptOp.getIndices()) { + ss << "[" << emitter.getOrCreateName(index) << "]"; + } + emitter.cacheDeferredOpResult(subscriptOp.getResult(), out); + return success(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -435,6 +497,17 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, } static LogicalResult printOperation(CppEmitter &emitter, + emitc::AddressOfOp addressOfOp) { + raw_ostream &os = emitter.ostream(); + Operation &op = *addressOfOp.getOperation(); + + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + os << "&"; + return emitter.emitOperand(addressOfOp.getReference()); +} + +static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); Attribute value = constantOp.getValue(); @@ -1336,32 +1409,6 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, labelInScopeCount.push(0); } -std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getValue()); - for (auto index : op.getIndices()) { - ss << "[" << getOrCreateName(index) << "]"; - } - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "." << op.getMember(); - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "->" << op.getMember(); - return out; -} - void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) { if (!valueMapper.count(value)) valueMapper.insert(value, str.str()); @@ -1545,7 +1592,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { "Expected precedence stack to be empty"); Operation *rootOp = expressionOp.getRootOp(); - emittedExpression = expressionOp; FailureOr<int> precedence = getOperatorPrecedence(rootOp); if (failed(precedence)) return failure(); @@ -1557,12 +1603,11 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { popExpressionPrecedence(); assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - emittedExpression = nullptr; return success(); } -LogicalResult CppEmitter::emitOperand(Value value) { +LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) { if (isPartOfCurrentExpression(value)) { Operation *def = value.getDefiningOp(); assert(def && "Expected operand to be defined by an operation"); @@ -1570,10 +1615,12 @@ LogicalResult CppEmitter::emitOperand(Value value) { if (failed(precedence)) return failure(); - // Sub-expressions with equal or lower precedence need to be parenthesized, - // as they might be evaluated in the wrong order depending on the shape of - // the expression tree. - bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence(); + // Unless already in brackets, sub-expressions with equal or lower + // precedence need to be parenthesized as they might be evaluated in the + // wrong order depending on the shape of the expression tree. + bool encloseInParenthesis = + !isInBrackets && precedence.value() <= getExpressionPrecedence(); + if (encloseInParenthesis) os << "("; pushExpressionPrecedence(precedence.value()); @@ -1596,14 +1643,8 @@ LogicalResult CppEmitter::emitOperand(Value value) { // If this operand is a block argument of an expression, emit instead the // matching expression parameter. Operation *argOp = arg.getParentBlock()->getParentOp(); - if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) { - // This scenario is only expected when one of the operations within the - // expression being emitted references one of the expression's block - // arguments. - assert(expressionOp == emittedExpression && - "Expected expression being emitted"); - value = expressionOp->getOperand(arg.getArgNumber()); - } + if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) + return emitOperand(expressionOp->getOperand(arg.getArgNumber())); } os << getOrCreateName(value); @@ -1612,15 +1653,9 @@ LogicalResult CppEmitter::emitOperand(Value value) { LogicalResult CppEmitter::emitOperands(Operation &op) { return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { - // If an expression is being emitted, push lowest precedence as these - // operands are either wrapped by parenthesis. - if (getEmittedExpression()) - pushExpressionPrecedence(lowestPrecedence()); - if (failed(emitOperand(operand))) - return failure(); - if (getEmittedExpression()) - popExpressionPrecedence(); - return success(); + // Emit operand under guarantee that if it's part of an expression then it + // is being emitted within brackets. + return emitOperand(operand, /*isInBrackets=*/true); }); } @@ -1702,7 +1737,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { // If op is being emitted as part of an expression, bail out. - if (getEmittedExpression()) + if (isEmittingExpression()) return success(); switch (op.getNumResults()) { @@ -1753,49 +1788,27 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { .Case<cf::BranchOp, cf::CondBranchOp>( [&](auto op) { return printOperation(*this, op); }) // EmitC ops. - .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, - emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp, + .Case<emitc::AddressOfOp, emitc::AddOp, emitc::ApplyOp, + emitc::AssignOp, emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp, emitc::BitwiseNotOp, emitc::BitwiseOrOp, emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp, emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, - emitc::DeclareFuncOp, emitc::DivOp, emitc::DoOp, - emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, - emitc::ForOp, emitc::FuncOp, emitc::GlobalOp, emitc::IfOp, - emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp, - emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, - emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp, - emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, - emitc::VerbatimOp>( + emitc::DeclareFuncOp, emitc::DereferenceOp, emitc::DivOp, + emitc::DoOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, + emitc::ForOp, emitc::FuncOp, emitc::GetFieldOp, + emitc::GetGlobalOp, emitc::GlobalOp, emitc::IfOp, + emitc::IncludeOp, emitc::LiteralOp, emitc::LoadOp, + emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, + emitc::MemberOfPtrOp, emitc::MemberOp, emitc::MulOp, + emitc::RemOp, emitc::ReturnOp, emitc::SubscriptOp, emitc::SubOp, + emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case<func::CallOp, func::FuncOp, func::ReturnOp>( [&](auto op) { return printOperation(*this, op); }) - .Case<emitc::GetGlobalOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getName()); - return success(); - }) - .Case<emitc::GetFieldOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getFieldName()); - return success(); - }) - .Case<emitc::LiteralOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getValue()); - return success(); - }) - .Case<emitc::MemberOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case<emitc::MemberOfPtrOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case<emitc::SubscriptOp>([&](auto op) { - cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); - return success(); - }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); @@ -1806,7 +1819,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (hasDeferredEmission(&op)) return success(); - if (getEmittedExpression() || + if (isEmittingExpression() || (isa<emitc::ExpressionOp>(op) && shouldBeInlined(cast<emitc::ExpressionOp>(op)))) return success(); diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 2dd0640..5be33c4 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -30,6 +30,14 @@ void registerFromLLVMIRTranslation() { llvm::cl::desc("Emit expensive warnings during LLVM IR import " "(discouraged: testing only!)"), llvm::cl::init(false)); + static llvm::cl::opt<bool> convertDebugRecToIntrinsics( + "convert-debug-rec-to-intrinsics", + llvm::cl::desc("Change the input LLVM module to use old debug intrinsics " + "instead of records " + "via convertFromNewDbgValues, this happens " + "before importing the debug information" + "(discouraged: to be removed soon!)"), + llvm::cl::init(false)); static llvm::cl::opt<bool> dropDICompositeTypeElements( "drop-di-composite-type-elements", llvm::cl::desc( @@ -69,8 +77,10 @@ void registerFromLLVMIRTranslation() { if (llvm::verifyModule(*llvmModule, &llvm::errs())) return nullptr; - // Debug records are not currently supported in the LLVM IR translator. - llvmModule->convertFromNewDbgValues(); + // Now that the translation supports importing debug records directly, + // make it the default, but allow the user to override to old behavior. + if (convertDebugRecToIntrinsics) + llvmModule->convertFromNewDbgValues(); return translateLLVMIRToModule( std::move(llvmModule), context, emitExpensiveWarnings, diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp index d3216d9..d9bfe65 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp @@ -124,10 +124,10 @@ static LogicalResult embedBinaryImpl(StringRef moduleName, } IRBuilder<> builder(module.getContext()); - auto i32Ty = builder.getInt32Ty(); - auto i64Ty = builder.getInt64Ty(); - auto ptrTy = builder.getPtrTy(0); - auto voidTy = builder.getVoidTy(); + auto *i32Ty = builder.getInt32Ty(); + auto *i64Ty = builder.getInt64Ty(); + auto *ptrTy = builder.getPtrTy(0); + auto *voidTy = builder.getVoidTy(); // Embed the module as a global object. auto *modulePtr = new GlobalVariable( @@ -147,13 +147,12 @@ static LogicalResult embedBinaryImpl(StringRef moduleName, "mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false)); Constant *optValue = ConstantInt::get(i32Ty, optLevel); return builder.CreateCall(moduleLoadFn, {serializedObj, optValue}); - } else { - FunctionCallee moduleLoadFn = module.getOrInsertFunction( - "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false)); - Constant *binarySize = - ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0)); - return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize}); } + FunctionCallee moduleLoadFn = module.getOrInsertFunction( + "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false)); + Constant *binarySize = + ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0)); + return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize}); }(); builder.CreateStore(moduleObj, modulePtr); builder.CreateRetVoid(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index 44732d5..2d4a18c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -80,8 +80,9 @@ static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder, /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR LLVM /// dialect attributes. -static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) { - static const SmallVector<unsigned> convertibleMetadata = { +static SmallVector<unsigned> +getSupportedMetadataImpl(llvm::LLVMContext &llvmContext) { + SmallVector<unsigned> convertibleMetadata = { llvm::LLVMContext::MD_prof, llvm::LLVMContext::MD_tbaa, llvm::LLVMContext::MD_access_group, @@ -91,10 +92,10 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) { llvm::LLVMContext::MD_dereferenceable, llvm::LLVMContext::MD_dereferenceable_or_null, llvm::LLVMContext::MD_mmra, - context.getMDKindID(vecTypeHintMDName), - context.getMDKindID(workGroupSizeHintMDName), - context.getMDKindID(reqdWorkGroupSizeMDName), - context.getMDKindID(intelReqdSubGroupSizeMDName)}; + llvmContext.getMDKindID(vecTypeHintMDName), + llvmContext.getMDKindID(workGroupSizeHintMDName), + llvmContext.getMDKindID(reqdWorkGroupSizeMDName), + llvmContext.getMDKindID(intelReqdSubGroupSizeMDName)}; return convertibleMetadata; } @@ -113,7 +114,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, return failure(); // Handle function entry count metadata. - if (name->getString() == "function_entry_count") { + if (name->getString() == llvm::MDProfLabels::FunctionEntryCount) { // TODO support function entry count metadata with GUID fields. if (node->getNumOperands() != 2) @@ -131,15 +132,28 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, << "expected function_entry_count to be attached to a function"; } - if (name->getString() != "branch_weights") + if (name->getString() != llvm::MDProfLabels::BranchWeights) return failure(); + // The branch_weights metadata must have at least 2 operands. + if (node->getNumOperands() < 2) + return failure(); + + ArrayRef<llvm::MDOperand> branchWeightOperands = + node->operands().drop_front(); + if (auto *mdString = dyn_cast<llvm::MDString>(node->getOperand(1))) { + if (mdString->getString() != llvm::MDProfLabels::ExpectedBranchWeights) + return failure(); + // The MLIR WeightedBranchOpInterface does not support the + // ExpectedBranchWeights field, so it is dropped. + branchWeightOperands = branchWeightOperands.drop_front(); + } // Handle branch weights metadata. SmallVector<int32_t> branchWeights; - branchWeights.reserve(node->getNumOperands() - 1); - for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) { + branchWeights.reserve(branchWeightOperands.size()); + for (const llvm::MDOperand &operand : branchWeightOperands) { llvm::ConstantInt *branchWeight = - llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(i)); + llvm::mdconst::dyn_extract<llvm::ConstantInt>(operand); if (!branchWeight) return failure(); branchWeights.push_back(branchWeight->getZExtValue()); @@ -492,9 +506,9 @@ public: /// Returns the list of LLVM IR metadata kinds that are convertible to MLIR /// LLVM dialect attributes. - ArrayRef<unsigned> - getSupportedMetadata(llvm::LLVMContext &context) const final { - return getSupportedMetadataImpl(context); + SmallVector<unsigned> + getSupportedMetadata(llvm::LLVMContext &llvmContext) const final { + return getSupportedMetadataImpl(llvmContext); } }; } // namespace diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index eaf1d20..b6ea4ba 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -222,14 +222,14 @@ static void convertLinkerOptionsOp(ArrayAttr options, llvm::LLVMContext &context = llvmModule->getContext(); llvm::NamedMDNode *linkerMDNode = llvmModule->getOrInsertNamedMetadata("llvm.linker.options"); - SmallVector<llvm::Metadata *> MDNodes; - MDNodes.reserve(options.size()); + SmallVector<llvm::Metadata *> mdNodes; + mdNodes.reserve(options.size()); for (auto s : options.getAsRange<StringAttr>()) { - auto *MDNode = llvm::MDString::get(context, s.getValue()); - MDNodes.push_back(MDNode); + auto *mdNode = llvm::MDString::get(context, s.getValue()); + mdNodes.push_back(mdNode); } - auto *listMDNode = llvm::MDTuple::get(context, MDNodes); + auto *listMDNode = llvm::MDTuple::get(context, mdNodes); linkerMDNode->addOperand(listMDNode); } @@ -243,16 +243,16 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr, if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) { for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) { - llvm::Metadata *fromMetadata = - entry.getFrom() - ? llvm::ValueAsMetadata::get(moduleTranslation.lookupFunction( - entry.getFrom().getValue())) - : nullptr; - llvm::Metadata *toMetadata = - entry.getTo() - ? llvm::ValueAsMetadata::get( - moduleTranslation.lookupFunction(entry.getTo().getValue())) - : nullptr; + auto getFuncMetadata = [&](FlatSymbolRefAttr sym) -> llvm::Metadata * { + if (!sym) + return nullptr; + if (llvm::Function *fn = + moduleTranslation.lookupFunction(sym.getValue())) + return llvm::ValueAsMetadata::get(fn); + return nullptr; + }; + llvm::Metadata *fromMetadata = getFuncMetadata(entry.getFrom()); + llvm::Metadata *toMetadata = getFuncMetadata(entry.getTo()); llvm::Metadata *vals[] = { fromMetadata, toMetadata, @@ -439,7 +439,14 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, llvm::MemoryEffects::Location::InaccessibleMem, convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) | llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, - convertModRefInfoToLLVM(memAttr.getOther())); + convertModRefInfoToLLVM(memAttr.getOther())) | + llvm::MemoryEffects(llvm::MemoryEffects::Location::ErrnoMem, + convertModRefInfoToLLVM(memAttr.getErrnoMem())) | + llvm::MemoryEffects( + llvm::MemoryEffects::Location::TargetMem0, + convertModRefInfoToLLVM(memAttr.getTargetMem0())) | + llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem1, + convertModRefInfoToLLVM(memAttr.getTargetMem1())); call->setMemoryEffects(memEffects); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index cecff51..b7427a5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -411,6 +411,41 @@ getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) { llvm_unreachable("unhandled tcgen05.st lowering"); } +static llvm::Intrinsic::ID getFenceSyncRestrictID(NVVM::MemOrderKind order) { + return order == NVVM::MemOrderKind::ACQUIRE + ? llvm::Intrinsic:: + nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster + : llvm::Intrinsic:: + nvvm_fence_release_sync_restrict_space_cta_scope_cluster; +} + +static llvm::Intrinsic::ID +getFenceProxyID(NVVM::ProxyKind kind, std::optional<NVVM::SharedSpace> space) { + switch (kind) { + case NVVM::ProxyKind::alias: + return llvm::Intrinsic::nvvm_fence_proxy_alias; + case NVVM::ProxyKind::async: + return llvm::Intrinsic::nvvm_fence_proxy_async; + case NVVM::ProxyKind::async_global: + return llvm::Intrinsic::nvvm_fence_proxy_async_global; + case NVVM::ProxyKind::async_shared: + return *space == NVVM::SharedSpace::shared_cta + ? llvm::Intrinsic::nvvm_fence_proxy_async_shared_cta + : llvm::Intrinsic::nvvm_fence_proxy_async_shared_cluster; + default: + llvm_unreachable("unsupported proxy kind"); + } +} + +static llvm::Intrinsic::ID +getFenceProxySyncRestrictID(NVVM::MemOrderKind order) { + return order == NVVM::MemOrderKind::ACQUIRE + ? llvm::Intrinsic:: + nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster + : llvm::Intrinsic:: + nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster; +} + namespace { /// Implementation of the dialect interface that converts operations belonging /// to the NVVM dialect to LLVM IR. diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8edec99..03d67a5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -61,6 +61,8 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) { return llvm::omp::OMP_SCHEDULE_Auto; case omp::ClauseScheduleKind::Runtime: return llvm::omp::OMP_SCHEDULE_Runtime; + case omp::ClauseScheduleKind::Distribute: + return llvm::omp::OMP_SCHEDULE_Distribute; } llvm_unreachable("unhandled schedule clause argument"); } @@ -135,28 +137,31 @@ class LinearClauseProcessor { private: SmallVector<llvm::Value *> linearPreconditionVars; SmallVector<llvm::Value *> linearLoopBodyTemps; - SmallVector<llvm::AllocaInst *> linearOrigVars; SmallVector<llvm::Value *> linearOrigVal; SmallVector<llvm::Value *> linearSteps; + SmallVector<llvm::Type *> linearVarTypes; llvm::BasicBlock *linearFinalizationBB; llvm::BasicBlock *linearExitBB; llvm::BasicBlock *linearLastIterExitBB; public: + // Register type for the linear variables + void registerType(LLVM::ModuleTranslation &moduleTranslation, + mlir::Attribute &ty) { + linearVarTypes.push_back(moduleTranslation.convertType( + mlir::cast<mlir::TypeAttr>(ty).getValue())); + } + // Allocate space for linear variabes void createLinearVar(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - mlir::Value &linearVar) { - if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>( - moduleTranslation.lookupValue(linearVar))) { - linearPreconditionVars.push_back(builder.CreateAlloca( - linearVarAlloca->getAllocatedType(), nullptr, ".linear_var")); - llvm::Value *linearLoopBodyTemp = builder.CreateAlloca( - linearVarAlloca->getAllocatedType(), nullptr, ".linear_result"); - linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar)); - linearLoopBodyTemps.push_back(linearLoopBodyTemp); - linearOrigVars.push_back(linearVarAlloca); - } + mlir::Value &linearVar, int idx) { + linearPreconditionVars.push_back( + builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_var")); + llvm::Value *linearLoopBodyTemp = + builder.CreateAlloca(linearVarTypes[idx], nullptr, ".linear_result"); + linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar)); + linearLoopBodyTemps.push_back(linearLoopBodyTemp); } // Initialize linear step @@ -166,20 +171,15 @@ public: } // Emit IR for initialization of linear variables - llvm::OpenMPIRBuilder::InsertPointOrErrorTy - initLinearVar(llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - llvm::BasicBlock *loopPreHeader) { + void initLinearVar(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::BasicBlock *loopPreHeader) { builder.SetInsertPoint(loopPreHeader->getTerminator()); - for (size_t index = 0; index < linearOrigVars.size(); index++) { - llvm::LoadInst *linearVarLoad = builder.CreateLoad( - linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]); + for (size_t index = 0; index < linearOrigVal.size(); index++) { + llvm::LoadInst *linearVarLoad = + builder.CreateLoad(linearVarTypes[index], linearOrigVal[index]); builder.CreateStore(linearVarLoad, linearPreconditionVars[index]); } - llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - moduleTranslation.getOpenMPBuilder()->createBarrier( - builder.saveIP(), llvm::omp::OMPD_barrier); - return afterBarrierIP; } // Emit IR for updating Linear variables @@ -188,20 +188,24 @@ public: builder.SetInsertPoint(loopBody->getTerminator()); for (size_t index = 0; index < linearPreconditionVars.size(); index++) { // Emit increments for linear vars - llvm::LoadInst *linearVarStart = - builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), - - linearPreconditionVars[index]); + llvm::LoadInst *linearVarStart = builder.CreateLoad( + linearVarTypes[index], linearPreconditionVars[index]); auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]); - auto addInst = builder.CreateAdd(linearVarStart, mulInst); - builder.CreateStore(addInst, linearLoopBodyTemps[index]); + if (linearVarTypes[index]->isIntegerTy()) { + auto addInst = builder.CreateAdd(linearVarStart, mulInst); + builder.CreateStore(addInst, linearLoopBodyTemps[index]); + } else if (linearVarTypes[index]->isFloatingPointTy()) { + auto cvt = builder.CreateSIToFP(mulInst, linearVarTypes[index]); + auto addInst = builder.CreateFAdd(linearVarStart, cvt); + builder.CreateStore(addInst, linearLoopBodyTemps[index]); + } } } // Linear variable finalization is conditional on the last logical iteration. // Create BB splits to manage the same. - void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder, - llvm::BasicBlock *loopExit) { + void splitLinearFiniBB(llvm::IRBuilderBase &builder, + llvm::BasicBlock *loopExit) { linearFinalizationBB = loopExit->splitBasicBlock( loopExit->getTerminator(), "omp_loop.linear_finalization"); linearExitBB = linearFinalizationBB->splitBasicBlock( @@ -225,11 +229,10 @@ public: llvm::Type::getInt32Ty(builder.getContext()), 0)); // Store the linear variable values to original variables. builder.SetInsertPoint(linearLastIterExitBB->getTerminator()); - for (size_t index = 0; index < linearOrigVars.size(); index++) { + for (size_t index = 0; index < linearOrigVal.size(); index++) { llvm::LoadInst *linearVarTemp = - builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), - linearLoopBodyTemps[index]); - builder.CreateStore(linearVarTemp, linearOrigVars[index]); + builder.CreateLoad(linearVarTypes[index], linearLoopBodyTemps[index]); + builder.CreateStore(linearVarTemp, linearOrigVal[index]); } // Create conditional branch such that the linear variable @@ -253,7 +256,8 @@ public: users.push_back(user); for (auto *user : users) { if (auto *userInst = dyn_cast<llvm::Instruction>(user)) { - if (userInst->getParent()->getName().str() == BBName) + if (userInst->getParent()->getName().str().find(BBName) != + std::string::npos) user->replaceUsesOfWith(linearOrigVal[varIndex], linearLoopBodyTemps[varIndex]); } @@ -319,10 +323,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.getDevice()) result = todo("device"); }; - auto checkDistSchedule = [&todo](auto op, LogicalResult &result) { - if (op.getDistScheduleChunkSize()) - result = todo("dist_schedule with chunk_size"); - }; auto checkHint = [](auto op, LogicalResult &) { if (op.getHint()) op.emitWarning("hint clause discarded"); @@ -332,14 +332,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { op.getInReductionSyms()) result = todo("in_reduction"); }; - auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) { - if (!op.getIsDevicePtrVars().empty()) - result = todo("is_device_ptr"); - }; - auto checkLinear = [&todo](auto op, LogicalResult &result) { - if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty()) - result = todo("linear"); - }; auto checkNowait = [&todo](auto op, LogicalResult &result) { if (op.getNowait()) result = todo("nowait"); @@ -387,7 +379,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::DistributeOp op) { checkAllocate(op, result); - checkDistSchedule(op, result); checkOrder(op, result); }) .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); }) @@ -423,7 +414,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::WsloopOp op) { checkAllocate(op, result); - checkLinear(op, result); checkOrder(op, result); checkReduction(op, result); }) @@ -431,10 +421,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkAllocate(op, result); checkReduction(op, result); }) - .Case([&](omp::SimdOp op) { - checkLinear(op, result); - checkReduction(op, result); - }) + .Case([&](omp::SimdOp op) { checkReduction(op, result); }) .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp, omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); }) .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>( @@ -444,7 +431,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkBare(op, result); checkDevice(op, result); checkInReduction(op, result); - checkIsDevicePtr(op, result); }) .Default([](Operation &) { // Assume all clauses for an operation can be translated unless they are @@ -953,6 +939,9 @@ using OwningAtomicReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy( llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *, llvm::Value *)>; +using OwningDataPtrPtrReductionGen = + std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy( + llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *&)>; } // namespace /// Create an OpenMPIRBuilder-compatible reduction generator for the given @@ -1017,6 +1006,35 @@ makeAtomicReductionGen(omp::DeclareReductionOp decl, return atomicGen; } +/// Create an OpenMPIRBuilder-compatible `data_ptr_ptr` reduction generator for +/// the given reduction declaration. The generator uses `builder` but ignores +/// its insertion point. Returns null if there is no `data_ptr_ptr` region +/// available in the reduction declaration. +static OwningDataPtrPtrReductionGen +makeRefDataPtrGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, bool isByRef) { + if (!isByRef) + return OwningDataPtrPtrReductionGen(); + + OwningDataPtrPtrReductionGen refDataPtrGen = + [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, + llvm::Value *byRefVal, llvm::Value *&result) mutable + -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { + moduleTranslation.mapValue(decl.getDataPtrPtrRegionArg(), byRefVal); + builder.restoreIP(insertPoint); + SmallVector<llvm::Value *> phis; + if (failed(inlineConvertOmpRegions(decl.getDataPtrPtrRegion(), + "omp.data_ptr_ptr.body", builder, + moduleTranslation, &phis))) + return llvm::createStringError( + "failed to inline `data_ptr_ptr` region of `omp.declare_reduction`"); + result = llvm::getSingleElement(phis); + return builder.saveIP(); + }; + + return refDataPtrGen; +} + /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder, @@ -1170,6 +1188,7 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs, template <typename T> static void mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, DenseMap<Value, llvm::Value *> &reductionVariableMap, unsigned i) { @@ -1180,8 +1199,17 @@ mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation, mlir::Value mlirSource = loop.getReductionVars()[i]; llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource); - assert(llvmSource && "lookup reduction var"); - moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource); + llvm::Value *origVal = llvmSource; + // If a non-pointer value is expected, load the value from the source pointer. + if (!isa<LLVM::LLVMPointerType>( + reduction.getInitializerMoldArg().getType()) && + isa<LLVM::LLVMPointerType>(mlirSource.getType())) { + origVal = + builder.CreateLoad(moduleTranslation.convertType( + reduction.getInitializerMoldArg().getType()), + llvmSource, "omp_orig"); + } + moduleTranslation.mapValue(reduction.getInitializerMoldArg(), origVal); if (entry.getNumArguments() > 1) { llvm::Value *allocation = @@ -1254,7 +1282,7 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs, SmallVector<llvm::Value *, 1> phis; // map block argument to initializer region - mapInitializationArgs(op, moduleTranslation, reductionDecls, + mapInitializationArgs(op, moduleTranslation, builder, reductionDecls, reductionVariableMap, i); // TODO In some cases (specially on the GPU), the init regions may @@ -1310,8 +1338,10 @@ static void collectReductionInfo( SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, SmallVectorImpl<OwningReductionGen> &owningReductionGens, SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens, + SmallVector<OwningDataPtrPtrReductionGen> &owningDataPtrPtrReductionGens, const ArrayRef<llvm::Value *> privateReductionVariables, - SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) { + SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos, + ArrayRef<bool> isByRef) { unsigned numReductions = loop.getNumReductionVars(); for (unsigned i = 0; i < numReductions; ++i) { @@ -1319,6 +1349,8 @@ static void collectReductionInfo( makeReductionGen(reductionDecls[i], builder, moduleTranslation)); owningAtomicReductionGens.push_back( makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation)); + owningDataPtrPtrReductionGens.push_back(makeRefDataPtrGen( + reductionDecls[i], builder, moduleTranslation, isByRef[i])); } // Collect the reduction information. @@ -1329,12 +1361,28 @@ static void collectReductionInfo( atomicGen = owningAtomicReductionGens[i]; llvm::Value *variable = moduleTranslation.lookupValue(loop.getReductionVars()[i]); + mlir::Type allocatedType; + reductionDecls[i].getAllocRegion().walk([&](mlir::Operation *op) { + if (auto alloca = mlir::dyn_cast<LLVM::AllocaOp>(op)) { + allocatedType = alloca.getElemType(); + return mlir::WalkResult::interrupt(); + } + + return mlir::WalkResult::advance(); + }); + reductionInfos.push_back( {moduleTranslation.convertType(reductionDecls[i].getType()), variable, privateReductionVariables[i], /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar, owningReductionGens[i], - /*ReductionGenClang=*/nullptr, atomicGen}); + /*ReductionGenClang=*/nullptr, atomicGen, + owningDataPtrPtrReductionGens[i], + allocatedType ? moduleTranslation.convertType(allocatedType) : nullptr, + reductionDecls[i].getByrefElementType() + ? moduleTranslation.convertType( + *reductionDecls[i].getByrefElementType()) + : nullptr}); } } @@ -1392,7 +1440,8 @@ static LogicalResult createReductionsAndCleanup( SmallVector<OwningReductionGen> owningReductionGens; SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens; - SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos; + SmallVector<OwningDataPtrPtrReductionGen> owningReductionGenRefDataPtrGens; + SmallVector<llvm::OpenMPIRBuilder::ReductionInfo, 2> reductionInfos; llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); @@ -1400,7 +1449,8 @@ static LogicalResult createReductionsAndCleanup( // ReductionInfo only accepts references to the generators. collectReductionInfo(op, builder, moduleTranslation, reductionDecls, owningReductionGens, owningAtomicReductionGens, - privateReductionVariables, reductionInfos); + owningReductionGenRefDataPtrGens, + privateReductionVariables, reductionInfos, isByRef); // The call to createReductions below expects the block to have a // terminator. Create an unreachable instruction to serve as terminator @@ -1907,7 +1957,7 @@ static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) { // If we are going to use distribute reduction then remove any debug uses of // the reduction parameters in teamsOp. Otherwise they will be left without // any mapped value in moduleTranslation and will eventually error out. - for (auto use : debugUses) + for (auto *use : debugUses) use->erase(); return true; } @@ -2484,6 +2534,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, chunk = builder.CreateSExtOrTrunc(chunkVar, ivType); } + omp::DistributeOp distributeOp = nullptr; + llvm::Value *distScheduleChunk = nullptr; + bool hasDistSchedule = false; + if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) { + distributeOp = cast<omp::DistributeOp>(opInst.getParentOp()); + hasDistSchedule = distributeOp.getDistScheduleStatic(); + if (distributeOp.getDistScheduleChunkSize()) { + llvm::Value *chunkVar = moduleTranslation.lookupValue( + distributeOp.getDistScheduleChunkSize()); + distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType); + } + } + PrivateVarsInfo privateVarsInfo(wsloopOp); SmallVector<omp::DeclareReductionOp> reductionDecls; @@ -2553,10 +2616,15 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // Initialize linear variables and linear step LinearClauseProcessor linearClauseProcessor; + if (!wsloopOp.getLinearVars().empty()) { - for (mlir::Value linearVar : wsloopOp.getLinearVars()) + auto linearVarTypes = wsloopOp.getLinearVarTypes().value(); + for (mlir::Attribute linearVarType : linearVarTypes) + linearClauseProcessor.registerType(moduleTranslation, linearVarType); + + for (auto [idx, linearVar] : llvm::enumerate(wsloopOp.getLinearVars())) linearClauseProcessor.createLinearVar(builder, moduleTranslation, - linearVar); + linearVar, idx); for (mlir::Value linearStep : wsloopOp.getLinearStepVars()) linearClauseProcessor.initLinearStep(moduleTranslation, linearStep); } @@ -2571,16 +2639,17 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // Emit Initialization and Update IR for linear variables if (!wsloopOp.getLinearVars().empty()) { + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - linearClauseProcessor.initLinearVar(builder, moduleTranslation, - loopInfo->getPreheader()); + moduleTranslation.getOpenMPBuilder()->createBarrier( + builder.saveIP(), llvm::omp::OMPD_barrier); if (failed(handleError(afterBarrierIP, *loopOp))) return failure(); builder.restoreIP(*afterBarrierIP); linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), loopInfo->getIndVar()); - linearClauseProcessor.outlineLinearFinalizationBB(builder, - loopInfo->getExit()); + linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit()); } builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); @@ -2611,7 +2680,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, convertToScheduleKind(schedule), chunk, isSimd, scheduleMod == omp::ScheduleModifier::monotonic, scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType, noLoopMode); + workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk); if (failed(handleError(wsloopIP, opInst))) return failure(); @@ -2655,6 +2724,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref()); assert(isByRef.size() == opInst.getNumReductionVars()); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + bool isCancellable = constructIsCancellable(opInst); if (failed(checkImplementationStatus(*opInst))) return failure(); @@ -2729,10 +2799,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Collect reduction info SmallVector<OwningReductionGen> owningReductionGens; SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens; - SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos; + SmallVector<OwningDataPtrPtrReductionGen> + owningReductionGenRefDataPtrGens; + SmallVector<llvm::OpenMPIRBuilder::ReductionInfo, 2> reductionInfos; collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls, owningReductionGens, owningAtomicReductionGens, - privateReductionVariables, reductionInfos); + owningReductionGenRefDataPtrGens, + privateReductionVariables, reductionInfos, isByRef); // Move to region cont block builder.SetInsertPoint((*regionBlock)->getTerminator()); @@ -2790,6 +2863,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, privateVarsInfo.privatizers))) return llvm::make_error<PreviouslyReportedError>(); + // If we could be performing cancellation, add the cancellation barrier on + // the way out of the outlined region. + if (isCancellable) { + auto IPOrErr = ompBuilder->createBarrier( + llvm::OpenMPIRBuilder::LocationDescription(builder), + llvm::omp::Directive::OMPD_unknown, + /* ForceSimpleCall */ false, + /* CheckCancelFlag */ false); + if (!IPOrErr) + return IPOrErr.takeError(); + } + builder.restoreIP(oldIP); return llvm::Error::success(); }; @@ -2803,7 +2888,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) pbKind = getProcBindKind(*bind); - bool isCancellable = constructIsCancellable(opInst); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); @@ -2858,6 +2942,20 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + // Initialize linear variables and linear step + LinearClauseProcessor linearClauseProcessor; + + if (!simdOp.getLinearVars().empty()) { + auto linearVarTypes = simdOp.getLinearVarTypes().value(); + for (mlir::Attribute linearVarType : linearVarTypes) + linearClauseProcessor.registerType(moduleTranslation, linearVarType); + for (auto [idx, linearVar] : llvm::enumerate(simdOp.getLinearVars())) + linearClauseProcessor.createLinearVar(builder, moduleTranslation, + linearVar, idx); + for (mlir::Value linearStep : simdOp.getLinearStepVars()) + linearClauseProcessor.initLinearStep(moduleTranslation, linearStep); + } + llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars( builder, moduleTranslation, privateVarsInfo, allocaIP); if (handleError(afterAllocas, opInst).failed()) @@ -2927,14 +3025,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(handleError(regionBlock, opInst))) return failure(); - builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); + // Emit Initialization for linear variables + if (simdOp.getLinearVars().size()) { + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); + + linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), + loopInfo->getIndVar()); + } + builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); + ompBuilder->applySimd(loopInfo, alignedVars, simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) : nullptr, order, simdlen, safelen); + for (size_t index = 0; index < simdOp.getLinearVars().size(); index++) + linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region", + index); + // We now need to reduce the per-simd-lane reduction variable into the // original variable. This works a bit differently to other reductions (e.g. // wsloop) because we don't need to call into the OpenMP runtime to handle @@ -3632,10 +3743,23 @@ convertToCaptureClauseKind( return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink; case mlir::omp::DeclareTargetCaptureClause::enter: return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter; + case mlir::omp::DeclareTargetCaptureClause::none: + return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryNone; } llvm_unreachable("unhandled capture clause"); } +static Operation *getGlobalOpFromValue(Value value) { + Operation *op = value.getDefiningOp(); + if (auto addrCast = dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op)) + op = addrCast->getOperand(0).getDefiningOp(); + if (auto addressOfOp = dyn_cast_if_present<LLVM::AddressOfOp>(op)) { + auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>(); + return modOp.lookupSymbol(addressOfOp.getGlobalName()); + } + return nullptr; +} + static llvm::SmallString<64> getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder) { @@ -3658,62 +3782,58 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, return suffix; } -static bool isDeclareTargetLink(mlir::Value value) { - if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) { - auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>(); - Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName()); - if (auto declareTargetGlobal = - llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp)) - if (declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::link) - return true; - } +static bool isDeclareTargetLink(Value value) { + if (auto declareTargetGlobal = + dyn_cast_if_present<omp::DeclareTargetInterface>( + getGlobalOpFromValue(value))) + if (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::link) + return true; + return false; +} + +static bool isDeclareTargetTo(Value value) { + if (auto declareTargetGlobal = + dyn_cast_if_present<omp::DeclareTargetInterface>( + getGlobalOpFromValue(value))) + if (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::to || + declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::enter) + return true; return false; } -// Returns the reference pointer generated by the lowering of the declare target -// operation in cases where the link clause is used or the to clause is used in -// USM mode. +// Returns the reference pointer generated by the lowering of the declare +// target operation in cases where the link clause is used or the to clause is +// used in USM mode. static llvm::Value * -getRefPtrIfDeclareTarget(mlir::Value value, +getRefPtrIfDeclareTarget(Value value, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - Operation *op = value.getDefiningOp(); - if (auto addrCast = llvm::dyn_cast_if_present<LLVM::AddrSpaceCastOp>(op)) - op = addrCast->getOperand(0).getDefiningOp(); - - // An easier way to do this may just be to keep track of any pointer - // references and their mapping to their respective operation - if (auto addressOfOp = llvm::dyn_cast_if_present<LLVM::AddressOfOp>(op)) { - if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>( - addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol( - addressOfOp.getGlobalName()))) { - - if (auto declareTargetGlobal = - llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( - gOp.getOperation())) { - - // In this case, we must utilise the reference pointer generated by the - // declare target operation, similar to Clang - if ((declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::link) || - (declareTargetGlobal.getDeclareTargetCaptureClause() == - mlir::omp::DeclareTargetCaptureClause::to && - ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { - llvm::SmallString<64> suffix = - getDeclareTargetRefPtrSuffix(gOp, *ompBuilder); - - if (gOp.getSymName().contains(suffix)) - return moduleTranslation.getLLVMModule()->getNamedValue( - gOp.getSymName()); + if (auto gOp = + dyn_cast_or_null<LLVM::GlobalOp>(getGlobalOpFromValue(value))) { + if (auto declareTargetGlobal = + dyn_cast<omp::DeclareTargetInterface>(gOp.getOperation())) { + // In this case, we must utilise the reference pointer generated by + // the declare target operation, similar to Clang + if ((declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::link) || + (declareTargetGlobal.getDeclareTargetCaptureClause() == + omp::DeclareTargetCaptureClause::to && + ompBuilder->Config.hasRequiresUnifiedSharedMemory())) { + llvm::SmallString<64> suffix = + getDeclareTargetRefPtrSuffix(gOp, *ompBuilder); + if (gOp.getSymName().contains(suffix)) return moduleTranslation.getLLVMModule()->getNamedValue( - (gOp.getSymName().str() + suffix.str()).str()); - } + gOp.getSymName()); + + return moduleTranslation.getLLVMModule()->getNamedValue( + (gOp.getSymName().str() + suffix.str()).str()); } } } - return nullptr; } @@ -3756,6 +3876,32 @@ struct MapInfoData : MapInfosTy { MapInfosTy::append(CurInfo); } }; + +enum class TargetDirectiveEnumTy : uint32_t { + None = 0, + Target = 1, + TargetData = 2, + TargetEnterData = 3, + TargetExitData = 4, + TargetUpdate = 5 +}; + +static TargetDirectiveEnumTy getTargetDirectiveEnumTyFromOp(Operation *op) { + return llvm::TypeSwitch<Operation *, TargetDirectiveEnumTy>(op) + .Case([](omp::TargetDataOp) { return TargetDirectiveEnumTy::TargetData; }) + .Case([](omp::TargetEnterDataOp) { + return TargetDirectiveEnumTy::TargetEnterData; + }) + .Case([&](omp::TargetExitDataOp) { + return TargetDirectiveEnumTy::TargetExitData; + }) + .Case([&](omp::TargetUpdateOp) { + return TargetDirectiveEnumTy::TargetUpdate; + }) + .Case([&](omp::TargetOp) { return TargetDirectiveEnumTy::Target; }) + .Default([&](Operation *op) { return TargetDirectiveEnumTy::None; }); +} + } // namespace static uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, @@ -3787,7 +3933,7 @@ static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, // This calculates the size to transfer based on bounds and the underlying // element type, provided bounds have been specified (Fortran // pointers/allocatables/target and arrays that have sections specified fall - // into this as well). + // into this as well) if (!memberClause.getBounds().empty()) { llvm::Value *elementCount = builder.getInt64(1); for (auto bounds : memberClause.getBounds()) { @@ -3835,6 +3981,9 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) { auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) { return (mlirFlags & flag) == flag; }; + const bool hasExplicitMap = + (mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) != + omp::ClauseMapFlags::none; llvm::omp::OpenMPOffloadMappingFlags mapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; @@ -3875,6 +4024,12 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) { if (mapTypeToBool(omp::ClauseMapFlags::attach)) mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH; + if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) { + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; + if (!hasExplicitMap) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL; + } + return mapType; } @@ -3910,10 +4065,12 @@ static void collectMapDataFromMapOperands( mapData.Pointers.push_back(mapData.OriginalValue.back()); if (llvm::Value *refPtr = - getRefPtrIfDeclareTarget(offloadPtr, - moduleTranslation)) { // declare target + getRefPtrIfDeclareTarget(offloadPtr, moduleTranslation)) { mapData.IsDeclareTarget.push_back(true); mapData.BasePointers.push_back(refPtr); + } else if (isDeclareTargetTo(offloadPtr)) { + mapData.IsDeclareTarget.push_back(true); + mapData.BasePointers.push_back(mapData.OriginalValue.back()); } else { // regular mapped variable mapData.IsDeclareTarget.push_back(false); mapData.BasePointers.push_back(mapData.OriginalValue.back()); @@ -3996,6 +4153,9 @@ static void collectMapDataFromMapOperands( llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr); auto mapType = convertClauseMapFlags(mapOp.getMapType()); auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + bool isDevicePtr = + (mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) != + omp::ClauseMapFlags::none; mapData.OriginalValue.push_back(origValue); mapData.BasePointers.push_back(origValue); @@ -4022,14 +4182,18 @@ static void collectMapDataFromMapOperands( mapData.Mappers.push_back(nullptr); } } else { + // For is_device_ptr we need the map type to propagate so the runtime + // can materialize the device-side copy of the pointer container. mapData.Types.push_back( - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL); + isDevicePtr ? mapType + : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL); mapData.Mappers.push_back(nullptr); } mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back( - llvm::OpenMPIRBuilder::DeviceInfoTy::Address); + isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer + : llvm::OpenMPIRBuilder::DeviceInfoTy::Address); mapData.IsAMapping.push_back(false); mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp)); } @@ -4042,41 +4206,66 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) { return std::distance(mapData.MapClause.begin(), res); } +static void sortMapIndices(llvm::SmallVectorImpl<size_t> &indices, + omp::MapInfoOp mapInfo) { + ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); + llvm::SmallVector<size_t> occludedChildren; + llvm::sort( + indices.begin(), indices.end(), [&](const size_t a, const size_t b) { + // Bail early if we are asked to look at the same index. If we do not + // bail early, we can end up mistakenly adding indices to + // occludedChildren. This can occur with some types of libc++ hardening. + if (a == b) + return false; + + auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); + auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); + + for (auto it : llvm::zip(memberIndicesA, memberIndicesB)) { + int64_t aIndex = mlir::cast<IntegerAttr>(std::get<0>(it)).getInt(); + int64_t bIndex = mlir::cast<IntegerAttr>(std::get<1>(it)).getInt(); + + if (aIndex == bIndex) + continue; + + if (aIndex < bIndex) + return true; + + if (aIndex > bIndex) + return false; + } + + // Iterated up until the end of the smallest member and + // they were found to be equal up to that point, so select + // the member with the lowest index count, so the "parent" + bool memberAParent = memberIndicesA.size() < memberIndicesB.size(); + if (memberAParent) + occludedChildren.push_back(b); + else + occludedChildren.push_back(a); + return memberAParent; + }); + + // We remove children from the index list that are overshadowed by + // a parent, this prevents us retrieving these as the first or last + // element when the parent is the correct element in these cases. + for (auto v : occludedChildren) + indices.erase(std::remove(indices.begin(), indices.end(), v), + indices.end()); +} + static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, bool first) { ArrayAttr indexAttr = mapInfo.getMembersIndexAttr(); // Only 1 member has been mapped, we can return it. if (indexAttr.size() == 1) return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()); - llvm::SmallVector<size_t> indices(indexAttr.size()); std::iota(indices.begin(), indices.end(), 0); - - llvm::sort(indices, [&](const size_t a, const size_t b) { - auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); - auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); - for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { - int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); - int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); - - if (aIndex == bIndex) - continue; - - if (aIndex < bIndex) - return first; - - if (aIndex > bIndex) - return !first; - } - - // Iterated the up until the end of the smallest member and - // they were found to be equal up to that point, so select - // the member with the lowest index count, so the "parent" - return memberIndicesA.size() < memberIndicesB.size(); - }); - + sortMapIndices(indices, mapInfo); return llvm::cast<omp::MapInfoOp>( - mapInfo.getMembers()[indices.front()].getDefiningOp()); + mapInfo.getMembers()[first ? indices.front() : indices.back()] + .getDefiningOp()); } /// This function calculates the array/pointer offset for map data provided @@ -4155,6 +4344,86 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, return idx; } +static void getAsIntegers(ArrayAttr values, llvm::SmallVector<int64_t> &ints) { + llvm::transform(values, std::back_inserter(ints), [](Attribute value) { + return cast<IntegerAttr>(value).getInt(); + }); +} + +// Gathers members that are overlapping in the parent, excluding members that +// themselves overlap, keeping the top-most (closest to parents level) map. +static void +getOverlappedMembers(llvm::SmallVectorImpl<size_t> &overlapMapDataIdxs, + omp::MapInfoOp parentOp) { + // No members mapped, no overlaps. + if (parentOp.getMembers().empty()) + return; + + // Single member, we can insert and return early. + if (parentOp.getMembers().size() == 1) { + overlapMapDataIdxs.push_back(0); + return; + } + + // 1) collect list of top-level overlapping members from MemberOp + llvm::SmallVector<std::pair<int, ArrayAttr>> memberByIndex; + ArrayAttr indexAttr = parentOp.getMembersIndexAttr(); + for (auto [memIndex, indicesAttr] : llvm::enumerate(indexAttr)) + memberByIndex.push_back( + std::make_pair(memIndex, cast<ArrayAttr>(indicesAttr))); + + // Sort the smallest first (higher up the parent -> member chain), so that + // when we remove members, we remove as much as we can in the initial + // iterations, shortening the number of passes required. + llvm::sort(memberByIndex.begin(), memberByIndex.end(), + [&](auto a, auto b) { return a.second.size() < b.second.size(); }); + + // Remove elements from the vector if there is a parent element that + // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1], + // [0,2].. etc. + llvm::SmallVector<std::pair<int, ArrayAttr>> skipList; + for (auto v : memberByIndex) { + llvm::SmallVector<int64_t> vArr(v.second.size()); + getAsIntegers(v.second, vArr); + skipList.push_back( + *std::find_if(memberByIndex.begin(), memberByIndex.end(), [&](auto x) { + if (v == x) + return false; + llvm::SmallVector<int64_t> xArr(x.second.size()); + getAsIntegers(x.second, xArr); + return std::equal(vArr.begin(), vArr.end(), xArr.begin()) && + xArr.size() >= vArr.size(); + })); + } + + // Collect the indices, as we need the base pointer etc. from the MapData + // structure which is primarily accessible via index at the moment. + for (auto v : memberByIndex) + if (find(skipList.begin(), skipList.end(), v) == skipList.end()) + overlapMapDataIdxs.push_back(v.first); +} + +// The intent is to verify if the mapped data being passed is a +// pointer -> pointee that requires special handling in certain cases, +// e.g. applying the OMP_MAP_PTR_AND_OBJ map type. +// +// There may be a better way to verify this, but unfortunately with +// opaque pointers we lose the ability to easily check if something is +// a pointer whilst maintaining access to the underlying type. +static bool checkIfPointerMap(omp::MapInfoOp mapOp) { + // If we have a varPtrPtr field assigned then the underlying type is a pointer + if (mapOp.getVarPtrPtr()) + return true; + + // If the map data is declare target with a link clause, then it's represented + // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has + // no relation to pointers. + if (isDeclareTargetLink(mapOp.getVarPtr())) + return true; + + return false; +} + // This creates two insertions into the MapInfosTy data structure for the // "parent" of a set of members, (usually a container e.g. // class/structure/derived type) when subsequent members have also been @@ -4173,7 +4442,8 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, - MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) { + MapInfoData &mapData, uint64_t mapDataIndex, + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); @@ -4182,7 +4452,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // base entry so the mapper receives correct copy semantics via its 'type' // parameter. Also keep TARGET_PARAM when required for kernel arguments. llvm::omp::OpenMPOffloadMappingFlags baseFlag = - isTargetParams + (targetDirective == TargetDirectiveEnumTy::Target && + !mapData.IsDeclareTarget[mapDataIndex]) ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; @@ -4217,7 +4488,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // runtime information on the dynamically allocated data). auto parentClause = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]); - llvm::Value *lowAddr, *highAddr; if (!parentClause.getPartialMap()) { lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex], @@ -4263,39 +4533,85 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // further case specific flag modifications). For the moment, it handles // what we support as expected. llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex]; + bool hasMapClose = (llvm::omp::OpenMPOffloadMappingFlags(mapFlag) & + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE) == + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); - combinedInfo.Types.emplace_back(mapFlag); - combinedInfo.DevicePointers.emplace_back( - llvm::OpenMPIRBuilder::DeviceInfoTy::None); - combinedInfo.Mappers.emplace_back(nullptr); - combinedInfo.Names.emplace_back(LLVM::createMappingInformation( - mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); - combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); - combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); - combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]); - } - return memberOfFlag; -} - -// The intent is to verify if the mapped data being passed is a -// pointer -> pointee that requires special handling in certain cases, -// e.g. applying the OMP_MAP_PTR_AND_OBJ map type. -// -// There may be a better way to verify this, but unfortunately with -// opaque pointers we lose the ability to easily check if something is -// a pointer whilst maintaining access to the underlying type. -static bool checkIfPointerMap(omp::MapInfoOp mapOp) { - // If we have a varPtrPtr field assigned then the underlying type is a pointer - if (mapOp.getVarPtrPtr()) - return true; - // If the map data is declare target with a link clause, then it's represented - // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has - // no relation to pointers. - if (isDeclareTargetLink(mapOp.getVarPtr())) - return true; + if (targetDirective == TargetDirectiveEnumTy::TargetUpdate || hasMapClose) { + combinedInfo.Types.emplace_back(mapFlag); + combinedInfo.DevicePointers.emplace_back( + mapData.DevicePointers[mapDataIndex]); + combinedInfo.Names.emplace_back(LLVM::createMappingInformation( + mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); + combinedInfo.BasePointers.emplace_back( + mapData.BasePointers[mapDataIndex]); + combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]); + combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); + } else { + llvm::SmallVector<size_t> overlapIdxs; + // Find all of the members that "overlap", i.e. occlude other members that + // were mapped alongside the parent, e.g. member [0], occludes [0,1] and + // [0,2], but not [1,0]. + getOverlappedMembers(overlapIdxs, parentClause); + // We need to make sure the overlapped members are sorted in order of + // lowest address to highest address. + sortMapIndices(overlapIdxs, parentClause); + + lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex], + builder.getPtrTy()); + highAddr = builder.CreatePointerCast( + builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex], + mapData.Pointers[mapDataIndex], 1), + builder.getPtrTy()); + + // TODO: We may want to skip arrays/array sections in this as Clang does. + // It appears to be an optimisation rather than a necessity though, + // but this requires further investigation. However, we would have to make + // sure to not exclude maps with bounds that ARE pointers, as these are + // processed as separate components, i.e. pointer + data. + for (auto v : overlapIdxs) { + auto mapDataOverlapIdx = getMapDataMemberIdx( + mapData, + cast<omp::MapInfoOp>(parentClause.getMembers()[v].getDefiningOp())); + combinedInfo.Types.emplace_back(mapFlag); + combinedInfo.DevicePointers.emplace_back( + mapData.DevicePointers[mapDataOverlapIdx]); + combinedInfo.Names.emplace_back(LLVM::createMappingInformation( + mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); + combinedInfo.BasePointers.emplace_back( + mapData.BasePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); + combinedInfo.Pointers.emplace_back(lowAddr); + combinedInfo.Sizes.emplace_back(builder.CreateIntCast( + builder.CreatePtrDiff(builder.getInt8Ty(), + mapData.OriginalValue[mapDataOverlapIdx], + lowAddr), + builder.getInt64Ty(), /*isSigned=*/true)); + lowAddr = builder.CreateConstGEP1_32( + checkIfPointerMap(llvm::cast<omp::MapInfoOp>( + mapData.MapClause[mapDataOverlapIdx])) + ? builder.getPtrTy() + : mapData.BaseType[mapDataOverlapIdx], + mapData.BasePointers[mapDataOverlapIdx], 1); + } - return false; + combinedInfo.Types.emplace_back(mapFlag); + combinedInfo.DevicePointers.emplace_back( + mapData.DevicePointers[mapDataIndex]); + combinedInfo.Names.emplace_back(LLVM::createMappingInformation( + mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); + combinedInfo.BasePointers.emplace_back( + mapData.BasePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); + combinedInfo.Pointers.emplace_back(lowAddr); + combinedInfo.Sizes.emplace_back(builder.CreateIntCast( + builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr), + builder.getInt64Ty(), true)); + } + } + return memberOfFlag; } // This function is intended to add explicit mappings of members @@ -4303,7 +4619,8 @@ static void processMapMembersWithParent( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, - llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) { + llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); @@ -4348,8 +4665,15 @@ static void processMapMembersWithParent( mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF; ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); - if (checkIfPointerMap(memberClause)) + bool isDeclTargetTo = isDeclareTargetTo(parentClause.getVarPtr() + ? parentClause.getVarPtr() + : parentClause.getVarPtrPtr()); + if (checkIfPointerMap(memberClause) && + (!isDeclTargetTo || + (targetDirective != TargetDirectiveEnumTy::TargetUpdate && + targetDirective != TargetDirectiveEnumTy::TargetData))) { mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; + } combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( @@ -4375,7 +4699,8 @@ static void processMapMembersWithParent( } static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, - MapInfosTy &combinedInfo, bool isTargetParams, + MapInfosTy &combinedInfo, + TargetDirectiveEnumTy targetDirective, int mapDataParentIdx = -1) { // Declare Target Mappings are excluded from being marked as // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're @@ -4387,7 +4712,8 @@ static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, if (isPtrTy) mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ; - if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx]) + if (targetDirective == TargetDirectiveEnumTy::Target && + !mapData.IsDeclareTarget[mapDataIdx]) mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy && @@ -4416,7 +4742,7 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, uint64_t mapDataIndex, - bool isTargetParams) { + TargetDirectiveEnumTy targetDirective) { assert(!ompBuilder.Config.isTargetDevice() && "function only supported for host device codegen"); @@ -4440,17 +4766,18 @@ static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, // Clang maps array without bounds as pointers (which we do not // currently do), whereas we treat them as arrays in all cases // currently. - processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams, + processIndividualMap(mapData, memberDataIdx, combinedInfo, targetDirective, mapDataIndex); return; } llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag = mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl, - combinedInfo, mapData, mapDataIndex, isTargetParams); + combinedInfo, mapData, mapDataIndex, + targetDirective); processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl, combinedInfo, mapData, mapDataIndex, - memberOfParentFlag); + memberOfParentFlag, targetDirective); } // This is a variation on Clang's GenerateOpenMPCapturedVars, which @@ -4528,10 +4855,10 @@ createAlteredByCaptureMap(MapInfoData &mapData, static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, MapInfosTy &combinedInfo, - MapInfoData &mapData, bool isTargetParams = false) { + MapInfoData &mapData, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); - // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can // involve generating new loads and stores, which changes the @@ -4561,22 +4888,24 @@ static void genMapInfos(llvm::IRBuilderBase &builder, auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]); if (!mapInfoOp.getMembers().empty()) { processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl, - combinedInfo, mapData, i, isTargetParams); + combinedInfo, mapData, i, targetDirective); continue; } - processIndividualMap(mapData, i, combinedInfo, isTargetParams); + processIndividualMap(mapData, i, combinedInfo, targetDirective); } } static llvm::Expected<llvm::Function *> emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - llvm::StringRef mapperFuncName); + llvm::StringRef mapperFuncName, + TargetDirectiveEnumTy targetDirective); static llvm::Expected<llvm::Function *> getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { + LLVM::ModuleTranslation &moduleTranslation, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); auto declMapperOp = cast<omp::DeclareMapperOp>(op); @@ -4588,13 +4917,14 @@ getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, return lookupFunc; return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation, - mapperFuncName); + mapperFuncName, targetDirective); } static llvm::Expected<llvm::Function *> emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - llvm::StringRef mapperFuncName) { + llvm::StringRef mapperFuncName, + TargetDirectiveEnumTy targetDirective) { assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() && "function only supported for host device codegen"); auto declMapperOp = cast<omp::DeclareMapperOp>(op); @@ -4622,10 +4952,11 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, MapInfoData mapData; collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, builder); - genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData, + targetDirective); - // Drop the mapping that is no longer necessary so that the same region can - // be processed multiple times. + // Drop the mapping that is no longer necessary so that the same region + // can be processed multiple times. moduleTranslation.forgetMapping(declMapperOp.getRegion()); return combinedInfo; }; @@ -4634,7 +4965,7 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, if (!combinedInfo.Mappers[i]) return nullptr; return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper( @@ -4655,10 +4986,12 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, SmallVector<Value> useDeviceAddrVars; llvm::omp::RuntimeFunction RTLFn; DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>()); + TargetDirectiveEnumTy targetDirective = getTargetDirectiveEnumTyFromOp(op); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true, - /*SeparateBeginEndCalls=*/true); + llvm::OpenMPIRBuilder::TargetDataInfo info( + /*RequiresDevicePointerInfo=*/true, + /*SeparateBeginEndCalls=*/true); bool isTargetDevice = ompBuilder->Config.isTargetDevice(); bool isOffloadEntry = isTargetDevice || !ompBuilder->Config.TargetTriples.empty(); @@ -4757,7 +5090,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, MapInfosTy combinedInfo; auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); - genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData); + genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData, + targetDirective); return combinedInfo; }; @@ -4873,7 +5207,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return nullptr; info.HasMapper = true; return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); @@ -4980,15 +5314,18 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) { // TODO: Add support for clauses which are valid for DISTRIBUTE // constructs. Static schedule is the default. - auto schedule = omp::ClauseScheduleKind::Static; - bool isOrdered = false; + bool hasDistSchedule = distributeOp.getDistScheduleStatic(); + auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute + : omp::ClauseScheduleKind::Static; + // dist_schedule clauses are ordered - otherise this should be false + bool isOrdered = hasDistSchedule; std::optional<omp::ScheduleModifier> scheduleMod; bool isSimd = false; llvm::omp::WorksharingLoopType workshareLoopType = llvm::omp::WorksharingLoopType::DistributeStaticLoop; bool loopNeedsBarrier = false; - llvm::Value *chunk = nullptr; - + llvm::Value *chunk = moduleTranslation.lookupValue( + distributeOp.getDistScheduleChunkSize()); llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = @@ -4997,12 +5334,11 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, convertToScheduleKind(schedule), chunk, isSimd, scheduleMod == omp::ScheduleModifier::monotonic, scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType); + workshareLoopType, false, hasDistSchedule, chunk); if (!wsloopIP) return wsloopIP.takeError(); } - if (failed(cleanupPrivateVars(builder, moduleTranslation, distributeOp.getLoc(), privVarsInfo.llvmVars, privVarsInfo.privatizers))) @@ -5135,11 +5471,16 @@ handleDeclareTargetMapVar(MapInfoData &mapData, for (llvm::User *user : userVec) { if (auto *insn = dyn_cast<llvm::Instruction>(user)) { if (insn->getFunction() == func) { - builder.SetCurrentDebugLocation(insn->getDebugLoc()); - auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(), - mapData.BasePointers[i]); - load->moveBefore(insn->getIterator()); - user->replaceUsesOfWith(mapData.OriginalValue[i], load); + auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]); + llvm::Value *substitute = mapData.BasePointers[i]; + if (isDeclareTargetLink(mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() + : mapOp.getVarPtr())) { + builder.SetCurrentDebugLocation(insn->getDebugLoc()); + substitute = builder.CreateLoad( + mapData.BasePointers[i]->getType(), mapData.BasePointers[i]); + cast<llvm::LoadInst>(substitute)->moveBefore(insn->getIterator()); + } + user->replaceUsesOfWith(mapData.OriginalValue[i], substitute); } } } @@ -5431,8 +5772,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, int32_t minTeamsVal = 1, maxTeamsVal = -1; if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) { - // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match - // clang and set min and max to the same value. + // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, + // match clang and set min and max to the same value. if (numTeamsUpper) { if (auto val = extractConstInteger(numTeamsUpper)) minTeamsVal = maxTeamsVal = *val; @@ -5624,9 +5965,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>(); auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst); auto &targetRegion = targetOp.getRegion(); - // Holds the private vars that have been mapped along with the block argument - // that corresponds to the MapInfoOp corresponding to the private var in - // question. So, for instance: + // Holds the private vars that have been mapped along with the block + // argument that corresponds to the MapInfoOp corresponding to the private + // var in question. So, for instance: // // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..) // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1) @@ -5641,6 +5982,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs(); ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs(); llvm::Function *llvmOutlinedFn = nullptr; + TargetDirectiveEnumTy targetDirective = + getTargetDirectiveEnumTyFromOp(&opInst); // TODO: It can also be false if a compile-time constant `false` IF clause is // specified. @@ -5802,7 +6145,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); - genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true); + genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, + targetDirective); return combinedInfos; }; @@ -5882,7 +6226,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, return nullptr; info.HasMapper = true; return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder, - moduleTranslation); + moduleTranslation, targetDirective); }; llvm::Value *ifCond = nullptr; diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index d9891e3..d7d215b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -34,12 +34,14 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Comdat.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugProgramInstruction.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/ModRef.h" #include <optional> @@ -522,6 +524,11 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) { debugIntrinsics.insert(intrinsic); } +void ModuleImport::addDebugRecord(llvm::DbgVariableRecord *dbgRecord) { + if (!dbgRecords.contains(dbgRecord)) + dbgRecords.insert(dbgRecord); +} + static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule, llvm::MDTuple *mdTuple) { auto getLLVMFunction = @@ -1214,7 +1221,7 @@ static TypedAttr getScalarConstantAsAttr(OpBuilder &builder, llvm::Constant *constScalar) { MLIRContext *context = builder.getContext(); - // Convert scalar intergers. + // Convert scalar integers. if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) { return builder.getIntegerAttr( IntegerType::get(context, constInt->getBitWidth()), @@ -2003,9 +2010,15 @@ FloatAttr ModuleImport::matchFloatAttr(llvm::Value *value) { return floatAttr; } -DILocalVariableAttr ModuleImport::matchLocalVariableAttr(llvm::Value *value) { - auto *nodeAsVal = cast<llvm::MetadataAsValue>(value); - auto *node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata()); +DILocalVariableAttr ModuleImport::matchLocalVariableAttr( + llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> valOrVariable) { + llvm::DILocalVariable *node = nullptr; + if (auto *value = dyn_cast<llvm::Value *>(valOrVariable)) { + auto *nodeAsVal = cast<llvm::MetadataAsValue>(value); + node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata()); + } else { + node = cast<llvm::DILocalVariable *>(valOrVariable); + } return debugImporter->translate(node); } @@ -2544,6 +2557,41 @@ LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) { if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(inst)) return convertIntrinsic(intrinsic); + // Process debug records attached to this instruction. Debug variable records + // are stored for later processing after all SSA values are converted, while + // debug label records can be converted immediately. + if (inst->DebugMarker) { + for (llvm::DbgRecord &dbgRecord : inst->DebugMarker->getDbgRecordRange()) { + // Store debug variable records for later processing. + if (auto *dbgVariableRecord = + dyn_cast<llvm::DbgVariableRecord>(&dbgRecord)) { + addDebugRecord(dbgVariableRecord); + continue; + } + Location loc = translateLoc(dbgRecord.getDebugLoc()); + auto emitUnsupportedWarning = [&]() -> LogicalResult { + if (!emitExpensiveWarnings) + return success(); + std::string options; + llvm::raw_string_ostream optionsStream(options); + dbgRecord.print(optionsStream); + emitWarning(loc) << "unhandled debug record " << optionsStream.str(); + return success(); + }; + // Convert the debug label records in-place. + if (auto *dbgLabelRecord = dyn_cast<llvm::DbgLabelRecord>(&dbgRecord)) { + DILabelAttr labelAttr = + debugImporter->translate(dbgLabelRecord->getLabel()); + if (!labelAttr) + return emitUnsupportedWarning(); + LLVM::DbgLabelOp::create(builder, loc, labelAttr); + continue; + } + // Warn if an unsupported debug record is encountered. + return emitUnsupportedWarning(); + } + } + // Convert all remaining LLVM instructions to MLIR operations. return convertInstruction(inst); } @@ -2579,8 +2627,15 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) { memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); auto inaccessibleMem = convertModRefInfoFromLLVM( memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); - auto memAttr = MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem, - inaccessibleMem); + auto errnoMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::ErrnoMem)); + auto targetMem0 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem0)); + auto targetMem1 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem1)); + auto memAttr = + MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem, + inaccessibleMem, errnoMem, targetMem0, targetMem1); // Only set the attr when it does not match the default value. if (memAttr.isReadWrite()) return; @@ -2885,8 +2940,15 @@ LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst, memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem)); ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM( memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem)); - auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem, - inaccessibleMem); + ModRefInfo errnoMem = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::ErrnoMem)); + ModRefInfo targetMem0 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem0)); + ModRefInfo targetMem1 = convertModRefInfoFromLLVM( + memEffects.getModRef(llvm::MemoryEffects::Location::TargetMem1)); + auto memAttr = + MemoryEffectsAttr::get(op.getContext(), othermem, argMem, inaccessibleMem, + errnoMem, targetMem0, targetMem1); // Only set the attribute when it does not match the default value. if (!memAttr.isReadWrite()) op.setMemoryEffectsAttr(memAttr); @@ -3007,6 +3069,11 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { if (failed(processDebugIntrinsics())) return failure(); + // Process the debug records that require a delayed conversion after + // everything else was converted. + if (failed(processDebugRecords())) + return failure(); + return success(); } @@ -3022,61 +3089,32 @@ static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic *dbgIntr) { return !isa<llvm::ValueAsMetadata>(nodeAsVal->getMetadata()); } -LogicalResult -ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, - DominanceInfo &domInfo) { - Location loc = translateLoc(dbgIntr->getDebugLoc()); - auto emitUnsupportedWarning = [&]() { - if (emitExpensiveWarnings) - emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr); - return success(); - }; - // Drop debug intrinsics with arg lists. - // TODO: Support debug intrinsics that have arg lists. - if (dbgIntr->hasArgList()) - return emitUnsupportedWarning(); - // Kill locations can have metadata nodes as location operand. This - // cannot be converted to poison as the type cannot be reconstructed. - // TODO: find a way to support this case. - if (isMetadataKillLocation(dbgIntr)) - return emitUnsupportedWarning(); - // Drop debug intrinsics if the associated variable information cannot be - // translated due to cyclic debug metadata. - // TODO: Support cyclic debug metadata. - DILocalVariableAttr localVariableAttr = - matchLocalVariableAttr(dbgIntr->getArgOperand(1)); - if (!localVariableAttr) - return emitUnsupportedWarning(); - FailureOr<Value> argOperand = convertMetadataValue(dbgIntr->getArgOperand(0)); - if (failed(argOperand)) - return emitError(loc) << "failed to convert a debug intrinsic operand: " - << diag(*dbgIntr); - - // Ensure that the debug intrinsic is inserted right after its operand is - // defined. Otherwise, the operand might not necessarily dominate the - // intrinsic. If the defining operation is a terminator, insert the intrinsic - // into a dominated block. - OpBuilder::InsertionGuard guard(builder); - if (Operation *op = argOperand->getDefiningOp(); +/// Ensure that the debug intrinsic is inserted right after the operand +/// definition. Otherwise, the operand might not necessarily dominate the +/// intrinsic. If the defining operation is a terminator, insert the intrinsic +/// into a dominated block. +static LogicalResult setDebugIntrinsicBuilderInsertionPoint( + mlir::OpBuilder &builder, DominanceInfo &domInfo, Value argOperand) { + if (Operation *op = argOperand.getDefiningOp(); op && op->hasTrait<OpTrait::IsTerminator>()) { // Find a dominated block that can hold the debug intrinsic. auto dominatedBlocks = domInfo.getNode(op->getBlock())->children(); // If no block is dominated by the terminator, this intrinisc cannot be // converted. if (dominatedBlocks.empty()) - return emitUnsupportedWarning(); + return failure(); // Set insertion point before the terminator, to avoid inserting something // before landingpads. Block *dominatedBlock = (*dominatedBlocks.begin())->getBlock(); builder.setInsertionPoint(dominatedBlock->getTerminator()); } else { - Value insertPt = *argOperand; - if (auto blockArg = dyn_cast<BlockArgument>(*argOperand)) { + Value insertPt = argOperand; + if (auto blockArg = dyn_cast<BlockArgument>(argOperand)) { // The value might be coming from a phi node and is now a block argument, // which means the insertion point is set to the start of the block. If // this block is a target destination of an invoke, the insertion point // must happen after the landing pad operation. - Block *insertionBlock = argOperand->getParentBlock(); + Block *insertionBlock = argOperand.getParentBlock(); if (!insertionBlock->empty() && isa<LandingpadOp>(insertionBlock->front())) insertPt = cast<LandingpadOp>(insertionBlock->front()).getRes(); @@ -3084,23 +3122,152 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, builder.setInsertionPointAfterValue(insertPt); } - auto locationExprAttr = - debugImporter->translateExpression(dbgIntr->getExpression()); - Operation *op = - llvm::TypeSwitch<llvm::DbgVariableIntrinsic *, Operation *>(dbgIntr) - .Case([&](llvm::DbgDeclareInst *) { - return LLVM::DbgDeclareOp::create( - builder, loc, *argOperand, localVariableAttr, locationExprAttr); - }) - .Case([&](llvm::DbgValueInst *) { - return LLVM::DbgValueOp::create( - builder, loc, *argOperand, localVariableAttr, locationExprAttr); - }); + return success(); +} + +std::tuple<DILocalVariableAttr, DIExpressionAttr, Value> +ModuleImport::processDebugOpArgumentsAndInsertionPt( + Location loc, + llvm::function_ref<FailureOr<Value>()> convertArgOperandToValue, + llvm::Value *address, + llvm::PointerUnion<llvm::Value *, llvm::DILocalVariable *> variable, + llvm::DIExpression *expression, DominanceInfo &domInfo) { + // Drop debug intrinsics if the associated debug information cannot be + // translated due to an unsupported construct. + DILocalVariableAttr localVarAttr = matchLocalVariableAttr(variable); + if (!localVarAttr) + return {}; + FailureOr<Value> argOperand = convertArgOperandToValue(); + if (failed(argOperand)) { + emitError(loc) << "failed to convert a debug operand: " << diag(*address); + return {}; + } + + if (setDebugIntrinsicBuilderInsertionPoint(builder, domInfo, *argOperand) + .failed()) + return {}; + + return {localVarAttr, debugImporter->translateExpression(expression), + *argOperand}; +} + +LogicalResult +ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr, + DominanceInfo &domInfo) { + Location loc = translateLoc(dbgIntr->getDebugLoc()); + auto emitUnsupportedWarning = [&]() { + if (emitExpensiveWarnings) + emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr); + return success(); + }; + + OpBuilder::InsertionGuard guard(builder); + auto convertArgOperandToValue = [&]() { + return convertMetadataValue(dbgIntr->getArgOperand(0)); + }; + + // Drop debug intrinsics with an argument list. + // TODO: Support this case. + if (dbgIntr->hasArgList()) + return emitUnsupportedWarning(); + + // Drop debug intrinsics with kill locations that have metadata nodes as + // location operand, which cannot be converted to poison as the type cannot be + // reconstructed. + // TODO: Support this case. + if (isMetadataKillLocation(dbgIntr)) + return emitUnsupportedWarning(); + + auto [localVariableAttr, locationExprAttr, locVal] = + processDebugOpArgumentsAndInsertionPt( + loc, convertArgOperandToValue, dbgIntr->getArgOperand(0), + dbgIntr->getArgOperand(1), dbgIntr->getExpression(), domInfo); + + if (!localVariableAttr) + return emitUnsupportedWarning(); + + if (!locVal) // Expected if localVariableAttr is present. + return failure(); + + Operation *op = nullptr; + if (isa<llvm::DbgDeclareInst>(dbgIntr)) + op = LLVM::DbgDeclareOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else if (isa<llvm::DbgValueInst>(dbgIntr)) + op = LLVM::DbgValueOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else + return emitUnsupportedWarning(); + mapNoResultOp(dbgIntr, op); setNonDebugMetadataAttrs(dbgIntr, op); return success(); } +LogicalResult +ModuleImport::processDebugRecord(llvm::DbgVariableRecord &dbgRecord, + DominanceInfo &domInfo) { + OpBuilder::InsertionGuard guard(builder); + Location loc = translateLoc(dbgRecord.getDebugLoc()); + auto emitUnsupportedWarning = [&]() -> LogicalResult { + if (!emitExpensiveWarnings) + return success(); + std::string options; + llvm::raw_string_ostream optionsStream(options); + dbgRecord.print(optionsStream); + emitWarning(loc) << "unhandled debug variable record " + << optionsStream.str(); + return success(); + }; + + // Drop debug records with an argument list. + // TODO: Support this case. + if (dbgRecord.hasArgList()) + return emitUnsupportedWarning(); + + // Drop all other debug records with a address operand that cannot be + // converted to an SSA value such as an empty metadata node. + // TODO: Support this case. + if (!dbgRecord.getAddress()) + return emitUnsupportedWarning(); + + auto convertArgOperandToValue = [&]() -> FailureOr<Value> { + llvm::Value *value = dbgRecord.getAddress(); + + // Return the mapped value if it has been converted before. + auto it = valueMapping.find(value); + if (it != valueMapping.end()) + return it->getSecond(); + + // Convert constants such as immediate values that have no mapping yet. + if (auto *constant = dyn_cast<llvm::Constant>(value)) + return convertConstantExpr(constant); + return failure(); + }; + + auto [localVariableAttr, locationExprAttr, locVal] = + processDebugOpArgumentsAndInsertionPt( + loc, convertArgOperandToValue, dbgRecord.getAddress(), + dbgRecord.getVariable(), dbgRecord.getExpression(), domInfo); + + if (!localVariableAttr) + return emitUnsupportedWarning(); + + if (!locVal) // Expected if localVariableAttr is present. + return failure(); + + if (dbgRecord.isDbgDeclare()) + LLVM::DbgDeclareOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else if (dbgRecord.isDbgValue()) + LLVM::DbgValueOp::create(builder, loc, locVal, localVariableAttr, + locationExprAttr); + else // isDbgAssign + return emitUnsupportedWarning(); + + return success(); +} + LogicalResult ModuleImport::processDebugIntrinsics() { DominanceInfo domInfo; for (llvm::Instruction *inst : debugIntrinsics) { @@ -3111,6 +3278,15 @@ LogicalResult ModuleImport::processDebugIntrinsics() { return success(); } +LogicalResult ModuleImport::processDebugRecords() { + DominanceInfo domInfo; + for (llvm::DbgVariableRecord *dbgRecord : dbgRecords) + if (failed(processDebugRecord(*dbgRecord, domInfo))) + return failure(); + dbgRecords.clear(); + return success(); +} + LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb, Block *block) { builder.setInsertionPointToStart(block); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 64e3c5f..fad9bd6b7 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -588,10 +588,17 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. - if (auto intAttr = dyn_cast<IntegerAttr>(attr)) - return llvm::ConstantInt::get( - llvmType, - intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { + // If the attribute is an unsigned integer or a 1-bit integer, zero-extend + // the value to the bit width of the LLVM type. Otherwise, sign-extend. + auto intTy = dyn_cast<IntegerType>(intAttr.getType()); + APInt value; + if (intTy && (intTy.isUnsigned() || intTy.getWidth() == 1)) + value = intAttr.getValue().zextOrTrunc(llvmType->getIntegerBitWidth()); + else + value = intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()); + return llvm::ConstantInt::get(llvmType, value); + } if (auto floatAttr = dyn_cast<FloatAttr>(attr)) { const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); // Special case for 8-bit floats, which are represented by integers due to @@ -677,10 +684,10 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( } } } - // std::vector is used here to accomodate large number of elements that - // exceed SmallVector capacity. - std::vector<llvm::Constant *> constants(numElements, child); - return llvm::ConstantArray::get(arrayType, constants); + // std::vector is used here to accomodate large number of elements that + // exceed SmallVector capacity. + std::vector<llvm::Constant *> constants(numElements, child); + return llvm::ConstantArray::get(arrayType, constants); } } @@ -892,10 +899,13 @@ void mlir::LLVM::detail::connectPHINodes(Region ®ion, llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) { - llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = - llvm::Intrinsic::getOrInsertDeclaration(module, intrinsic, tys); - return builder.CreateCall(fn, args); + return builder.CreateIntrinsic(intrinsic, tys, args); +} + +llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( + llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, + llvm::Type *retTy, ArrayRef<llvm::Value *> args) { + return builder.CreateIntrinsic(retTy, intrinsic, args); } llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( @@ -1637,6 +1647,15 @@ static void convertFunctionMemoryAttributes(LLVMFuncOp func, newMemEffects |= llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, convertModRefInfoToLLVM(memEffects.getOther())); + newMemEffects |= + llvm::MemoryEffects(llvm::MemoryEffects::Location::ErrnoMem, + convertModRefInfoToLLVM(memEffects.getErrnoMem())); + newMemEffects |= + llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem0, + convertModRefInfoToLLVM(memEffects.getTargetMem0())); + newMemEffects |= + llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem1, + convertModRefInfoToLLVM(memEffects.getTargetMem1())); llvmFunc->setMemoryEffects(newMemEffects); } @@ -2122,8 +2141,16 @@ LogicalResult ModuleTranslation::createTBAAMetadata() { // LLVM metadata instances. AttrTypeWalker walker; walker.addWalk([&](TBAARootAttr root) { - tbaaMetadataMapping.insert( - {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))}); + llvm::MDNode *node; + if (StringAttr id = root.getId()) { + node = llvm::MDNode::get(ctx, llvm::MDString::get(ctx, id)); + } else { + // Anonymous root nodes are self-referencing. + auto selfRef = llvm::MDNode::getTemporary(ctx, {}); + node = llvm::MDNode::get(ctx, {selfRef.get()}); + node->replaceOperandWith(0, node); + } + tbaaMetadataMapping.insert({root, node}); }); walker.addWalk([&](TBAATypeDescriptorAttr descriptor) { @@ -2254,8 +2281,11 @@ llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() { /* HasRequiresUnifiedSharedMemory = */ false, /* HasRequiresDynamicAllocators = */ false); unsigned int defaultAS = - getLLVMModule()->getDataLayout().getProgramAddressSpace(); + llvmModule->getDataLayout().getProgramAddressSpace(); config.setDefaultTargetAS(defaultAS); + config.setRuntimeCC(llvmModule->getTargetTriple().isSPIRV() + ? llvm::CallingConv::SPIR_FUNC + : llvm::CallingConv::C); ompBuilder->setConfig(std::move(config)); ompBuilder->initialize(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index c27f9aa..5b04a14 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction( return processLoopMerge(operands); case spirv::Opcode::OpPhi: return processPhi(operands); + case spirv::Opcode::OpSwitch: + return processSwitch(operands); case spirv::Opcode::OpUndef: return processUndef(operands); default: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 6492708..50883d9 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -346,6 +346,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { case spirv::Decoration::Constant: case spirv::Decoration::Invariant: case spirv::Decoration::Patch: + case spirv::Decoration::Coherent: if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target <id>"; @@ -2292,6 +2293,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) { return success(); } +LogicalResult spirv::Deserializer::processSwitch(ArrayRef<uint32_t> operands) { + if (!curBlock) + return emitError(unknownLoc, "OpSwitch must appear in a block"); + + if (operands.size() < 2) + return emitError(unknownLoc, "OpSwitch must at least specify selector and " + "a default target"); + + if (operands.size() % 2) + return emitError(unknownLoc, + "OpSwitch must at have an even number of operands: " + "selector, default target and any number of literal and " + "label <id> pairs"); + + Value selector = getValue(operands[0]); + Block *defaultBlock = getOrCreateBlock(operands[1]); + Location loc = createFileLineColLoc(opBuilder); + + SmallVector<int32_t> literals; + SmallVector<Block *> blocks; + for (unsigned i = 2, e = operands.size(); i < e; i += 2) { + literals.push_back(operands[i]); + blocks.push_back(getOrCreateBlock(operands[i + 1])); + } + + SmallVector<ValueRange> targetOperands(blocks.size(), {}); + spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock, + ArrayRef<Value>(), literals, blocks, targetOperands); + + return success(); +} + namespace { /// A class for putting all blocks in a structured selection/loop in a /// spirv.mlir.selection/spirv.mlir.loop op. @@ -2799,6 +2832,23 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { branchCondOp.getFalseBlock()); branchCondOp.erase(); + } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) { + if (target == switchOp.getDefaultTarget()) { + SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands()); + DenseIntElementsAttr literals = + switchOp.getLiterals().value_or(DenseIntElementsAttr()); + spirv::SwitchOp::create( + opBuilder, switchOp.getLoc(), switchOp.getSelector(), + switchOp.getDefaultTarget(), blockArgs, literals, + switchOp.getTargets(), targetOperands); + switchOp.erase(); + } else { + SuccessorRange targets = switchOp.getTargets(); + auto it = llvm::find(targets, target); + assert(it != targets.end()); + size_t index = std::distance(targets.begin(), it); + switchOp.getTargetOperandsMutable(index).assign(blockArgs); + } } else { return emitError(unknownLoc, "unimplemented terminator for Phi creation"); } @@ -2819,7 +2869,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { return success(); } -LogicalResult spirv::Deserializer::splitConditionalBlocks() { +LogicalResult spirv::Deserializer::splitSelectionHeader() { // Create a copy, so we can modify keys in the original. BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo; for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end(); @@ -2836,7 +2886,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { Operation *terminator = block->getTerminator(); assert(terminator); - if (!isa<spirv::BranchConditionalOp>(terminator)) + if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator)) continue; // Check if the current header block is a merge block of another construct. @@ -2846,10 +2896,10 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { splitHeaderMergeBlock = true; } - // Do not split a block that only contains a conditional branch, unless it - // is also a merge block of another construct - in that case we want to - // split the block. We do not want two constructs to share header / merge - // block. + // Do not split a block that only contains a conditional branch / switch, + // unless it is also a merge block of another construct - in that case we + // want to split the block. We do not want two constructs to share header / + // merge block. if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) { Block *newBlock = block->splitBlock(terminator); OpBuilder builder(block, block->end()); @@ -2887,13 +2937,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() { logger.startLine() << "\n"; }); - if (failed(splitConditionalBlocks())) { + if (failed(splitSelectionHeader())) { return failure(); } - // TODO: This loop is non-deterministic. Iteration order may vary between runs - // for the same shader as the key to the map is a pointer. See: - // https://github.com/llvm/llvm-project/issues/128547 while (!blockMergeInfo.empty()) { Block *headerBlock = blockMergeInfo.begin()->first; BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 6027f1a..50c9350 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -58,7 +58,9 @@ struct DebugLine { }; /// Map from a selection/loop's header block to its merge (and continue) target. -using BlockMergeInfoMap = DenseMap<Block *, BlockMergeInfo>; +/// Use `MapVector<>` to ensure a deterministic iteration order with a pointer +/// key. +using BlockMergeInfoMap = llvm::MapVector<Block *, BlockMergeInfo>; /// A "deferred struct type" is a struct type with one or more member types not /// known when the Deserializer first encounters the struct. This happens, for @@ -278,11 +280,11 @@ private: return opBuilder.getStringAttr(attrName); } - /// Move a conditional branch into a separate basic block to avoid unnecessary - /// sinking of defs that may be required outside a selection region. This - /// function also ensures that a single block cannot be a header block of one - /// selection construct and the merge block of another. - LogicalResult splitConditionalBlocks(); + /// Move a conditional branch or a switch into a separate basic block to avoid + /// unnecessary sinking of defs that may be required outside a selection + /// region. This function also ensures that a single block cannot be a header + /// block of one selection construct and the merge block of another. + LogicalResult splitSelectionHeader(); //===--------------------------------------------------------------------===// // Type @@ -472,6 +474,9 @@ private: /// Processes a SPIR-V OpPhi instruction with the given `operands`. LogicalResult processPhi(ArrayRef<uint32_t> operands); + /// Processes a SPIR-V OpSwitch instruction with the given `operands`. + LogicalResult processSwitch(ArrayRef<uint32_t> operands); + /// Creates block arguments on predecessors previously recorded when handling /// OpPhi instructions. LogicalResult wireUpBlockArgument(); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index 85e92c7..6397d2c 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { return success(); } +LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) { + uint32_t selectorID = getValueID(switchOp.getSelector()); + uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget()); + SmallVector<uint32_t> arguments{selectorID, defaultLabelID}; + + std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals(); + BlockRange targets = switchOp.getTargets(); + if (literals) { + for (auto [literal, target] : llvm::zip_equal(*literals, targets)) { + arguments.push_back(literal.getLimitedValue()); + uint32_t targetLabelID = getOrCreateBlockID(target); + arguments.push_back(targetLabelID); + } + } + + if (failed(emitDebugLine(functionBody, switchOp.getLoc()))) + return failure(); + encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments); + return success(); +} + LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { auto varName = addressOfOp.getVariable(); auto variableID = getVariableID(varName); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 29ed5a4..c879a2b 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -373,6 +373,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::Block: case spirv::Decoration::Invariant: case spirv::Decoration::Patch: + case spirv::Decoration::Coherent: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa<UnitAttr, DecorationAttr>(attr)) @@ -1443,7 +1444,20 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { assert(branchCondOp.getFalseTarget() == block); blockOperands = branchCondOp.getFalseTargetOperands(); } - + assert(!blockOperands->empty() && + "expected non-empty block operand range"); + predecessors.emplace_back(spirvPredecessor, *blockOperands); + } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) { + std::optional<OperandRange> blockOperands; + if (block == switchOp.getDefaultTarget()) { + blockOperands = switchOp.getDefaultOperands(); + } else { + SuccessorRange targets = switchOp.getTargets(); + auto it = llvm::find(targets, block); + assert(it != targets.end()); + size_t index = std::distance(targets.begin(), it); + blockOperands = switchOp.getTargetOperands(index); + } assert(!blockOperands->empty() && "expected non-empty block operand range"); predecessors.emplace_back(spirvPredecessor, *blockOperands); @@ -1579,6 +1593,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) { .Case([&](spirv::SpecConstantOperationOp op) { return processSpecConstantOperationOp(op); }) + .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index add372b..6e79c13 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -304,6 +304,8 @@ private: LogicalResult processBranchOp(spirv::BranchOp branchOp); + LogicalResult processSwitchOp(spirv::SwitchOp switchOp); + //===--------------------------------------------------------------------===// // Operations //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp index 6f2e4cd..e82807f 100644 --- a/mlir/lib/Tools/PDLL/AST/Context.cpp +++ b/mlir/lib/Tools/PDLL/AST/Context.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/AST/Context.h" -#include "TypeDetail.h" +#include "mlir/Tools/PDLL/AST/Types.h" using namespace mlir; using namespace mlir::pdll::ast; diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp index 5aa0937..4358ceb 100644 --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -21,7 +21,7 @@ static StringRef copyStringWithNull(Context &ctx, StringRef str) { return str; char *data = ctx.getAllocator().Allocate<char>(str.size() + 1); - std::copy(str.begin(), str.end(), data); + llvm::copy(str, data); data[str.size()] = 0; return StringRef(data, str.size()); } diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h deleted file mode 100644 index a0bd84e..0000000 --- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h +++ /dev/null @@ -1,141 +0,0 @@ -//===- TypeDetail.h ---------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ -#define LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ - -#include "mlir/Tools/PDLL/AST/Types.h" - -namespace mlir { -namespace pdll { -namespace ast { -//===----------------------------------------------------------------------===// -// Type -//===----------------------------------------------------------------------===// - -struct Type::Storage : public StorageUniquer::BaseStorage { - Storage(TypeID typeID) : typeID(typeID) {} - - /// The type identifier for the derived type class. - TypeID typeID; -}; - -namespace detail { - -/// A utility CRTP base class that defines many of the necessary utilities for -/// defining a PDLL AST Type. -template <typename ConcreteT, typename KeyT = void> -struct TypeStorageBase : public Type::Storage { - using KeyTy = KeyT; - using Base = TypeStorageBase<ConcreteT, KeyT>; - TypeStorageBase(KeyTy key) - : Type::Storage(TypeID::get<ConcreteT>()), key(key) {} - - /// Construct an instance with the given storage allocator. - static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, - const KeyTy &key) { - return new (alloc.allocate<ConcreteT>()) ConcreteT(key); - } - - /// Utility methods required by the storage allocator. - bool operator==(const KeyTy &key) const { return this->key == key; } - - /// Return the key value of this storage class. - const KeyTy &getValue() const { return key; } - -protected: - KeyTy key; -}; -/// A specialization of the storage base for singleton types. -template <typename ConcreteT> -struct TypeStorageBase<ConcreteT, void> : public Type::Storage { - using Base = TypeStorageBase<ConcreteT, void>; - TypeStorageBase() : Type::Storage(TypeID::get<ConcreteT>()) {} -}; - -//===----------------------------------------------------------------------===// -// AttributeType -//===----------------------------------------------------------------------===// - -struct AttributeTypeStorage : public TypeStorageBase<AttributeTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// ConstraintType -//===----------------------------------------------------------------------===// - -struct ConstraintTypeStorage : public TypeStorageBase<ConstraintTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// OperationType -//===----------------------------------------------------------------------===// - -struct OperationTypeStorage - : public TypeStorageBase<OperationTypeStorage, - std::pair<StringRef, const ods::Operation *>> { - using Base::Base; - - static OperationTypeStorage * - construct(StorageUniquer::StorageAllocator &alloc, - const std::pair<StringRef, const ods::Operation *> &key) { - return new (alloc.allocate<OperationTypeStorage>()) OperationTypeStorage( - std::make_pair(alloc.copyInto(key.first), key.second)); - } -}; - -//===----------------------------------------------------------------------===// -// RangeType -//===----------------------------------------------------------------------===// - -struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> { - using Base::Base; -}; - -//===----------------------------------------------------------------------===// -// RewriteType -//===----------------------------------------------------------------------===// - -struct RewriteTypeStorage : public TypeStorageBase<RewriteTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// TupleType -//===----------------------------------------------------------------------===// - -struct TupleTypeStorage - : public TypeStorageBase<TupleTypeStorage, - std::pair<ArrayRef<Type>, ArrayRef<StringRef>>> { - using Base::Base; - - static TupleTypeStorage * - construct(StorageUniquer::StorageAllocator &alloc, - std::pair<ArrayRef<Type>, ArrayRef<StringRef>> key) { - SmallVector<StringRef> names = llvm::to_vector(llvm::map_range( - key.second, [&](StringRef name) { return alloc.copyInto(name); })); - return new (alloc.allocate<TupleTypeStorage>()) - TupleTypeStorage(std::make_pair(alloc.copyInto(key.first), - alloc.copyInto(llvm::ArrayRef(names)))); - } -}; - -//===----------------------------------------------------------------------===// -// TypeType -//===----------------------------------------------------------------------===// - -struct TypeTypeStorage : public TypeStorageBase<TypeTypeStorage> {}; - -//===----------------------------------------------------------------------===// -// ValueType -//===----------------------------------------------------------------------===// - -struct ValueTypeStorage : public TypeStorageBase<ValueTypeStorage> {}; - -} // namespace detail -} // namespace ast -} // namespace pdll -} // namespace mlir - -#endif // LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp index 1468ac2..d5497b0 100644 --- a/mlir/lib/Tools/PDLL/AST/Types.cpp +++ b/mlir/lib/Tools/PDLL/AST/Types.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/AST/Types.h" -#include "TypeDetail.h" #include "mlir/Tools/PDLL/AST/Context.h" #include <optional> diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 9ef405d..018a188 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -681,17 +681,8 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer, return success(); } -std::pair<std::string, std::string> -mlir::registerAndParseCLIOptions(int argc, char **argv, - llvm::StringRef toolName, - DialectRegistry ®istry) { - static cl::opt<std::string> inputFilename( - cl::Positional, cl::desc("<input file>"), cl::init("-")); - - static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"), - cl::value_desc("filename"), - cl::init("-")); - // Register any command line options. +std::string mlir::registerCLIOptions(llvm::StringRef toolName, + DialectRegistry ®istry) { MlirOptMainConfig::registerCLOptions(registry); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); @@ -706,11 +697,29 @@ mlir::registerAndParseCLIOptions(int argc, char **argv, interleaveComma(registry.getDialectNames(), os, [&](auto name) { os << name; }); } - // Parse pass names in main to ensure static initialization completed. + return helpHeader; +} + +std::pair<std::string, std::string> +mlir::parseCLIOptions(int argc, char **argv, llvm::StringRef helpHeader) { + static cl::opt<std::string> inputFilename( + cl::Positional, cl::desc("<input file>"), cl::init("-")); + + static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"), + cl::value_desc("filename"), + cl::init("-")); cl::ParseCommandLineOptions(argc, argv, helpHeader); return std::make_pair(inputFilename.getValue(), outputFilename.getValue()); } +std::pair<std::string, std::string> +mlir::registerAndParseCLIOptions(int argc, char **argv, + llvm::StringRef toolName, + DialectRegistry ®istry) { + auto helpHeader = registerCLIOptions(toolName, registry); + return parseCLIOptions(argc, argv, helpHeader); +} + static LogicalResult printRegisteredDialects(DialectRegistry ®istry) { llvm::outs() << "Available Dialects: "; interleave(registry.getDialectNames(), llvm::outs(), ","); diff --git a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp index 685e794..64e86f2 100644 --- a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp +++ b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp @@ -153,5 +153,12 @@ int mlir::MlirTblgenMain(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv); - return TableGenMain(argv[0], &mlirTableGenMain); + return TableGenMain( + argv[0], [](TableGenOutputFiles &OutFiles, const RecordKeeper &RK) { + std::string S; + raw_string_ostream OS(S); + bool Res = mlirTableGenMain(OS, RK); + OutFiles = {S, {}}; + return Res; + }); } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 54b67f5..8907724 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_library(MLIRTransforms DEPENDS MLIRTransformsPassIncGen + MLIRTransformsDialectInterfaceIncGen LINK_LIBS PUBLIC MLIRAnalysis @@ -39,4 +40,5 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils + MLIRUBDialect ) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 41f3f9d..e9ced064 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -33,6 +33,7 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/LivenessAnalysis.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" @@ -260,6 +261,22 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { + // Operations that have dead operands can be erased regardless of their + // side effects. The liveness analysis would not have marked an SSA value as + // "dead" if it had a side-effecting user that is reachable. + bool hasDeadOperand = + markLives(op->getOperands(), nonLiveSet, la).flip().any(); + if (hasDeadOperand) { + LDBG() << "Simple op has dead operands, so the op must be dead: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + assert(!hasLive(op->getResults(), nonLiveSet, la) && + "expected the op to have no live results"); + cl.operations.push_back(op); + collectNonLiveValues(nonLiveSet, op->getResults(), + BitVector(op->getNumResults(), true)); + return; + } + if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { LDBG() << "Simple op is not memory effect free or has live results, " "preserving it: " @@ -361,6 +378,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // block other than the entry block, because every block has a terminator. for (Block &block : funcOp.getBlocks()) { Operation *returnOp = block.getTerminator(); + if (!returnOp->hasTrait<OpTrait::ReturnLike>()) + continue; if (returnOp && returnOp->getNumOperands() == numReturns) cl.operands.push_back({returnOp, nonLiveRets}); } @@ -700,7 +719,11 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, } /// Steps to process a `BranchOpInterface` operation: -/// Iterate through each successor block of `branchOp`. +/// +/// When a non-forwarded operand is dead (e.g., the condition value of a +/// conditional branch op), the entire operation is dead. +/// +/// Otherwise, iterate through each successor block of `branchOp`. /// (1) For each successor block, gather all operands from all successors. /// (2) Fetch their associated liveness analysis data and collect for future /// removal. @@ -711,7 +734,22 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { LDBG() << "Processing branch op: " << *branchOp; + + // Check for dead non-forwarded operands. + BitVector deadNonForwardedOperands = + markLives(branchOp->getOperands(), nonLiveSet, la).flip(); unsigned numSuccessors = branchOp->getNumSuccessors(); + for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { + SuccessorOperands successorOperands = + branchOp.getSuccessorOperands(succIdx); + // Remove all non-forwarded operands from the bit vector. + for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands()) + deadNonForwardedOperands[opOperand.getOperandNumber()] = false; + } + if (deadNonForwardedOperands.any()) { + cl.operations.push_back(branchOp.getOperation()); + return; + } for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { Block *successorBlock = branchOp->getSuccessor(succIdx); @@ -742,23 +780,70 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, static void cleanUpDeadVals(RDVFinalCleanupList &list) { LDBG() << "Starting cleanup of dead values..."; - // 1. Operations + // 1. Blocks, We must remove the block arguments and successor operands before + // deleting the operation, as they may reside in the region operation. + LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; + for (auto &b : list.blocks) { + // blocks that are accessed via multiple codepaths processed once + if (b.b->getNumArguments() != b.nonLiveArgs.size()) + continue; + LDBG() << "Erasing " << b.nonLiveArgs.count() + << " non-live arguments from block: " << b.b; + // it iterates backwards because erase invalidates all successor indexes + for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { + if (!b.nonLiveArgs[i]) + continue; + LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); + b.b->getArgument(i).dropAllUses(); + b.b->eraseArgument(i); + } + } + + // 2. Successor Operands + LDBG() << "Cleaning up " << list.successorOperands.size() + << " successor operand lists"; + for (auto &op : list.successorOperands) { + SuccessorOperands successorOperands = + op.branch.getSuccessorOperands(op.successorIndex); + // blocks that are accessed via multiple codepaths processed once + if (successorOperands.size() != op.nonLiveOperands.size()) + continue; + LDBG() << "Erasing " << op.nonLiveOperands.count() + << " non-live successor operands from successor " + << op.successorIndex << " of branch: " + << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); + // it iterates backwards because erase invalidates all successor indexes + for (int i = successorOperands.size() - 1; i >= 0; --i) { + if (!op.nonLiveOperands[i]) + continue; + LDBG() << " Erasing successor operand " << i << ": " + << successorOperands[i]; + successorOperands.erase(i); + } + } + + // 3. Operations LDBG() << "Cleaning up " << list.operations.size() << " operations"; - for (auto &op : list.operations) { + for (Operation *op : list.operations) { LDBG() << "Erasing operation: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); + if (op->hasTrait<OpTrait::IsTerminator>()) { + // When erasing a terminator, insert an unreachable op in its place. + OpBuilder b(op); + ub::UnreachableOp::create(b, op->getLoc()); + } op->dropAllUses(); op->erase(); } - // 2. Values + // 4. Values LDBG() << "Cleaning up " << list.values.size() << " values"; for (auto &v : list.values) { LDBG() << "Dropping all uses of value: " << v; v.dropAllUses(); } - // 3. Functions + // 5. Functions LDBG() << "Cleaning up " << list.functions.size() << " functions"; // Record which function arguments were erased so we can shrink call-site // argument segments for CallOpInterface operations (e.g. ops using @@ -780,7 +865,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { (void)f.funcOp.eraseResults(f.nonLiveRets); } - // 4. Operands + // 6. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { // Handle call-specific cleanup only when we have a cached callee reference. @@ -822,7 +907,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } } - // 5. Results + // 7. Results LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { LDBG() << "Erasing " << r.nonLive.count() @@ -830,48 +915,6 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(r.op, OpPrintingFlags().skipRegions()); dropUsesAndEraseResults(r.op, r.nonLive); } - - // 6. Blocks - LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; - for (auto &b : list.blocks) { - // blocks that are accessed via multiple codepaths processed once - if (b.b->getNumArguments() != b.nonLiveArgs.size()) - continue; - LDBG() << "Erasing " << b.nonLiveArgs.count() - << " non-live arguments from block: " << b.b; - // it iterates backwards because erase invalidates all successor indexes - for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { - if (!b.nonLiveArgs[i]) - continue; - LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); - b.b->getArgument(i).dropAllUses(); - b.b->eraseArgument(i); - } - } - - // 7. Successor Operands - LDBG() << "Cleaning up " << list.successorOperands.size() - << " successor operand lists"; - for (auto &op : list.successorOperands) { - SuccessorOperands successorOperands = - op.branch.getSuccessorOperands(op.successorIndex); - // blocks that are accessed via multiple codepaths processed once - if (successorOperands.size() != op.nonLiveOperands.size()) - continue; - LDBG() << "Erasing " << op.nonLiveOperands.count() - << " non-live successor operands from successor " - << op.successorIndex << " of branch: " - << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); - // it iterates backwards because erase invalidates all successor indexes - for (int i = successorOperands.size() - 1; i >= 0; --i) { - if (!op.nonLiveOperands[i]) - continue; - LDBG() << " Erasing successor operand " << i << ": " - << successorOperands[i]; - successorOperands.erase(i); - } - } - LDBG() << "Finished cleanup of dead values"; } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 9945a71..09ad423 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" #include <optional> +#include <utility> using namespace mlir; using namespace mlir::detail; @@ -975,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues); /// Replace the uses of the given value with the given values. The specified - /// converter is used to build materializations (if necessary). - void replaceAllUsesWith(Value from, ValueRange to, - const TypeConverter *converter); + /// converter is used to build materializations (if necessary). If `functor` + /// is specified, only the uses that the functor returns "true" for are + /// replaced. + void replaceValueUses(Value from, ValueRange to, + const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor = nullptr); /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1051,7 +1055,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { MLIRContext *context, std::function<void(Operation *)> opErasedCallback = nullptr) : RewriterBase(context, /*listener=*/this), - opErasedCallback(opErasedCallback) {} + opErasedCallback(std::move(opErasedCallback)) {} /// Erase the given op (unless it was already erased). void eraseOp(Operation *op) override { @@ -1202,11 +1206,16 @@ void BlockTypeConversionRewrite::rollback() { } /// Replace all uses of `from` with `repl`. -static void performReplaceValue(RewriterBase &rewriter, Value from, - Value repl) { +static void +performReplaceValue(RewriterBase &rewriter, Value from, Value repl, + function_ref<bool(OpOperand &)> functor = nullptr) { if (isa<BlockArgument>(repl)) { // `repl` is a block argument. Directly replace all uses. - rewriter.replaceAllUsesWith(from, repl); + if (functor) { + rewriter.replaceUsesWithIf(from, repl, functor); + } else { + rewriter.replaceAllUsesWith(from, repl); + } return; } @@ -1237,7 +1246,11 @@ static void performReplaceValue(RewriterBase &rewriter, Value from, Block *replBlock = replOp->getBlock(); rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + bool result = + user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + if (result && functor) + result &= functor(operand); + return result; }); } @@ -1645,7 +1658,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, /*isPureTypeConversion=*/false) .front(); - replaceAllUsesWith(origArg, mat, converter); + replaceValueUses(origArg, mat, converter); continue; } @@ -1654,14 +1667,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - replaceAllUsesWith(origArg, inputMap->replacementValues, converter); + replaceValueUses(origArg, inputMap->replacementValues, converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - replaceAllUsesWith(origArg, replArgs, converter); + replaceValueUses(origArg, replArgs, converter); } if (config.allowPatternRollback) @@ -1961,8 +1974,24 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::replaceAllUsesWith( - Value from, ValueRange to, const TypeConverter *converter) { +void ConversionPatternRewriterImpl::replaceValueUses( + Value from, ValueRange to, const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor) { + LLVM_DEBUG({ + logger.startLine() << "** Replace Value : '" << from << "'"; + if (auto blockArg = dyn_cast<BlockArgument>(from)) { + if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { + logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")"; + } else { + logger.getOStream() << " (unlinked block)"; + } + } + if (functor) { + logger.getOStream() << ", conditional replacement"; + } + }); + if (!config.allowPatternRollback) { SmallVector<Value> toConv = llvm::to_vector(to); SmallVector<Value> repls = @@ -1972,7 +2001,7 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( if (!repl) return; - performReplaceValue(r, from, repl); + performReplaceValue(r, from, repl, functor); return; } @@ -1991,6 +2020,9 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( replacedValues.insert(from); #endif // NDEBUG + if (functor) + llvm::report_fatal_error( + "conditional value replacement is not supported in rollback mode"); mapping.map(from, to); appendRewrite<ReplaceValueRewrite>(from, converter); } @@ -2189,18 +2221,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( } void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) { - LLVM_DEBUG({ - impl->logger.startLine() << "** Replace Value : '" << from << "'"; - if (auto blockArg = dyn_cast<BlockArgument>(from)) { - if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { - impl->logger.getOStream() << " (in region of '" << parentOp->getName() - << "' (" << parentOp << ")\n"; - } else { - impl->logger.getOStream() << " (unlinked block)\n"; - } - } - }); - impl->replaceAllUsesWith(from, to, impl->currentTypeConverter); + impl->replaceValueUses(from, to, impl->currentTypeConverter); +} + +void ConversionPatternRewriter::replaceUsesWithIf( + Value from, ValueRange to, function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced) { + assert(!allUsesReplaced && + "allUsesReplaced is not supported in a dialect conversion"); + impl->replaceValueUses(from, to, impl->currentTypeConverter, functor); } Value ConversionPatternRewriter::getRemappedValue(Value key) { @@ -2765,7 +2794,7 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { rewriterImpl.patternMaterializations.clear(); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Expensive pattern check that can detect API violations. - if (checkOp) { + if (checkOp && topLevelFingerPrint) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index 26c965c..4095031 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -613,8 +613,8 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, LLVM_DEBUG({ LDBG() << "* Inliner: Initial calls in SCC are: {"; - for (unsigned i = 0, e = calls.size(); i < e; ++i) - LDBG() << " " << i << ". " << calls[i].call << ","; + for (unsigned I = 0, E = calls.size(); I < E; ++I) + LDBG() << " " << I << ". " << calls[I].call << ","; LDBG() << "}"; }); diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 31ae1d1..330a2d3 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1149,9 +1149,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, // Remove the values that already dominate the insertion point. SmallVector<Value> prunedValues; for (auto value : values) { - if (dominance.properlyDominates(value, insertionPoint)) { + if (dominance.properlyDominates(value, insertionPoint)) continue; - } // Block arguments are not supported. if (isa<BlockArgument>(value)) { return rewriter.notifyMatchFailure( @@ -1178,8 +1177,13 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, // Since current support is to only move within a same basic block, // the slices dont need to look past block arguments. options.omitBlockArguments = true; + bool dependsOnSideEffectingOp = false; options.filter = [&](Operation *sliceBoundaryOp) { - return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); + bool mustMove = + !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); + if (mustMove && !isPure(sliceBoundaryOp)) + dependsOnSideEffectingOp = true; + return mustMove; }; llvm::SetVector<Operation *> slice; for (auto value : prunedValues) { @@ -1188,6 +1192,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, (void)result; } + // Check if any operation in the slice is side-effecting. + if (dependsOnSideEffectingOp) + return failure(); + // If the slice contains `insertionPoint` cannot move the dependencies. if (slice.contains(insertionPoint)) { return rewriter.notifyMatchFailure( @@ -1198,9 +1206,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, // Sort operations topologically before moving. mlir::topologicalSort(slice); - for (Operation *op : slice) { + for (Operation *op : slice) rewriter.moveOpBefore(op, insertionPoint); - } return success(); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 51c7576..2acb6ee 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -895,24 +895,8 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI ################################################################################ _flatten_mlir_python_targets(mlir_python_sources_deps MLIRPythonSources) -add_custom_target("mlir-python-sources" DEPENDS ${mlir_python_sources_deps}) -if(NOT LLVM_ENABLE_IDE) - add_llvm_install_targets(install-mlir-python-sources - DEPENDS mlir-python-sources - COMPONENT mlir-python-sources - ) -endif() - -set(_mlir_python_stubgen_enabled ON) -# Stubgen doesn't work when cross-compiling (stubgen will run in the host interpreter and then fail -# to find the extension module for the host arch). -# Note: Stubgen requires some extra handling to work properly when sanitizers are enabled, -# so we skip running it in that case now. -if(CMAKE_CROSSCOMPILING OR (NOT LLVM_USE_SANITIZER STREQUAL "")) - set(_mlir_python_stubgen_enabled OFF) -endif() -if(_mlir_python_stubgen_enabled) +if(MLIR_PYTHON_STUBGEN_ENABLED) # _mlir stubgen # Note: All this needs to come before add_mlir_python_modules(MLIRPythonModules so that the install targets for the # generated type stubs get created. @@ -965,6 +949,7 @@ if(_mlir_python_stubgen_enabled) ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs" SOURCES "${_core_type_stub_sources}" ) + list(APPEND mlir_python_sources_deps MLIRPythonExtension.Core.type_stub_gen) # _mlirPythonTestNanobind stubgen @@ -995,13 +980,21 @@ if(_mlir_python_stubgen_enabled) endif() endif() +add_custom_target("mlir-python-sources" DEPENDS ${mlir_python_sources_deps}) +if(NOT LLVM_ENABLE_IDE) + add_llvm_install_targets(install-mlir-python-sources + DEPENDS mlir-python-sources + COMPONENT mlir-python-sources + ) +endif() + ################################################################################ # The fully assembled package of modules. # This must come last. ################################################################################ set(_declared_sources MLIRPythonSources MLIRPythonExtension.RegisterEverything) -if(_mlir_python_stubgen_enabled) +if(MLIR_PYTHON_STUBGEN_ENABLED) list(APPEND _declared_sources MLIRPythonExtension.Core.type_stub_gen) endif() @@ -1014,7 +1007,7 @@ add_mlir_python_modules(MLIRPythonModules COMMON_CAPI_LINK_LIBS MLIRPythonCAPI ) -if(_mlir_python_stubgen_enabled) +if(MLIR_PYTHON_STUBGEN_ENABLED) add_dependencies(MLIRPythonModules "${_mlir_typestub_gen_target}") if(MLIR_INCLUDE_TESTS) add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}") diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py index 2fbcbb0..d15643c 100644 --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -49,13 +49,13 @@ class GPUFuncOp(GPUFuncOp): FUNCTION_TYPE_ATTR_NAME = "function_type" SYM_NAME_ATTR_NAME = "sym_name" - ARGUMENT_ATTR_NAME = "arg_attrs" - RESULT_ATTR_NAME = "res_attrs" def __init__( self, function_type: Union[FunctionType, TypeAttr], sym_name: Optional[Union[str, StringAttr]] = None, + arg_attrs: Optional[Sequence[dict]] = None, + res_attrs: Optional[Sequence[dict]] = None, kernel: Optional[bool] = None, workgroup_attrib_attrs: Optional[Sequence[dict]] = None, private_attrib_attrs: Optional[Sequence[dict]] = None, @@ -88,6 +88,8 @@ class GPUFuncOp(GPUFuncOp): ) super().__init__( function_type, + arg_attrs=arg_attrs, + res_attrs=res_attrs, workgroup_attrib_attrs=workgroup_attrib_attrs, private_attrib_attrs=private_attrib_attrs, loc=loc, diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index d387c12..c92bda7 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -352,3 +352,7 @@ def unpack( ip=ip, ) ) + + +reduce = region_op(ReduceOp, terminator=YieldOp) +map = region_op(MapOp, terminator=YieldOp) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py index 1672656..2235bb2 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py @@ -5,22 +5,25 @@ import sys + +def multiline_str_representer(dumper, data): + if len(data.splitlines()) > 1: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + try: - import yaml + from yaml import YAMLObject as _YAMLObject, add_representer + + add_representer(str, multiline_str_representer) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"This tool requires PyYAML but it was not installed. " - f"Recommend: {sys.executable} -m pip install PyYAML" - ) from e -__all__ = [ - "yaml_dump", - "yaml_dump_all", - "YAMLObject", -] + class _YAMLObject: + pass -class YAMLObject(yaml.YAMLObject): +class YAMLObject(_YAMLObject): @classmethod def to_yaml(cls, dumper, self): """Default to a custom dictionary mapping.""" @@ -33,21 +36,34 @@ class YAMLObject(yaml.YAMLObject): return yaml_dump(self) -def multiline_str_representer(dumper, data): - if len(data.splitlines()) > 1: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") - else: - return dumper.represent_scalar("tag:yaml.org,2002:str", data) +def yaml_dump(data, sort_keys=False, **kwargs): + try: + import yaml + return yaml.dump(data, sort_keys=sort_keys, **kwargs) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e -yaml.add_representer(str, multiline_str_representer) +def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): + try: + import yaml -def yaml_dump(data, sort_keys=False, **kwargs): - return yaml.dump(data, sort_keys=sort_keys, **kwargs) + return yaml.dump_all( + data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs + ) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e -def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): - return yaml.dump_all( - data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs - ) +__all__ = [ + "yaml_dump", + "yaml_dump_all", + "YAMLObject", +] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index fd4a5a8..9c24f94 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -1729,16 +1729,16 @@ def pooling_ndhwc_min( @linalg_structured_op -def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): +def fill(value=ScalarDef(T), O=TensorDef(T, output=True)): """Fills the output tensor with the given value. Works for arbitrary ranked output tensors since the operation performs scalar - accesses only and is thus rank polymorphic. Numeric casting is performed on - the value operand, promoting it to the same data type as the output. + accesses only and is thus rank polymorphic. The value type must match the + element type of the output tensor or memref. """ implements(FillOpInterface) defines(Canonicalizer) - O[None] = TypeFn.cast_signed(U, value) + O[None] = value @linalg_structured_op diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 678ceee..9e22df3 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -12,6 +12,7 @@ try: from ._ods_common import ( get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, + get_op_result_or_op_results as _get_op_result_or_op_results, _cext as _ods_cext, ) except ImportError as e: @@ -254,3 +255,77 @@ def for_( yield iv, iter_args[0], for_op.results[0] else: yield iv + + +@_ods_cext.register_operation(_Dialect, replace=True) +class IndexSwitchOp(IndexSwitchOp): + __doc__ = IndexSwitchOp.__doc__ + + def __init__( + self, + results, + arg, + cases, + case_body_builder=None, + default_body_builder=None, + loc=None, + ip=None, + ): + cases = DenseI64ArrayAttr.get(cases) + super().__init__( + results, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip + ) + for region in self.regions: + region.blocks.append() + + if default_body_builder is not None: + with InsertionPoint(self.default_block): + default_body_builder(self) + + if case_body_builder is not None: + for i, case in enumerate(cases): + with InsertionPoint(self.case_block(i)): + case_body_builder(self, i, self.cases[i]) + + @property + def default_region(self) -> Region: + return self.regions[0] + + @property + def default_block(self) -> Block: + return self.default_region.blocks[0] + + @property + def case_regions(self) -> Sequence[Region]: + return self.regions[1:] + + def case_region(self, i: int) -> Region: + return self.case_regions[i] + + @property + def case_blocks(self) -> Sequence[Block]: + return [region.blocks[0] for region in self.case_regions] + + def case_block(self, i: int) -> Block: + return self.case_regions[i].blocks[0] + + +def index_switch( + results, + arg, + cases, + case_body_builder=None, + default_body_builder=None, + loc=None, + ip=None, +) -> Union[OpResult, OpResultList, IndexSwitchOp]: + op = IndexSwitchOp( + results=results, + arg=arg, + cases=cases, + case_body_builder=case_body_builder, + default_body_builder=default_body_builder, + loc=loc, + ip=ip, + ) + return _get_op_result_or_op_results(op) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index de414dc..b3dd79c 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -7,6 +7,7 @@ from .._transform_ops_gen import * from .._transform_ops_gen import _Dialect from ..._mlir_libs._mlirDialectsTransform import * from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType +from . import interpreter try: from ...ir import * @@ -324,6 +325,25 @@ class NamedSequenceOp(NamedSequenceOp): def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] + def apply( + self, + payload: Module, + transform_options: Optional[interpreter.TransformOptions] = None, + ) -> Module: + assert self.parent + assert "transform.with_named_sequence" in self.parent.attributes + assert isinstance( + self.parent.attributes["transform.with_named_sequence"], UnitAttr + ) + + interpreter.apply_named_sequence( + payload_root=payload, + transform_root=self, + transform_module=self.parent, + transform_options=transform_options, + ) + return payload # NB: was modified in-place (if any transformation happened) + def named_sequence( sym_name: Union[str, SymbolRefAttr], diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index ce8015d..5aa6453 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -11,6 +11,7 @@ try: from .._ods_common import _cext as _ods_cext from .._ods_common import ( MixedValues, + MixedInt, get_op_result_or_value as _get_op_result_or_value, _dispatch_dynamic_index_list, ) @@ -41,6 +42,15 @@ class GetDescOp(GetDescOp): ) +def get_desc_op( + target: Value, + *, + loc=None, + ip=None, +) -> OpResult: + return GetDescOp(target, loc=loc, ip=ip).result + + @_ods_cext.register_operation(_Dialect, replace=True) class SetDescLayoutOp(SetDescLayoutOp): """Specialization for SetDescLayoutOp class.""" @@ -52,6 +62,7 @@ class SetDescLayoutOp(SetDescLayoutOp): sg_data: MixedValues, *, inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, loc=None, ip=None, ): @@ -82,11 +93,33 @@ class SetDescLayoutOp(SetDescLayoutOp): static_sg_layout=static_sg_layout, static_sg_data=static_sg_data, static_inst_data=static_inst_data, + slice_dims=slice_dims, loc=loc, ip=ip, ) +def set_desc_layout( + target: Union[Operation, Value], + sg_layout: MixedValues, + sg_data: MixedValues, + *, + inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, + loc=None, + ip=None, +) -> OpResult: + return SetDescLayoutOp( + target, + sg_layout, + sg_data, + inst_data=inst_data, + slice_dims=slice_dims, + loc=loc, + ip=ip, + ).result + + @_ods_cext.register_operation(_Dialect, replace=True) class SetOpLayoutAttrOp(SetOpLayoutAttrOp): """Specialization for SetOpLayoutAttrOp class.""" @@ -98,6 +131,7 @@ class SetOpLayoutAttrOp(SetOpLayoutAttrOp): sg_data: MixedValues, *, inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, index: Optional[Union[int, Attribute]] = None, result: Optional[Union[bool, Attribute]] = None, loc=None, @@ -127,8 +161,206 @@ class SetOpLayoutAttrOp(SetOpLayoutAttrOp): static_sg_layout=static_sg_layout, static_sg_data=static_sg_data, static_inst_data=static_inst_data, + slice_dims=slice_dims, index=index, result=result, loc=loc, ip=ip, ) + + +def set_op_layout_attr( + target: Union[Operation, Value], + sg_layout: MixedValues, + sg_data: MixedValues, + *, + inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, + index: Optional[Union[int, Attribute]] = None, + result: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, +) -> SetOpLayoutAttrOp: + return SetOpLayoutAttrOp( + target, + sg_layout, + sg_data, + inst_data=inst_data, + slice_dims=slice_dims, + index=index, + result=result, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp): + """Specialization for SetGPULaunchThreadsOp class.""" + + def __init__( + self, + launch_op: Union[Operation, Value], + threads: MixedValues, + *, + loc=None, + ip=None, + ): + ( + dynamic_threads, + static_threads, + _, + ) = _dispatch_dynamic_index_list(threads) + + super().__init__( + _get_op_result_or_value(launch_op), + dynamic_threads, + static_threads=static_threads, + loc=loc, + ip=ip, + ) + + +def set_gpu_launch_threads( + launch_op: Union[Operation, Value], + threads: MixedValues, + *, + loc=None, + ip=None, +) -> SetGPULaunchThreadsOp: + return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InsertPrefetchOp(InsertPrefetchOp): + """Specialization for InsertPrefetchOp class.""" + + def __init__( + self, + target: Value, + *, + nb_prefetch: Optional[MixedInt] = 1, + loc=None, + ip=None, + ): + static_nb_prefetch = 1 + dynamic_nb_prefetch = None + if isinstance(nb_prefetch, int): + static_nb_prefetch = nb_prefetch + elif isinstance(nb_prefetch, IntegerAttr): + static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error + elif isinstance(nb_prefetch, (Operation, Value, OpView)): + dynamic_nb_prefetch = nb_prefetch + + super().__init__( + transform.AnyOpType.get(), + target, + dynamic_nb_prefetch=dynamic_nb_prefetch, + static_nb_prefetch=static_nb_prefetch, + loc=loc, + ip=ip, + ) + + +def insert_prefetch( + target: Value, + *, + nb_prefetch: Optional[MixedInt] = 1, + loc=None, + ip=None, +) -> OpResult: + return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConvertLayoutOp(ConvertLayoutOp): + """Specialization for ConvertLayoutOp class.""" + + def __init__( + self, + target: Value, + input_sg_layout: MixedValues, + input_sg_data: MixedValues, + target_sg_layout: MixedValues, + target_sg_data: MixedValues, + *, + input_inst_data: Optional[MixedValues] = None, + target_inst_data: Optional[MixedValues] = None, + loc=None, + ip=None, + ): + input_inst_data = [] if input_inst_data is None else input_inst_data + target_inst_data = [] if target_inst_data is None else target_inst_data + ( + dynamic_input_sg_layout, + static_input_sg_layout, + _, + ) = _dispatch_dynamic_index_list(input_sg_layout) + ( + dynamic_input_sg_data, + static_input_sg_data, + _, + ) = _dispatch_dynamic_index_list(input_sg_data) + ( + dynamic_input_inst_data, + static_input_inst_data, + _, + ) = _dispatch_dynamic_index_list(input_inst_data) + ( + dynamic_target_sg_layout, + static_target_sg_layout, + _, + ) = _dispatch_dynamic_index_list(target_sg_layout) + ( + dynamic_target_sg_data, + static_target_sg_data, + _, + ) = _dispatch_dynamic_index_list(target_sg_data) + ( + dynamic_target_inst_data, + static_target_inst_data, + _, + ) = _dispatch_dynamic_index_list(target_inst_data) + super().__init__( + transform.AnyOpType.get(), + target, + dynamic_input_sg_layout, + dynamic_input_sg_data, + dynamic_input_inst_data, + dynamic_target_sg_layout, + dynamic_target_sg_data, + dynamic_target_inst_data, + static_input_sg_layout=static_input_sg_layout, + static_input_sg_data=static_input_sg_data, + static_input_inst_data=static_input_inst_data, + static_target_sg_layout=static_target_sg_layout, + static_target_sg_data=static_target_sg_data, + static_target_inst_data=static_target_inst_data, + loc=loc, + ip=ip, + ) + + +def convert_layout( + target: Value, + input_sg_layout: MixedValues, + input_sg_data: MixedValues, + target_sg_layout: MixedValues, + target_sg_data: MixedValues, + *, + input_inst_data: Optional[MixedValues] = None, + target_inst_data: Optional[MixedValues] = None, + loc=None, + ip=None, +) -> ConvertLayoutOp: + return ConvertLayoutOp( + target, + input_sg_layout, + input_sg_data, + target_sg_layout, + target_sg_data, + input_inst_data=input_inst_data, + target_inst_data=target_inst_data, + loc=loc, + ip=ip, + ).result diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 11477d0..f4aa2d6 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -34,11 +34,12 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]: """ old_enabled = _globals.loc_tracebacks_enabled() old_limit = _globals.loc_tracebacks_frame_limit() + max_depth = old_limit if max_depth is None else max_depth try: _globals.set_loc_tracebacks_frame_limit(max_depth) if not old_enabled: _globals.set_loc_tracebacks_enabled(True) - yield + yield finally: if not old_enabled: _globals.set_loc_tracebacks_enabled(False) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index abe0925..a1ff6e8 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,7 +1,9 @@ +# BUILD dependencies nanobind>=2.9, <3.0 -numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 +typing_extensions>=4.12.2 +# RUN dependencies +numpy>=1.19.5, <=2.1.2 ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16 ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" -typing_extensions>=4.12.2 diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir index 3748be7..768f1cf 100644 --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -184,6 +184,18 @@ func.func private @private0(%0 : i32) -> i32 { // CHECK-NEXT: result #0: live // CHECK-LABEL: test_tag: y: // CHECK-NEXT: result #0: not live +// CHECK-LABEL: test_tag: for: +// CHECK-NEXT: operand #0: live +// CHECK-NEXT: operand #1: live +// CHECK-NEXT: operand #2: live +// CHECK-NEXT: operand #3: live +// CHECK-NEXT: operand #4: not live +// CHECK-NEXT: result #0: live +// CHECK-NEXT: result #1: not live +// CHECK-NEXT: region: #0: +// CHECK-NEXT: argument: #0: live +// CHECK-NEXT: argument: #1: not live +// CHECK-NEXT: argument: #2: not live func.func @test_7_type_3(%arg0: memref<i32>) { %c0 = arith.constant {tag = "zero"} 0 : index %c10 = arith.constant {tag = "ten"} 10 : index @@ -194,7 +206,7 @@ func.func @test_7_type_3(%arg0: memref<i32>) { %1 = arith.addi %x, %x : i32 %2 = func.call @private0(%1) : (i32) -> i32 scf.yield %2, %arg3 : i32, i32 - } + } {tag = "for"} memref.store %0#0, %arg0[] : memref<i32> return } diff --git a/mlir/test/CAPI/execution_engine.c b/mlir/test/CAPI/execution_engine.c index 4751288..4df232f 100644 --- a/mlir/test/CAPI/execution_engine.c +++ b/mlir/test/CAPI/execution_engine.c @@ -69,7 +69,7 @@ void testSimpleExecution(void) { mlirRegisterAllLLVMTranslations(ctx); MlirExecutionEngine jit = mlirExecutionEngineCreate( module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, - /*enableObjectDump=*/false); + /*enableObjectDump=*/false, /*enablePIC=*/false); if (mlirExecutionEngineIsNull(jit)) { fprintf(stderr, "Execution engine creation failed"); exit(2); @@ -125,7 +125,7 @@ void testOmpCreation(void) { // against the OpenMP library. MlirExecutionEngine jit = mlirExecutionEngineCreate( module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, - /*enableObjectDump=*/false); + /*enableObjectDump=*/false, /*enablePIC=*/false); if (mlirExecutionEngineIsNull(jit)) { fprintf(stderr, "Engine creation failed with OpenMP"); exit(2); diff --git a/mlir/test/CAPI/global_constructors.c b/mlir/test/CAPI/global_constructors.c index bd2fe14..9aacaf2 100644 --- a/mlir/test/CAPI/global_constructors.c +++ b/mlir/test/CAPI/global_constructors.c @@ -79,7 +79,7 @@ void testGlobalCtorJitCallback(void) { // Create execution engine with initialization disabled MlirExecutionEngine jit = mlirExecutionEngineCreate( module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL, - /*enableObjectDump=*/false); + /*enableObjectDump=*/false, /*enablePIC=*/false); if (mlirExecutionEngineIsNull(jit)) { fprintf(stderr, "Execution engine creation failed"); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 2fd3df6d..432b887 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -456,3 +456,4 @@ func.func @sched_barrier() { amdgpu.sched_barrier allow = <valu|all_vmem> func.return } + diff --git a/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir new file mode 100644 index 0000000..a94e17a --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/gfx1250.mlir @@ -0,0 +1,445 @@ +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics \ +// RUN: | FileCheck %s + +// CHECK-LABEL: @scaled_ext_packed_matrix_fp4 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + +// CHECK-LABEL: @scaled_ext_packed_matrix_fp8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + +// CHECK-LABEL: @scaled_ext_packed_matrix_bf8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + + +// CHECK-LABEL: @scaled_ext_packed_matrix_fp6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> +} + +// CHECK-LABEL: @scaled_ext_packed_matrix_bf6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed_matrix_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op blockSize of 16 can only have (firstScaleLane, firstScaleByte) be (0, 0) or (16, 2) for f8.}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op failed to verify that all of {source, res} have same shape}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_src_elem_type(%v: vector<16xf16>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op operand #0 must be}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf16>, vector<4xf8E8M0FNU> -> vector<16xf16> + return %ret0: vector<16xf16> +} + +// ----- + +func.func @amdgpu.scaled_ext_packed_matrix_invalid_dst_elem_type(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf64>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed_matrix' op result #0 must be vector}} + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64> + return %ret0: vector<16xf64> +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +func.func @amdgpu.make_dma_base.invalid_element_types(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xf32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op failed to verify that all of {global, lds} have same element type}} + %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xf32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32> + return %0 : !amdgpu.tdm_base<i32> +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +func.func @amdgpu.make_dma_base.invalid_element_types(%idx: index, %mem: memref<8xi7, #gpu_global_addrspace>, %smem: memref<8xi7,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i7>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op element type must be 1, 2, 4, or 8 bytes long but type was 7 bits long.}} + %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi7, #gpu_global_addrspace>, memref<8xi7, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i7> + return %0 : !amdgpu.tdm_base<i7> +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +// CHECK-LABEL: func @make_dma_base +// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32, 1>, %[[SMEM:.+]]: memref<8xi32, 3>) +func.func @make_dma_base(%idx: index, %mem: memref<8xi32, #gpu_global_addrspace>, %smem: memref<8xi32,#gpu_lds_addrspace>) -> (!amdgpu.tdm_base<i32>) { + // CHECK-DAG: %[[INT:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[MEMREF_DESC_MEM:.+]] = builtin.unrealized_conversion_cast %[[MEM]] : memref<8xi32, 1> + // CHECK-DAG: %[[MEMREF_DESC_SMEM:.+]] = builtin.unrealized_conversion_cast %[[SMEM]] : memref<8xi32, 3> + + // CHECK-DAG: %[[MEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_MEM]][1] : !llvm.struct<(ptr<1> + // CHECK-DAG: %[[SMEM_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF_DESC_SMEM]][1] : !llvm.struct<(ptr<3> + + // CHECK-DAG: %[[MEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[MEM_BASE_PTR]][%[[INT]]] + // CHECK-DAG: %[[SMEM_BASE_OFFSET:.+]] = llvm.getelementptr %[[SMEM_BASE_PTR]][%[[INT]]] + + // CHECK-DAG: %[[MEM_INT:.+]] = llvm.ptrtoint %[[MEM_BASE_OFFSET]] : !llvm.ptr<1> to i64 + // CHECK-DAG: %[[SMEM_INT:.+]] = llvm.ptrtoint %[[SMEM_BASE_OFFSET]] : !llvm.ptr<3> to i32 + + // CHECK: %[[MEM_INT_LOW:.+]] = llvm.trunc %[[MEM_INT]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) + // CHECK: %[[SHIFTED_MEM_INT:.+]] = llvm.lshr %[[MEM_INT]], %[[SHIFT]] + // CHECK: %[[MEM_INT_HIGH:.+]] = llvm.trunc %[[SHIFTED_MEM_INT]] : i64 to i32 + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(33554431 : i32) + // CHECK: %[[VALID_MEM_INT_HIGH:.+]] = llvm.and %[[MEM_INT_HIGH]], %[[MASK]] + + // CHECK-DAG: %[[TYPE_FIELD:.+]] = llvm.mlir.constant(-2147483648 : i32) + // CHECK: %[[MEM_INT_HIGH_TYPE:.+]] = llvm.or %[[VALID_MEM_INT_HIGH]], %[[TYPE_FIELD]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) : i32 + + // CHECK: %[[V4I32_0_0:.+]] = llvm.mlir.poison : vector<4xi32> + // CHECK: %[[V4I32_0_1:.+]] = llvm.insertelement %[[C1]], %[[V4I32_0_0]][%[[C0]] : i32] + // CHECK: %[[V4I32_0_2:.+]] = llvm.insertelement %[[SMEM_INT]], %[[V4I32_0_1]][%[[C1]] : i32] + // CHECK: %[[V4I32_0_3:.+]] = llvm.insertelement %[[MEM_INT_LOW]], %[[V4I32_0_2]][%[[C2]] : i32] + // CHECK: %[[V4I32_0_4:.+]] = llvm.insertelement %[[MEM_INT_HIGH_TYPE]], %[[V4I32_0_3]][%[[C3]] : i32] + + %0 = amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32, #gpu_global_addrspace>, memref<8xi32, #gpu_lds_addrspace> -> !amdgpu.tdm_base<i32> + + func.return %0 : !amdgpu.tdm_base<i32> +} + +// ----- + +// CHECK-LABEL: func @make_dma_descriptor +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>) -> !amdgpu.tdm_descriptor { + // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) + // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32) + // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32) + // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32) + + // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR0:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]] + + // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR1:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]] + + // CHECK-DAG: %[[TENSOR_DIM_1:.+]] = llvm.mlir.constant(128 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR3_0:.+]] = llvm.lshr %[[TENSOR_DIM_1]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TENSOR_DIM_1_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1]], %[[C16]] + // CHECK: %[[SGPR2:.+]] = llvm.or %[[SGPR2_0]], %[[TENSOR_DIM_1_SHIFTED]] + + // CHECK-DAG: %[[TILE_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TILE_DIM_0_SHIFTED:.+]] = llvm.shl %[[TILE_DIM_0:.+]], %[[C16]] + // CHECK: %[[SGPR3:.+]] = llvm.or %[[SGPR3_0]], %[[TILE_DIM_0_SHIFTED]] + + // CHECK-DAG: %[[SGPR4:.+]] = llvm.mlir.constant(128 : i32) + + // CHECK-DAG: %[[TENSOR_DIM_0_STRIDE:.+]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_0_STRIDE]] + // CHECK-DAG: %[[SGPR5:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_MASKED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_HIGH_64:.+]] = llvm.lshr %[[TENSOR_DIM_0_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR6_0:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_HIGH_64]] : i64 to i32 + + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE:.+]] = llvm.mlir.constant(64 : i64) + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_1_STRIDE]] + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE_LOW:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_MASKED]] + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_SHIFTED:.+]] = llvm.lshr %[[TENSOR_DIM_1_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR7:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_SHIFTED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1_STRIDE_LOW]], %[[SHIFT]] + // CHECK-DAG: %[[SGPR6:.+]] = llvm.or %[[SGPR6_0]], %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED]] + + // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32> + // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32] + // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32] + // CHECK: %[[DGROUP1_2:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP1_1]][%[[C2]] : i32] + // CHECK: %[[DGROUP1_3:.+]] = llvm.insertelement %[[SGPR3]], %[[DGROUP1_2]][%[[C3]] : i32] + // CHECK: %[[DGROUP1_4:.+]] = llvm.insertelement %[[SGPR4]], %[[DGROUP1_3]][%[[C4]] : i32] + // CHECK: %[[DGROUP1_5:.+]] = llvm.insertelement %[[SGPR5]], %[[DGROUP1_4]][%[[C5]] : i32] + // CHECK: %[[DGROUP1_6:.+]] = llvm.insertelement %[[SGPR6]], %[[DGROUP1_5]][%[[C6]] : i32] + // CHECK: %[[DGROUP1:.+]] = llvm.insertelement %[[SGPR7]], %[[DGROUP1_6]][%[[C7]] : i32] + + // CHECK: %[[DGROUPS:.+]] = builtin.unrealized_conversion_cast %[[DGROUP0]], %[[DGROUP1]] : vector<4xi32>, vector<8xi32> to !amdgpu.tdm_descriptor + %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %descriptor : !amdgpu.tdm_descriptor +} + +// ----- + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 +#amdgpu_fat_buffer_addrspace = 7 + +// CHECK-LABEL: func @make_dma_descriptor_atomic_barrier +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[BARRIER:.+]]: {{.*}}, %[[IDX:.+]]: index) +func.func @make_dma_descriptor_atomic_barrier(%base: !amdgpu.tdm_base<i32>, %barrier : memref<8xi32, #gpu_lds_addrspace>, %idx: index) -> !amdgpu.tdm_descriptor { + // CHECK-DAG: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[BARRIER_MEMREF_DESC:.+]] = builtin.unrealized_conversion_cast %[[BARRIER]] + // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) + // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32) + // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32) + // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32) + + // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR0_0:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]] + + // CHECK-DAG: %[[ATOMIC_BARRIER_ENABLE_OFFSET:.+]] = llvm.mlir.constant(18 : i32) + // CHECK: %[[ATOMIC_BARRIER_ENABLE_FIELD:.+]] = llvm.shl %[[C1]], %[[ATOMIC_BARRIER_ENABLE_OFFSET]] + // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_0]], %[[ATOMIC_BARRIER_ENABLE_FIELD]] + + // CHECK: %[[ATOMIC_BARRIER_ALIGNED_PTR:.+]] = llvm.extractvalue %[[BARRIER_MEMREF_DESC]][1] + // CHECK: %[[ATOMIC_BARRIER_ADDR:.+]] = llvm.getelementptr %[[ATOMIC_BARRIER_ALIGNED_PTR]][%[[INDEX]] + // CHECK: %[[ATOMIC_BARRIER_I32:.+]] = llvm.ptrtoint %[[ATOMIC_BARRIER_ADDR]] : !llvm.ptr<3> to i32 + // CHECK: %[[ATOMIC_BARRIER_NO_3_LSB:.+]] = llvm.lshr %[[ATOMIC_BARRIER_I32]], %[[C3]] + // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32) + // CHECK: %[[ATOMIC_BARRIER:.+]] = llvm.and %[[ATOMIC_BARRIER_NO_3_LSB]], %[[MASK]] + + // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR1_0:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]] + // CHECK: %[[SGPR1:.+]] = llvm.or %[[ATOMIC_BARRIER]], %[[SGPR1_0]] + + // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32> + // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32] + // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32] + + %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] + globalStride [64, 1] + sharedSize [128, 64] + atomicBarrier(%barrier[%idx] : memref<8xi32, #gpu_lds_addrspace>) + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %descriptor : !amdgpu.tdm_descriptor +} + +// ----- + +// CHECK-LABEL: func @make_dma_descriptor_workgroup_mask +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[WG_MASK:.+]]: i16, %[[TIMEOUT:.+]]: i1) +func.func @make_dma_descriptor_workgroup_mask(%base: !amdgpu.tdm_base<i32>, %wg_mask: i16, %timeout: i1) -> !amdgpu.tdm_descriptor { + // CHECK-DAG: %[[DGROUP0:.+]] = builtin.unrealized_conversion_cast %[[BASE]] + + // CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : i32) + // CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(3 : i32) + // CHECK-DAG: %[[C4:.+]] = llvm.mlir.constant(4 : i32) + // CHECK-DAG: %[[C5:.+]] = llvm.mlir.constant(5 : i32) + // CHECK-DAG: %[[C6:.+]] = llvm.mlir.constant(6 : i32) + // CHECK-DAG: %[[C7:.+]] = llvm.mlir.constant(7 : i32) + + // CHECK-DAG: %[[WG_MASK_EXT:.+]] = llvm.zext %[[WG_MASK]] + // CHECK-DAG: %[[DATA_SIZE:.+]] = llvm.mlir.constant(2 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[DATA_SIZE_SHIFTED:.+]] = llvm.shl %[[DATA_SIZE]], %[[C16]] + // CHECK: %[[SGPR0_BASE:.+]] = llvm.or %[[WG_MASK_EXT]], %[[DATA_SIZE_SHIFTED]] + // CHECK-DAG: %[[C21:.+]] = llvm.mlir.constant(21 : i32) + // CHECK: %[[TIMEOUT_SHIFTED:.+]] = llvm.shl %[[C1]], %[[C21]] + // CHECK: %[[SGPR0:.+]] = llvm.or %[[SGPR0_BASE]], %[[TIMEOUT_SHIFTED]] + + // CHECK-DAG: %[[TENSOR_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR2_0:.+]] = llvm.lshr %[[TENSOR_DIM_0]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR1:.+]] = llvm.shl %[[TENSOR_DIM_0]], %[[C16]] + + // CHECK-DAG: %[[TENSOR_DIM_1:.+]] = llvm.mlir.constant(128 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[SGPR3_0:.+]] = llvm.lshr %[[TENSOR_DIM_1]], %[[C16]] + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TENSOR_DIM_1_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1]], %[[C16]] + // CHECK: %[[SGPR2:.+]] = llvm.or %[[SGPR2_0]], %[[TENSOR_DIM_1_SHIFTED]] + + // CHECK-DAG: %[[TILE_DIM_0:.+]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: %[[C16:.+]] = llvm.mlir.constant(16 : i32) + // CHECK: %[[TILE_DIM_0_SHIFTED:.+]] = llvm.shl %[[TILE_DIM_0:.+]], %[[C16]] + // CHECK: %[[SGPR3:.+]] = llvm.or %[[SGPR3_0]], %[[TILE_DIM_0_SHIFTED]] + + // CHECK-DAG: %[[SGPR4:.+]] = llvm.mlir.constant(128 : i32) + + // CHECK-DAG: %[[TENSOR_DIM_0_STRIDE:.+]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_0_STRIDE]] + // CHECK-DAG: %[[SGPR5:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_MASKED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[TENSOR_DIM_0_STRIDE_HIGH_64:.+]] = llvm.lshr %[[TENSOR_DIM_0_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR6_0:.+]] = llvm.trunc %[[TENSOR_DIM_0_STRIDE_HIGH_64]] : i64 to i32 + + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE:.+]] = llvm.mlir.constant(64 : i64) + // CHECK-DAG: %[[MASK:.+]] = llvm.mlir.constant(281474976710655 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_MASKED:.+]] = llvm.and %[[MASK]], %[[TENSOR_DIM_1_STRIDE]] + // CHECK-DAG: %[[TENSOR_DIM_1_STRIDE_LOW:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_MASKED]] + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[TENSOR_DIM_1_STRIDE_SHIFTED:.+]] = llvm.lshr %[[TENSOR_DIM_1_STRIDE_MASKED]], %[[SHIFT]] + // CHECK: %[[SGPR7:.+]] = llvm.trunc %[[TENSOR_DIM_1_STRIDE_SHIFTED]] : i64 to i32 + // CHECK-DAG: %[[SHIFT:.+]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED:.+]] = llvm.shl %[[TENSOR_DIM_1_STRIDE_LOW]], %[[SHIFT]] + // CHECK-DAG: %[[SGPR6:.+]] = llvm.or %[[SGPR6_0]], %[[TENSOR_DIM_1_STRIDE_LOW_SHIFTED]] + + // CHECK: %[[V8I32:.+]] = llvm.mlir.poison : vector<8xi32> + // CHECK: %[[DGROUP1_0:.+]] = llvm.insertelement %[[SGPR0]], %[[V8I32]][%[[C0]] : i32] + // CHECK: %[[DGROUP1_1:.+]] = llvm.insertelement %[[SGPR1]], %[[DGROUP1_0]][%[[C1]] : i32] + // CHECK: %[[DGROUP1_2:.+]] = llvm.insertelement %[[SGPR2]], %[[DGROUP1_1]][%[[C2]] : i32] + // CHECK: %[[DGROUP1_3:.+]] = llvm.insertelement %[[SGPR3]], %[[DGROUP1_2]][%[[C3]] : i32] + // CHECK: %[[DGROUP1_4:.+]] = llvm.insertelement %[[SGPR4]], %[[DGROUP1_3]][%[[C4]] : i32] + // CHECK: %[[DGROUP1_5:.+]] = llvm.insertelement %[[SGPR5]], %[[DGROUP1_4]][%[[C5]] : i32] + // CHECK: %[[DGROUP1_6:.+]] = llvm.insertelement %[[SGPR6]], %[[DGROUP1_5]][%[[C6]] : i32] + // CHECK: %[[DGROUP1:.+]] = llvm.insertelement %[[SGPR7]], %[[DGROUP1_6]][%[[C7]] : i32] + + // CHECK: %[[DGROUPS:.+]] = builtin.unrealized_conversion_cast %[[DGROUP0]], %[[DGROUP1]] : vector<4xi32>, vector<8xi32> to !amdgpu.tdm_descriptor + %descriptor = amdgpu.make_dma_descriptor %base globalSize [128, 64] globalStride [64, 1] sharedSize [128, 64] workgroupMask %wg_mask earlyTimeout %timeout : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %descriptor : !amdgpu.tdm_descriptor +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir index 1016ee8..537ef59 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9 -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10 -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11 -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11 +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12 // CHECK-LABEL: func @memory_counter_wait func.func @memory_counter_wait() { diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir new file mode 100644 index 0000000..5b29e01 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_tensor.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s + +// CHECK-LABEL: func @memory_counter_wait_tensor +func.func @memory_counter_wait_tensor() { + // CHECK: rocdl.s.wait.tensorcnt 3 + amdgpu.memory_counter_wait tensor(3) + + return +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir new file mode 100644 index 0000000..1d2f692 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/memory_counter_wait_unsupported.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx942 +// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1030 +// RUN: mlir-opt %s --verify-diagnostics --convert-amdgpu-to-rocdl=chipset=gfx1100 + +func.func @memory_counter_wait_tensor() { + // expected-error @below{{failed to legalize operation 'amdgpu.memory_counter_wait'}} + // expected-error @below{{'amdgpu.memory_counter_wait' op unsupported chipset}} + amdgpu.memory_counter_wait tensor(0) + + return +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir index 9fcc147..4e6aa17 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir @@ -6,30 +6,30 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>, %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) { // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<4xf32> + // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} {opsel = true} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> - // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} {opsel = true} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16> amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> - // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> return diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir index 5788347..978227b 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -20,15 +20,15 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32> amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>) -> vector<8xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>) -> vector<4xf16> amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16> - // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16> - // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>) -> vector<4xi16> amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16> // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> @@ -51,19 +51,19 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32> func.return diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir index 5e77a3ad..37259f6 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir @@ -14,13 +14,13 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2 amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> - // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) + // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}} : (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>) amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16> return @@ -29,31 +29,31 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec // CHECK-LABEL: @wmma_k64 func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>, %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) { - // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}} + // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, %arg3 {clamp = true, signA = true, signB = true} amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32> // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16> return @@ -65,25 +65,25 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>, // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16> return diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir new file mode 100644 index 0000000..bd4a9da --- /dev/null +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -0,0 +1,329 @@ +// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 + +// CHECK-LABEL: func.func @foo() -> f8E4M3FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN +// CHECK: return %[[CONSTANT_0]] : f8E4M3FN +// CHECK: } + +// CHECK-LABEL: func.func @bar() -> f6E3M2FN { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN +// CHECK: return %[[CONSTANT_0]] : f6E3M2FN +// CHECK: } + +// Illustrate that both f8E4M3FN and f6E3M2FN calling the same _mlir_apfloat_add is fine +// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width. +// CHECK-LABEL: func.func @full_example() { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN +// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN +// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64 +// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8 +// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64 +// // fltSemantics semantics for f8E4M3FN +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32 +// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8 +// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN +// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN + +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN +// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN +// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64 +// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6 +// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64 +// // fltSemantics semantics for f6E3M2FN +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32 +// CHECK: %[[VAL_3:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64 +// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6 +// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN +// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN +// CHECK: return +// CHECK: } + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> f8E4M3FN { + %cst = arith.constant 2.2 : f8E4M3FN + return %cst : f8E4M3FN +} + +func.func @bar() -> f6E3M2FN { + %cst = arith.constant 3.2 : f6E3M2FN + return %cst : f6E3M2FN +} + +func.func @full_example() { + %a = arith.constant 1.4 : f8E4M3FN + %b = func.call @foo() : () -> (f8E4M3FN) + %c = arith.addf %a, %b : f8E4M3FN + vector.print %c : f8E4M3FN + + %d = arith.constant 2.4 : f6E3M2FN + %e = func.call @bar() : () -> (f6E3M2FN) + %f = arith.addf %d, %e : f6E3M2FN + vector.print %f : f6E3M2FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// Test decl collision (different type) +// expected-error@+1{{matched function '_mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}} +func.func private @_mlir_apfloat_add(i32, i32, f32) -> index +func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.addf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.subf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.mulf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.divf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.remf %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 18 : i32 +// CHECK: %[[sem_out:.*]] = arith.constant 2 : i32 +// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64 +func.func @extf(%arg0: f4E2M1FN) { + %0 = arith.extf %arg0 : f4E2M1FN to f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 1 : i32 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64 +func.func @truncf(%arg0: bf16) { + %0 = arith.truncf %arg0 : bf16 to f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32 +// CHECK: %[[out_width:.*]] = arith.constant 4 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant false +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +// CHECK: arith.trunci %[[res]] : i64 to i4 +func.func @fptosi(%arg0: f16) { + %0 = arith.fptosi %arg0 : f16 to i4 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_to_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 0 : i32 +// CHECK: %[[out_width:.*]] = arith.constant 4 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant true +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_to_int(%[[sem_in]], %[[out_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +// CHECK: arith.trunci %[[res]] : i64 to i4 +func.func @fptoui(%arg0: f16) { + %0 = arith.fptoui %arg0 : f16 to i4 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: %[[in_width:.*]] = arith.constant 32 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant false +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +func.func @sitofp(%arg0: i32) { + %0 = arith.sitofp %arg0 : i32 to f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: %[[in_width:.*]] = arith.constant 32 : i32 +// CHECK: %[[is_unsigned:.*]] = arith.constant true +// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64 +func.func @uitofp(%arg0: i32) { + %0 = arith.uitofp %arg0 : i32 to f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8 +// CHECK: %[[c3:.*]] = arith.constant 3 : i8 +// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8 +// CHECK: %[[c0:.*]] = arith.constant 0 : i8 +// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8 +// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1 +func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_neg(i32, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_neg(%[[sem]], %{{.*}}) : (i32, i64) -> i64 +func.func @negf(%arg0: f32) { + %0 = arith.negf %arg0 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_minimum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_minimum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @minimumf(%arg0: f32, %arg1: f32) { + %0 = arith.minimumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_maximum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_maximum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @maximumf(%arg0: f32, %arg1: f32) { + %0 = arith.maximumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_minnum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_minnum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @minnumf(%arg0: f32, %arg1: f32) { + %0 = arith.minnumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_maxnum(i32, i64, i64) -> i64 +// CHECK: %[[sem:.*]] = arith.constant 2 : i32 +// CHECK: %[[res:.*]] = call @_mlir_apfloat_maxnum(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64 +func.func @maxnumf(%arg0: f32, %arg1: f32) { + %0 = arith.maxnumf %arg0, %arg1 : f32 + return +} + +// ----- + +// CHECK-LABEL: func.func @unsupported_bitwidth +// CHECK: arith.addf {{.*}} : f128 +// CHECK: arith.negf {{.*}} : f128 +// CHECK: arith.cmpf {{.*}} : f128 +// CHECK: arith.extf {{.*}} : f32 to f128 +// CHECK: arith.truncf {{.*}} : f128 to f32 +// CHECK: arith.fptosi {{.*}} : f128 to i32 +// CHECK: arith.fptosi {{.*}} : f32 to i92 +// CHECK: arith.sitofp {{.*}} : i1 to f128 +// CHECK: arith.sitofp {{.*}} : i92 to f32 +func.func @unsupported_bitwidth(%arg0: f128, %arg1: f128, %arg2: f32) { + %0 = arith.addf %arg0, %arg1 : f128 + %1 = arith.negf %arg0 : f128 + %2 = arith.cmpf "ult", %arg0, %arg1 : f128 + %3 = arith.extf %arg2 : f32 to f128 + %4 = arith.truncf %arg0 : f128 to f32 + %5 = arith.fptosi %arg0 : f128 to i32 + %6 = arith.fptosi %arg2 : f32 to i92 + %7 = arith.sitofp %2 : i1 to f128 + %8 = arith.sitofp %6 : i92 to f32 + return +} + +// ----- + +// CHECK-LABEL: func.func @addf_vector +// CHECK-2: vector.to_elements + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: arith.bitcast +// CHECK: arith.extui +// CHECK: call +// CHECK: arith.trunci + +// CHECK: vector.from_elements +func.func @addf_vector(%arg0: vector<4xf4E2M1FN>, %arg1: vector<4xf4E2M1FN>) { + %0 = arith.addf %arg0, %arg1 : vector<4xf4E2M1FN> + return +} diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 5f1ec66..b53c52d 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -738,6 +738,22 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) { // ----- +// CHECK-LABEL: @ops_supporting_exact +func.func @ops_supporting_exact(i32, i32) { +^bb0(%arg0: i32, %arg1: i32): +// CHECK: = llvm.ashr exact %arg0, %arg1 : i32 + %0 = arith.shrsi %arg0, %arg1 exact : i32 +// CHECK: = llvm.lshr exact %arg0, %arg1 : i32 + %1 = arith.shrui %arg0, %arg1 exact : i32 +// CHECK: = llvm.sdiv exact %arg0, %arg1 : i32 + %2 = arith.divsi %arg0, %arg1 exact : i32 +// CHECK: = llvm.udiv exact %arg0, %arg1 : i32 + %3 = arith.divui %arg0, %arg1 exact : i32 + return +} + +// ----- + // CHECK-LABEL: func @memref_bitcast // CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>) // CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -754,12 +770,14 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> { // CHECK: arith.addf {{.*}} : f4E2M1FN // CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN> // CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN> +// CHECK: arith.cmpf {{.*}} : f4E2M1FN // CHECK: llvm.select {{.*}} : i1, i4 func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) { %0 = arith.addf %arg0, %arg0 : f4E2M1FN %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN> %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN> - %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN + %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN + %4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN return } @@ -769,9 +787,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2 // CHECK: llvm.fadd {{.*}} : f32 // CHECK: llvm.fadd {{.*}} : vector<4xf32> // CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32> -func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) { +// CHECK: llvm.fcmp {{.*}} : f32 +func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) { %0 = arith.addf %arg0, %arg0 : f32 %1 = arith.addf %arg1, %arg1 : vector<4xf32> %2 = arith.addf %arg2, %arg2 : vector<4x8xf32> - return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32> + %3 = arith.cmpf oeq, %arg0, %arg3 : f32 + return } diff --git a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir index ae1dc70..bd28162 100644 --- a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir +++ b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir @@ -32,7 +32,7 @@ func.func @pass_through(%arg0: () -> ()) -> (() -> ()) { func.func private @llvmlinkage(i32) attributes { "llvm.linkage" = #llvm.linkage<extern_weak> } // CHECK-LABEL: llvm.func @llvmreadnone(i32) -// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none> +// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none> func.func private @llvmreadnone(i32) attributes { llvm.readnone } // CHECK-LABEL: llvm.func @body(i32) diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index a4b5dde..f1cc1eb 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allow-pattern-rollback=0' -split-input-file | FileCheck %s // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allowed-dialects=func,arith,cf' -split-input-file | FileCheck %s // RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE // RUN: mlir-opt %s -transform-interpreter | FileCheck %s diff --git a/mlir/test/Conversion/GPUToNVVM/memref.mlir b/mlir/test/Conversion/GPUToNVVM/memref.mlir index e164ca9..a4e8ead 100644 --- a/mlir/test/Conversion/GPUToNVVM/memref.mlir +++ b/mlir/test/Conversion/GPUToNVVM/memref.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -convert-gpu-to-nvvm | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm="allow-pattern-rollback=0" | FileCheck %s // RUN: mlir-opt %s -convert-gpu-to-nvvm='use-bare-ptr-memref-call-conv=1' \ // RUN: | FileCheck %s --check-prefix=BARE diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir index b479467..a080144 100644 --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-gpu-to-nvvm="allow-pattern-rollback=0" --split-input-file %s | FileCheck %s // RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck --check-prefix=CHECK32 %s gpu.module @test_module { @@ -81,6 +82,28 @@ gpu.module @test_module { gpu.module @test_module { + // CHECK-LABEL: func @gpu_wmma_f64_load_op() -> + // CHECK-SAME: f64 + // CHECK32-LABEL: func @gpu_wmma_f64_load_op() -> + func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp"> + return %0 : !gpu.mma_matrix<8x4xf64, "AOp"> + // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64 + // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64 + // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64 + // CHECK: llvm.return %[[LOAD]] : f64 + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @gpu_wmma_store_op // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) // CHECK32-LABEL: func @gpu_wmma_store_op diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir index c1627a0..19e1c7a 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-copy.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefixes=CPP,CHECK +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefixes=NOCPP,CHECK func.func @alloc_copy(%arg0: memref<999xi32>) { %alloc = memref.alloc() : memref<999xi32> @@ -9,42 +9,46 @@ func.func @alloc_copy(%arg0: memref<999xi32>) { return } -// CHECK: module { // NOCPP: emitc.include <"stdlib.h"> // NOCPP-NEXT: emitc.include <"string.h"> // CPP: emitc.include <"cstdlib"> // CPP-NEXT: emitc.include <"cstring"> -// CHECK-LABEL: alloc_copy -// CHECK-SAME: %[[arg0:.*]]: memref<999xi32> -// CHECK-NEXT: builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32> -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %1, %2 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "malloc"(%3) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> -// CHECK-NEXT: emitc.cast %4 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> -// CHECK-NEXT: builtin.unrealized_conversion_cast %5 : !emitc.ptr<i32> to !emitc.array<999xi32> -// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: emitc.subscript %0[%7] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> -// CHECK-NEXT: emitc.apply "&"(%8) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %12, %13 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%11, %9, %14) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %15, %16 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "malloc"(%17) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> -// CHECK-NEXT: emitc.cast %18 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> -// CHECK-NEXT: builtin.unrealized_conversion_cast %19 : !emitc.ptr<i32> to !emitc.array<999xi32> -// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: emitc.subscript %0[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> -// CHECK-NEXT: emitc.apply "&"(%22) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> -// CHECK-NEXT: emitc.subscript %20[%21] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> -// CHECK-NEXT: emitc.apply "&"(%24) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> -// CHECK-NEXT: emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t -// CHECK-NEXT: "emitc.constant"() <{value = 999 : index}> : () -> index -// CHECK-NEXT: emitc.mul %26, %27 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%25, %23, %28) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () -// CHECK-NEXT: return +// CHECK-LABEL: func.func @alloc_copy( +// CHECK-SAME: %[[ARG0:.*]]: memref<999xi32>) { +// CHECK: %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<999xi32> to !emitc.array<999xi32> +// CHECK: %[[CALL_OPAQUE_0:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_0:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_0:.*]] = emitc.mul %[[CALL_OPAQUE_0]], %[[VAL_0]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: %[[CALL_OPAQUE_1:.*]] = emitc.call_opaque "malloc"(%[[MUL_0]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CHECK: %[[CAST_0:.*]] = emitc.cast %[[CALL_OPAQUE_1]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[CAST_0]] : !emitc.ptr<i32> to !emitc.array<999xi32> +// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<i32> +// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_2]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<i32> +// CHECK: %[[CALL_OPAQUE_2:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_3:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_1:.*]] = emitc.mul %[[CALL_OPAQUE_2]], %[[VAL_3]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_1]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () +// CHECK: %[[CALL_OPAQUE_3:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_2:.*]] = emitc.mul %[[CALL_OPAQUE_3]], %[[VAL_4]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: %[[CALL_OPAQUE_4:.*]] = emitc.call_opaque "malloc"(%[[MUL_2]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CHECK: %[[CAST_1:.*]] = emitc.cast %[[CALL_OPAQUE_4]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CHECK: %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CAST_1]] : !emitc.ptr<i32> to !emitc.array<999xi32> +// CHECK: %[[VAL_5:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_2:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_5]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_2:.*]] = emitc.address_of %[[SUBSCRIPT_2]] : !emitc.lvalue<i32> +// CHECK: %[[VAL_6:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_3:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_2]]{{\[}}%[[VAL_6]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32> +// CHECK: %[[ADDRESS_OF_3:.*]] = emitc.address_of %[[SUBSCRIPT_3]] : !emitc.lvalue<i32> +// CHECK: %[[CALL_OPAQUE_5:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CHECK: %[[VAL_7:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CHECK: %[[MUL_3:.*]] = emitc.mul %[[CALL_OPAQUE_5]], %[[VAL_7]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_3]], %[[ADDRESS_OF_2]], %[[MUL_3]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> () +// CHECK: return +// CHECK: } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir index d151d1b..3de2d25 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP -// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefixes=CPP,CHECK +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefixes=NOCPP,CHECK func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32> @@ -10,20 +10,21 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { // NOCPP: emitc.include <"string.h"> // CPP: emitc.include <"cstring"> -// CHECK-LABEL: copying -// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> -// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> -// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> -// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> -// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32> -// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> -// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32> -// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t -// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index -// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> () -// CHECK-NEXT: return -// CHECK-NEXT: } -// CHECK-NEXT:} +// CHECK-LABEL: func.func @copying( +// CHECK-SAME: %[[ARG0:.*]]: memref<9x4x5x7xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<9x4x5x7xf32>) { +// CHECK: %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK: %[[VAL_0:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_0]], %[[VAL_0]], %[[VAL_0]], %[[VAL_0]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> +// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<f32> +// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32> +// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<f32> +// CHECK: %[[CALL_OPAQUE_0:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 1260 : index}> : () -> index +// CHECK: %[[MUL_0:.*]] = emitc.mul %[[CALL_OPAQUE_0]], %[[VAL_2]] : (!emitc.size_t, index) -> !emitc.size_t +// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_0]]) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> () +// CHECK: return +// CHECK: } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 2b4eda3..c7b043b 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -53,7 +53,7 @@ module @globals { // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32> %0 = memref.get_global @public_global : memref<3x7xf32> // CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32> - // CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + // CHECK-NEXT: emitc.address_of %1 : !emitc.lvalue<i32> %1 = memref.get_global @__constant_xi32 : memref<i32> return } diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index dcf4ddb..0eb4478 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -603,14 +603,14 @@ func.func @mbarrier_txcount() { %txcount = arith.constant 256 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType scf.yield } else { %txcount = arith.constant 0 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType scf.yield } @@ -620,7 +620,7 @@ func.func @mbarrier_txcount() { %ticks = arith.constant 10000000 : index // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]] nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType func.return @@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() { %txcount = arith.constant 256 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType %phase_c0 = arith.constant 0 : i1 %ticks = arith.constant 10000000 : index // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]] nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType func.return diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index a9356c5..8fb36ac 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -16,17 +16,13 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r" - nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b" - nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 llvm.return } // CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" - nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b" nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1 llvm.return @@ -44,7 +40,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, // CHECK-SAME: DONE: // CHECK-SAME: }", // CHECK-SAME: "r,r,r" - nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 + nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 llvm.return } @@ -544,8 +540,8 @@ func.func @elect_one_leader_sync() { // ----- -// CHECK-LABEL: @init_mbarrier_arrive_expect_tx -llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { +// CHECK-LABEL: @test_nvvm_prefetch +llvm.func @test_nvvm_prefetch(%desc : !llvm.ptr, %pred : i1) { //CHECK: nvvm.prefetch tensormap, %{{.*}} nvvm.prefetch tensormap, %desc : !llvm.ptr //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b" @@ -588,29 +584,6 @@ func.func @cp_async_bulk_wait_group() { // ----- -func.func @fence_mbarrier_init() { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.mbarrier_init.release.cluster;" - nvvm.fence.mbarrier.init - func.return -} -// ----- - -func.func @fence_proxy() { - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.alias;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<alias>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.global;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.global>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cta;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>} - //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cluster;", "" : () -> () - nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>} - func.return -} - -// ----- - // CHECK-LABEL: @llvm_nvvm_barrier_arrive // CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32) llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) { diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index f2fbe91..b122f42 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -615,3 +615,22 @@ omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> // CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr) omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr) } + +// CHECK-LABEL: llvm.func @omp_dist_schedule(%arg0: i32) { +func.func @omp_dist_schedule(%arg0: i32) { + %c1_i32 = arith.constant 1 : i32 + // CHECK: %1 = llvm.mlir.constant(1024 : i32) : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c16_i32 = arith.constant 16 : i32 + %c8_i32 = arith.constant 8 : i32 + omp.teams num_teams( to %c8_i32 : i32) thread_limit(%c16_i32 : i32) { + // CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%c1024_i32 : i32) { + omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%arg0) inclusive step (%c1_i32) { + omp.terminator + } + } + omp.terminator + } + return +} diff --git a/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir b/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir new file mode 100644 index 0000000..3bd9bb4 --- /dev/null +++ b/mlir/test/Conversion/OpenMPToLLVM/map-info-type-conversion-fail.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -convert-openmp-to-llvm -split-input-file -verify-diagnostics %s + +// Indicates that the TypeConversion has failed for the MPMapInfoOp. +// In this specific case, the `tensor` type (used in a TypeAttr) cannot be converted +// to an LLVM type. This test ensures that the conversion fails gracefully with a +// legalization error instead of crashing. +func.func @fail_map_info_tensor_type(%arg0: memref<?xf32>) { + // expected-error@+1 {{failed to legalize operation 'omp.map.info' that was explicitly marked illegal}} + %map_info = omp.map.info var_ptr(%arg0: memref<?xf32>, tensor<?xf32>) map_clauses(to) capture(ByRef) -> memref<?xf32> + omp.target_update map_entries(%map_info: memref<?xf32>) { + omp.terminator + } + return +} diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index 483c7b3..0c4f20e 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="allow-pattern-rollback=0" -split-input-file %s | FileCheck %s // CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-NEXT: cf.br ^bb1(%{{.*}} : index) diff --git a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir index 26f5a3e..2f192df 100644 --- a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir +++ b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir @@ -673,3 +673,51 @@ func.func @nested_parallel_with_side_effect() { // CHECK: gpu.launch // CHECK-NOT: scf.parallel + +// ----- + +func.func @scf2gpu_index_creation_2d() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + // Single 2-D scf.parallel mapped to block_x and thread_x. + // Use both IVs so the conversion must compute indices. + scf.parallel (%bx, %tx) = (%c0, %c0) to (%c32, %c32) step (%c1, %c1) { + %u = arith.addi %bx, %c0 : index + %v = arith.addi %tx, %c0 : index + } { + mapping = [ + #gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>, + #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)> + ] + } + return +} + +// CHECK-LABEL: func @scf2gpu_index_creation_2d +// CHECK: gpu.launch +// CHECK: %[[IDX:.*]] = affine.apply +// CHECK: arith.addi %[[IDX]], + +// ----- + +func.func @scf2gpu_index_creation_1d() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + scf.parallel (%t) = (%c0) to (%c64) step (%c1) { + %w = arith.addi %t, %c0 : index + } { + mapping = [ + #gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)> + ] + } + return +} + +// CHECK-LABEL: func @scf2gpu_index_creation_1d +// CHECK: gpu.launch +// CHECK: %[[IDX:.*]] = affine.apply +// CHECK: arith.addi %[[IDX]], diff --git a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir index e1936e2..b17e1c4 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir @@ -162,9 +162,7 @@ spirv.func @sqrt(%arg0: f32, %arg1: vector<3xf16>) "None" { // CHECK-LABEL: @tan spirv.func @tan(%arg0: f32) "None" { - // CHECK: %[[SIN:.*]] = llvm.intr.sin(%{{.*}}) : (f32) -> f32 - // CHECK: %[[COS:.*]] = llvm.intr.cos(%{{.*}}) : (f32) -> f32 - // CHECK: llvm.fdiv %[[SIN]], %[[COS]] : f32 + // CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32 %0 = spirv.GL.Tan %arg0 : f32 spirv.Return } @@ -175,13 +173,7 @@ spirv.func @tan(%arg0: f32) "None" { // CHECK-LABEL: @tanh spirv.func @tanh(%arg0: f32) "None" { - // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : f32 - // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : f32 - // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[X2]]) : (f32) -> f32 - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32 - // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : f32 - // CHECK: llvm.fdiv %[[T0]], %[[T1]] : f32 + // CHECK: llvm.intr.tanh(%{{.*}}) : (f32) -> f32 %0 = spirv.GL.Tanh %arg0 : f32 spirv.Return } diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir index b69c2d0..65c6e05 100644 --- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir +++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir @@ -79,3 +79,12 @@ func.func @tensor_2d_empty() -> () { %x = arith.constant dense<> : tensor<2x0xi32> return } + +// Tensors with more than UINT32_MAX elements cannnot fit in a spirv.array. +// Test that they are not lowered. +// CHECK-LABEL: func @very_large_tensor +// CHECK-NEXT: arith.constant dense<1> +func.func @very_large_tensor() -> () { + %x = arith.constant dense<1> : tensor<4294967296xi32> + return +} diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir index 6c0b111..0fe63f5 100644 --- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir +++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir @@ -17,3 +17,9 @@ func.func @check_poison() { %3 = ub.poison : !llvm.ptr return } + +// CHECK-LABEL: @check_unrechable +func.func @check_unrechable() { +// CHECK: llvm.unreachable + ub.unreachable +} diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir index f497eb3..9c277cf 100644 --- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir +++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-ub-to-spirv -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-ub-to-spirv %s | FileCheck %s module attributes { spirv.target_env = #spirv.target_env< @@ -19,3 +19,20 @@ func.func @check_poison() { } } + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>> +} { +// CHECK-LABEL: @check_unrechable +func.func @check_unrechable(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: +// CHECK: spirv.Unreachable + ub.unreachable +^bb2: + return +} +} diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir index c87a530..8bb272b 100644 --- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir @@ -11,14 +11,15 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector // LOAD-ND-LABEL: @load_1D_vector( // LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, -// LOAD-ND-SAME: %[[OFFSET:.+]]: index -// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0] -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// LOAD-ND-SAME: %[[COLLAPSED]] -// LOAD-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32, -// LOAD-ND-SAME: boundary_check = false -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32> -// LOAD-ND: return %[[VEC]] +// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> +// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex> +// LOAD-ND-COUNT2: arith.muli {{.*}} : index +// LOAD-ND-COUNT2: arith.addi {{.*}} : index +// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index +// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32> // LOAD-GATHER-LABEL: @load_1D_vector( // LOAD-GATHER-SAME: %[[SRC:.+]]: memref<8x16x32xf32>, @@ -404,7 +405,7 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>, // ----- gpu.module @xevm_module { -gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> { +gpu.func @load_from_subview_1D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> { %c0 = arith.constant 0.0 : f16 %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> %0 = vector.transfer_read %subview[%off2, %off2], %c0 @@ -412,19 +413,23 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: gpu.return %0 : vector<8xf16> } -// LOAD-ND-LABEL: @load_from_subview( +// LOAD-ND-LABEL: @load_from_subview_1D( // LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, // LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> // LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> -// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0] -// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc -// LOAD-ND-SAME: %[[COLLAPSED]] -// LOAD-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16, -// LOAD-ND-SAME: boundary_check = false -// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]]]{{.*}}-> vector<8xf16> -// LOAD-ND: return %[[VEC]] - -// LOAD-GATHER-LABEL: @load_from_subview( +// LOAD-ND: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index +// LOAD-ND: %[[STEP:.+]] = vector.step : vector<8xindex> +// LOAD-ND: arith.muli {{.*}} : index +// LOAD-ND: arith.addi %[[OFFSET]]{{.*}} : index +// LOAD-ND: arith.addi {{.*}} : index +// LOAD-ND: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex> +// LOAD-ND: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex> +// LOAD-ND: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// LOAD-ND: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-ND: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> + +// LOAD-GATHER-LABEL: @load_from_subview_1D( // LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, // LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index // LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1> @@ -440,3 +445,42 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16> } + +// ----- +gpu.module @xevm_module { +gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8x16xf16> { + %c0 = arith.constant 0.0 : f16 + %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> + %0 = vector.transfer_read %subview[%off2, %off2], %c0 + {in_bounds = [true, true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8x16xf16> + gpu.return %0 : vector<8x16xf16> +} + +// LOAD-ND-LABEL: @load_from_subview_2D( +// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc +// LOAD-ND-SAME: %[[SUBVIEW]] +// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16, +// LOAD-ND-SAME: boundary_check = false +// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16> +// LOAD-ND: return %[[VEC]] + +// LOAD-GATHER-LABEL: @load_from_subview_2D( +// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>, +// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index +// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1> +// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> +// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index +// LOAD-GATHER-COUNT2: vector.step +// LOAD-GATHER-COUNT2: vector.shape_cast +// LOAD-GATHER-COUNT2: vector.broadcast +// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index +// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index +// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex> +// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex> +// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index +// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64 +// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16> +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 09ef76c..9a1e2cb 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -7,42 +7,41 @@ gpu.module @create_nd_tdesc { // CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel { gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel { + // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index + // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32 // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32 // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32 + // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32 // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32> %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] : ui64 -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32> %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32> - // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index - // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 + // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32 // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 - // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64 + // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> - // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -51,20 +50,16 @@ gpu.module @create_nd_tdesc { %size_x = arith.constant 64 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index %BLOCK_DMODEL = arith.constant 16 : index - // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index - // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32 - // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32 - // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32 - // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32 - // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 - // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64> - // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32> + // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32 + // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32> %dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16> gpu.return } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir new file mode 100644 index 0000000..aebec7f --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: @load_store( + // CHECK-SAME: %[[SRC:.*]]: memref<512xf32, 1>, %[[DST:.*]]: memref<256xf32, 1> + gpu.func @load_store(%src: memref<512xf32, 1>, %dst: memref<256xf32, 1>) kernel { + // CHECK: %[[C512:.*]] = arith.constant 512 : i64 + // CHECK: %[[C384:.*]] = arith.constant 384 : i64 + + // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32> + %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index + // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32> + %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32> + // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index + // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64 + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32> + // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64 + // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> + // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32> + %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<32xf32> -> vector<2xf32> + + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>> + // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64 + // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1> + // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}> + // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>) + xegpu.store_nd %loaded, %dst_tdesc[128] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> + : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index d4cb493..3a3769f 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) - //CHECK-LABEL: load_store_matrix_1 - gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { + //CHECK-LABEL: load_store_matrix_plain + gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> //CHECK: %[[TID:.*]] = gpu.thread_id x @@ -26,12 +26,40 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { gpu.return %1: f32 } + //CHECK-LABEL: load_store_matrix_plain_2d_input + gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3> + + %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> + + %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + + %tid_x = gpu.thread_id x + + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + + gpu.return %1: f32 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_2 - gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK-LABEL: load_store_matrix_blocked_strided + gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> - //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x //CHECK: %[[c13:.*]] = arith.constant 13 : index //CHECK: %[[c16:.*]] = arith.constant 16 : index @@ -39,7 +67,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c256:.*]] = arith.constant 256 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -53,39 +81,39 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 - + %tid_x = gpu.thread_id x %c13 = arith.constant 13 : index %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16 //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> - - xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index + + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index gpu.return %1: f16 } // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_3 - gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + //CHECK-LABEL: load_store_matrix_blocked_nostride + gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 { + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> - + //CHECK: %[[tid_x:.*]] = gpu.thread_id x //CHECK: %[[c19:.*]] = arith.constant 19 : index %tid_x = gpu.thread_id x %c19 = arith.constant 19: index - - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c1024:.*]] = arith.constant 1024 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -97,32 +125,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[c1:.*]] = arith.constant 1 : index //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index - //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16 - + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index - + //CHECK: gpu.return %[[loaded]] : f16 gpu.return %1: f16 } // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_4 - gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector + gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> - //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c256:.*]] = arith.constant 256 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -136,7 +161,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> - + %tid_x = gpu.thread_id x %c16 = arith.constant 16 : index %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16> @@ -147,28 +172,26 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { gpu.return %1: vector<8xf16> } - + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_5 - gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> - - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> - + //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio + gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index + //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32 + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[c48:.*]] = arith.constant 48 : index - %c16 = arith.constant 16 : index %c48 = arith.constant 48 : index - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c1024:.*]] = arith.constant 1024 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -183,7 +206,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 //CHECK: %[[c2:.*]] = arith.constant 2 : i32 //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 - //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32 //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> @@ -191,11 +214,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] { %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16> //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> - //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index gpu.return %1: vector<8xf16> } + gpu.func @matrix_vector_materialization(%matrixdesc : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>) { + // CHECK: %[[XEVM_VECTOR:.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> + // CHECK: %[[SOURCE_MATERIALIZE:.*]] = vector.shape_cast %[[XEVM_VECTOR]] : vector<16xf16> to vector<1x16xf16> + // CHECK: %[[XEGPU_VECTOR:.*]] = arith.addf %[[SOURCE_MATERIALIZE]], %[[SOURCE_MATERIALIZE]] : vector<1x16xf16> + // CHECK: %[[TARGET_MATERIALIZE:.*]] = vector.shape_cast %[[XEGPU_VECTOR]] : vector<1x16xf16> to vector<16xf16> + // CHECK: llvm.store %[[TARGET_MATERIALIZE]], %{{.*}} : vector<16xf16>, !llvm.ptr<3> + %loaded = xegpu.load_matrix %matrixdesc[16,0] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<1x16xf16> + %loaded_2 = arith.addf %loaded, %loaded : vector<1x16xf16> + xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + gpu.return + } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir index 4c6bbf2..4c73c9c 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -1,72 +1,32 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store( gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32 + // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32 + // CHECK: %[[H:.*]] = arith.constant 8 : i32 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32 - //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 - //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 - //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] + //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]] //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32, //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> - //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32> %tid_x = gpu.thread_id x %tid_x_i32 = arith.index_cast %tid_x : index to i32 %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 - //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> - // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> - //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32 - //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 - //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 - //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> - //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] - //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32, + //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{ + //CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32, //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir new file mode 100644 index 0000000..97e5ce1 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir @@ -0,0 +1,80 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store_matrix_a + // CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1> + gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel { + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64> + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32 + // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]] + // CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] + // CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64 + %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4> + // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]] + // CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] + // CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64 + %dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4> + + // CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4> + + // CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64> + // CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]], + // CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{ + // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32, + // CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, + // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + %loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<8x64xi4> -> vector<32xi4> + + // CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32> + // CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32> + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>> + + // CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64> + // CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1> + // CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]], + // CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{ + // CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) + xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> + : vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>> + gpu.return + } + + // CHECK-LABEL: gpu.func @load_matrix_b_request_pack + gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel { + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + %srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4> + %dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4> + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4> + + // CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{ + // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32, + // CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false, + // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + %loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<32x32xi4> -> vector<64xi4> + + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir index e4b3030..43df721 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -3,27 +3,16 @@ gpu.module @prefetch_nd_check { // CHECK-LABEL: gpu.func @prefetch_nd gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { - // CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.constant 64 : i32 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = arith.constant dense<0> : vector<4xi64> - // CHECK: %[[PREF_BASE_H:.*]] = arith.constant 8 : i32 - // CHECK: %[[PREF_BASE_W:.*]] = arith.constant 16 : i32 + // CHECK: %[[BASE_WIDTH_PITCH_BYTES:.*]] = arith.constant 64 : i32 // CHECK: %[[OFFSET_ZERO:.*]] = arith.constant 0 : i32 + // CHECK: %[[BASE_H:.*]] = arith.constant 8 : i32 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert %[[PREF_BASE_W]], %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert %[[PREF_BASE_H]], %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]], - //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]] + //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> + //CHECK: xevm.blockprefetch2d %[[LLVMPTR]], %[[BASE_WIDTH_PITCH_BYTES]], %[[BASE_H]], + //CHECK-SAME: %[[BASE_WIDTH_PITCH_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]] //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32, //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32) diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir new file mode 100644 index 0000000..f925472 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @prefetch_check { + // CHECK-LABEL: gpu.func @prefetch_matrix_a + gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel { + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4> + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4> + + // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] + // CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1> + xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<8x64xi4> + + gpu.return + } +} diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index 72e70ff..7f01526 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -175,7 +175,7 @@ llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %bas // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( // CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes -// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind} +// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind} // CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, // CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) { llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { @@ -187,7 +187,7 @@ llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i( // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind, // CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64 xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, @@ -200,13 +200,13 @@ llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i // CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( // CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes // CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, -// CHECK-SAME: inaccessibleMem = none>, no_unwind, will_return} +// CHECK-SAME: inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind, will_return} // CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> { llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f( // CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = // CHECK-SAME: !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind, // CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} // CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted @@ -230,13 +230,13 @@ llvm.func @memfence() { // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes -// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind} +// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, no_unwind} // CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) { llvm.func @prefetch(%ptr: !llvm.ptr<1>) { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64 xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>) llvm.return @@ -352,7 +352,7 @@ llvm.func @local_id.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.local_id.x : i32 llvm.return %1 : i32 @@ -380,7 +380,7 @@ llvm.func @local_size.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.local_size.x : i32 llvm.return %1 : i32 @@ -408,7 +408,7 @@ llvm.func @group_id.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.group_id.x : i32 llvm.return %1 : i32 @@ -436,7 +436,7 @@ llvm.func @group_count.x() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]]) // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32 %1 = xevm.group_count.x : i32 llvm.return %1 : i32 @@ -463,7 +463,7 @@ llvm.func @group_count.z() -> i32 { llvm.func @lane_id() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32 %1 = xevm.lane_id : i32 llvm.return %1 : i32 @@ -474,7 +474,7 @@ llvm.func @lane_id() -> i32 { llvm.func @subgroup_size() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size() // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32 %1 = xevm.subgroup_size : i32 llvm.return %1 : i32 @@ -485,7 +485,7 @@ llvm.func @subgroup_size() -> i32 { llvm.func @subgroup_id() -> i32 { // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>, - // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, + // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>, // CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32 %1 = xevm.subgroup_id : i32 llvm.return %1 : i32 diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir new file mode 100644 index 0000000..9d43c99 --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/amdgpu-make-dma-descriptor-fold.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt --canonicalize %s | FileCheck %s + +// CHECK-LABEL: @make_dma_descriptor_fold +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[IDX:.+]]: index) +func.func @make_dma_descriptor_fold(%base: !amdgpu.tdm_base<i32>, %idx: index) -> !amdgpu.tdm_descriptor { + %c64 = arith.constant 64 : index + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + %0 = amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [%c64, %c64] + // CHECK-SAME: globalStride [64, 1] + globalStride [%c64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [%c64, %c64] + iterate %idx, %idx, %idx + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return %0 : !amdgpu.tdm_descriptor +} diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index fee0c00..cff1d3f 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -244,3 +244,39 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4 %res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32> } + +// ----- + +// CHECK-LABEL fuse_memory_counter_wait +func.func @fuse_memory_counter_wait() { + // CHECK: amdgpu.memory_counter_wait + // CHECK-SAME: load(1) store(2) ds(2) exp(1) tensor(0) + // CHECK-NEXT: return + amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5) + amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) tensor(0) + return +} + +// CHECK-LABEL fuse_memory_counter_wait_different_counters +func.func @fuse_memory_counter_wait_different_counters() { + // CHECK: amdgpu.memory_counter_wait + // CHECK-SAME: load(1) store(2) ds(3) exp(4) + // CHECK-NEXT: return + amdgpu.memory_counter_wait load(1) store(2) + amdgpu.memory_counter_wait ds(3) exp(4) + return +} + +func.func private @use() + +// CHECK-LABEL fuse_memory_counter_wait_not_adjacent +func.func @fuse_memory_counter_wait_not_adjacent() { + // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) + // CHECK-NEXT: call @use() + // CHECK-NEXT: amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) + // CHECK-NEXT: return + amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) + func.call @use() : () -> () + amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1) + return +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 4c6f62a..6308ea9 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -333,48 +333,86 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return +func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} + %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> + func.return %0 : vector<16xf32> } // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return +func.func @scaled_mfma_invalid_n(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} + %0 = amdgpu.scaled_mfma 32x8x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> + func.return %0 : vector<16xf32> +} + +// ----- + +func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {64, 128}}} + %0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> + func.return %0 : vector<16xf32> +} + +// ----- + +func.func @make_dma_base_invalid_addressspace(%idx: index, %mem: memref<8xi32>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op lds memref must have workgroup address space attribute.}} + amdgpu.make_dma_base %mem[%idx], %mem[%idx] : memref<8xi32>, memref<8xi32> -> !amdgpu.tdm_base<i32> +} + +// ----- + +func.func @make_dma_base_invalid_addressspace(%idx: index, %smem : memref<8xi32, #gpu.address_space<workgroup>>) { + // expected-error@+1 {{'amdgpu.make_dma_base' op global memref must have global address space attribute.}} + amdgpu.make_dma_base %smem[%idx], %smem[%idx] : memref<8xi32, #gpu.address_space<workgroup>>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> +} + +// ----- + +func.func @make_dma_base_invalid_barrier(%base: !amdgpu.tdm_base<i32>, %barrier: memref<8xi32>, %idx: index) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op atomic barrier address must be in LDS.}} + amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] atomicBarrier(%barrier[%idx] : memref<8xi32>) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor } // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16> +// CHECK-LABEL: func @make_dma_descriptor_invalid_empty_strides +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_empty_strides(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides must not be empty.}} + amdgpu.make_dma_descriptor %base globalSize [0, 1] globalStride [] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor func.return } // ----- -func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { - // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} - %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> - func.return %0 : vector<16xf32> +// CHECK-LABEL: func @make_dma_descriptor_invalid_innermost_stride +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_innermost_stride(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides for the innermost dimension must be 1.}} + amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [1, 2] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } // ----- -func.func @scaled_mfma_invalid_n(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { - // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} - %0 = amdgpu.scaled_mfma 32x8x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> - func.return %0 : vector<16xf32> +// CHECK-LABEL: func @make_dma_descriptor_invalid_size_and_stride_sizes +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_size_and_stride_sizes(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op strides and sizes must have same rank.}} + amdgpu.make_dma_descriptor %base globalSize [1, 1, 1] globalStride [1, 1] sharedSize [1, 0] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } // ----- -func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { - // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {64, 128}}} - %0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> - func.return %0 : vector<16xf32> +// CHECK-LABEL: func @make_dma_descriptor_invalid_shared_and_global_rank +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>) +func.func @make_dma_descriptor_invalid_shared_and_global_rank(%base: !amdgpu.tdm_base<i32>) { + // expected-error@+1 {{'amdgpu.make_dma_descriptor' op tensor must have same rank as tile.}} + amdgpu.make_dma_descriptor %base globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2, 3] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } + diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 09134cb..651aff4 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -221,58 +221,58 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) -> func.return %ret : vector<2xbf16> } -// CHECK-LABEL: func.func @scaled_ext_packed816_fp4 -func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp4 +func.func @scaled_ext_packed_matrix_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_fp8 -func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp8 +func.func @scaled_ext_packed_matrix_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_bf8 -func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_bf8 +func.func @scaled_ext_packed_matrix_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_fp6 -func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_fp6 +func.func @scaled_ext_packed_matrix_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32> } -// CHECK-LABEL: func.func @scaled_ext_packed816_bf16 -func.func @scaled_ext_packed816_bf16(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { - // CHECK: amdgpu.scaled_ext_packed816 - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - // CHECK: amdgpu.scaled_ext_packed816 - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> +// CHECK-LABEL: func.func @scaled_ext_packed_matrix_bf6 +func.func @scaled_ext_packed_matrix_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret0 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret1 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + // CHECK: amdgpu.scaled_ext_packed_matrix + %ret2 = amdgpu.scaled_ext_packed_matrix %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32> } @@ -671,17 +671,105 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, % // CHECK-LABEL: func @memory_counter_wait func.func @memory_counter_wait() { - // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) - // CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1) + // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5) + // CHECK: amdgpu.memory_counter_wait load(4) store(2) ds(3) exp(1) tensor(0) // CHECK: amdgpu.memory_counter_wait load(1) // CHECK: amdgpu.memory_counter_wait store(2) // CHECK: amdgpu.memory_counter_wait ds(3) // CHECK: amdgpu.memory_counter_wait exp(4) - amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) - amdgpu.memory_counter_wait exp(1) store(2) ds(3) load(4) + // CHECK: amdgpu.memory_counter_wait tensor(5) + amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4) tensor(5) + amdgpu.memory_counter_wait tensor(0) exp(1) store(2) ds(3) load(4) amdgpu.memory_counter_wait load(1) amdgpu.memory_counter_wait store(2) amdgpu.memory_counter_wait ds(3) amdgpu.memory_counter_wait exp(4) + amdgpu.memory_counter_wait tensor(5) + func.return +} + +// CHECK-LABEL: func @make_dma_base +// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32>, %[[SMEM:.+]]: memref<8xi32, #gpu.address_space<workgroup>>) +func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space<workgroup>>) { + // CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> + amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32> + func.return +} + +// CHECK-LABEL: func @make_dma_descriptor +// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[WG_MASK:.+]]: i16, %[[TIMEOUT:.+]]: i1, %[[BARRIER:.+]]: memref<8xi32, #gpu.address_space<workgroup>>, %[[IDX:.+]]: index) +func.func @make_dma_descriptor(%base: !amdgpu.tdm_base<i32>, %wg_mask: i16, %timeout: i1, %barrier: memref<8xi32, #gpu.address_space<workgroup>>, %idx: index) { + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + sharedSize [64, 64] : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: padShared(%[[IDX]] every %[[IDX]]) + padShared(%idx every %idx) + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: workgroupMask %[[WG_MASK]] + workgroupMask %wg_mask + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: workgroupMask %[[WG_MASK]] + workgroupMask %wg_mask + // CHECK-SAME: earlyTimeout %[[TIMEOUT]] + earlyTimeout %timeout + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: atomicBarrier(%[[BARRIER]][%[[IDX]]] : memref<8xi32, #gpu.address_space<workgroup>>) + atomicBarrier(%barrier[%idx] : memref<8xi32, #gpu.address_space<workgroup>>) + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + + // CHECK: amdgpu.make_dma_descriptor %[[BASE]] + amdgpu.make_dma_descriptor %base + // CHECK-SAME: globalSize [64, 64] + globalSize [64, 64] + // CHECK-SAME: globalStride [64, 1] + globalStride [64, 1] + // CHECK-SAME: sharedSize [64, 64] + sharedSize [64, 64] + // CHECK-SAME: iterate %[[IDX]], %[[IDX]], %[[IDX]] + iterate %idx, %idx, %idx + : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor + func.return } diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir index 3be14ea..6a82532 100644 --- a/mlir/test/Dialect/Affine/loop-coalescing.mlir +++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir @@ -416,3 +416,31 @@ func.func @test_loops_do_not_get_coalesced() { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @inner_loop_has_iter_args +// CHECK-SAME: %[[ALLOC:.*]]: memref<?xi64>) +func.func @inner_loop_has_iter_args(%alloc : memref<?xi64>) { + %c17 = arith.constant 17 : index + affine.for %arg0 = 0 to 79 { + %0 = affine.for %arg1 = 0 to 64 iter_args(%arg2 = %alloc) -> (memref<?xi64>) { + %1 = arith.remui %arg1, %c17 : index + %2 = arith.index_cast %arg1 : index to i64 + memref.store %2, %arg2[%1] : memref<?xi64> + affine.yield %arg2 : memref<?xi64> + } + } + return +} + +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 17 : index +// CHECK: %[[APPLY_0:.*]] = affine.apply affine_map<() -> (79)>() +// CHECK: %[[APPLY_1:.*]] = affine.apply affine_map<() -> (64)>() +// CHECK: %[[APPLY_2:.*]] = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%[[APPLY_0]]){{\[}}%[[APPLY_1]]] +// CHECK: affine.for %[[IV:.*]] = 0 to %[[APPLY_2]] { +// CHECK: %[[APPLY_3:.*]] = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)>(%[[IV]]){{\[}}%[[APPLY_1]]] +// CHECK: %[[REMUI_0:.*]] = arith.remui %[[APPLY_3]], %[[CONSTANT_0]] : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[APPLY_3]] : index to i64 +// CHECK: memref.store %[[INDEX_CAST_0]], %[[ALLOC]]{{\[}}%[[REMUI_0]]] : memref<?xi64> +// CHECK: } diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir index 817614b..2e80102 100644 --- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index // CHECK: "test.some_use"(%[[c5]]) // CHECK: %[[c5:.*]] = arith.constant 5 : index // CHECK: "test.some_use"(%[[c5]]) -func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) { +func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: i32) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index scf.for %iv = %c0 to %ub step %c4 { %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub] %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32> - %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> + %filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) "test.some_use"(%bound) : (index) -> () diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 2fe0995..3ad1530 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2958,6 +2958,19 @@ func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 { return %hi : i32 } +// CHECK-LABEL: @truncIShrSIExactToTrunciShrUIExact +// CHECK-SAME: (%[[A:.+]]: i64) +// CHECK-NEXT: %[[C32:.+]] = arith.constant 32 : i64 +// CHECK-NEXT: %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] exact : i64 +// CHECK-NEXT: %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32 +// CHECK-NEXT: return %[[TRU]] : i32 +func.func @truncIShrSIExactToTrunciShrUIExact(%a: i64) -> i32 { + %c32 = arith.constant 32: i64 + %sh = arith.shrsi %a, %c32 exact : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + // CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1 // CHECK: arith.shrsi func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 { diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 1e656e8..58eadfd 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -151,6 +151,12 @@ func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_divui_exact +func.func @test_divui_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.divui %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_divui_tensor func.func @test_divui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.divui %arg0, %arg1 : tensor<8x8xi64> @@ -175,6 +181,12 @@ func.func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_divsi_exact +func.func @test_divsi_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.divsi %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_divsi_tensor func.func @test_divsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.divsi %arg0, %arg1 : tensor<8x8xi64> @@ -391,6 +403,12 @@ func.func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_shrui_exact +func.func @test_shrui_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.shrui %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_shrui_tensor func.func @test_shrui_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.shrui %arg0, %arg1 : tensor<8x8xi64> @@ -415,6 +433,12 @@ func.func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 { return %0 : i64 } +// CHECK-LABEL: test_shrsi_exact +func.func @test_shrsi_exact(%arg0 : i64, %arg1 : i64) -> i64 { + %0 = arith.shrsi %arg0, %arg1 exact : i64 + return %0 : i64 +} + // CHECK-LABEL: test_shrsi_tensor func.func @test_shrsi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { %0 = arith.shrsi %arg0, %arg1 : tensor<8x8xi64> diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 8249d59..3929f5b 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32 // ----- -// `EmptyTensorElimination` fails to find a valid insertion -// point for the new injected `SubsetExtraction`. -// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors -func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> { +// CHECK-LABEL: func.func @eliminate_all_empty_tensors +func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> { %cst_1 = arith.constant 1.0 : f32 %cst_2 = arith.constant 2.0 : f32 - // CHECK: memref.alloc - // CHECK: memref.alloc - // CHECK: memref.alloc + // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> + // CHECK-NOT: memref.alloc %empty_1 = tensor.empty() : tensor<5x6x64xf32> %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %empty_2 = tensor.empty() : tensor<5x6x64xf32> %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> - // CHECK: memref.copy + // CHECK-NOT: memref.copy %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] @@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> { // ----- -// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor -func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { +// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors +func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> { %cst_1 = arith.constant 1.0 : f32 %cst_2 = arith.constant 2.0 : f32 // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> - // CHECK: memref.alloc // CHECK-NOT: memref.alloc - %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> + %concatenated_empty = tensor.empty() : tensor<5x6x128xf32> %empty_1 = tensor.empty() : tensor<5x6x64xf32> %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %empty_2 = tensor.empty() : tensor<5x6x64xf32> %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> - // CHECK: memref.copy - %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] + // CHECK-NOT: memref.copy + %inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> @@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> { // CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty // CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty +// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32> +// CHECK-NOT: memref.alloc +// CHECK-NOT: memref.copy +// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0] +// CHECK-ELIM: linalg.fill +// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64] +// CHECK-ELIM: linalg.fill func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> { %cst_1 = arith.constant 1.0 : f32 %cst_2 = arith.constant 2.0 : f32 %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32> %empty_1 = tensor.empty() : tensor<5x6x64xf32> - // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice - // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]] - // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]] %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32> - // CHECK: memref.copy %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> - // CHECK-NOT: memref.copy %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1] : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_2 : tensor<5x6x128xf32> @@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_1 : tensor<5x6x128xf32> } + +// ----- + +// Test that dependent pure operations are moved before the +// insertion point to enable empty tensor elimination. + +// CHECK-LABEL: func.func @move_dependent_arith_op( +// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-NOT: memref.alloc +// CHECK: %[[C5:.*]] = arith.constant 5 : index +// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]] +// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1] +// CHECK: linalg.fill {{.*}} outs(%[[SV]] +// CHECK: return %[[ARG0]] +// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op( +// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32> +// CHECK-ELIM-SAME: %[[ARG1:.*]]: index +// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index +// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]] +// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1] +// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]] +// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]] +func.func @move_dependent_arith_op( + %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true}, + %arg1: index, %f: f32) -> tensor<10xf32> +{ + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %c5 = arith.constant 5 : index + %offset = arith.addi %arg1, %c5 : index + %2 = tensor.insert_slice %1 into %arg0[%offset][5][1] + : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +} + +// ----- + +// Test that side-effecting operations are not moved, preventing empty +// tensor elimination. + +// CHECK-LABEL: func.func @side_effecting_op_blocks_movement( +// CHECK: memref.alloc +// CHECK: linalg.fill +// CHECK: memref.load +// CHECK: memref.subview +// CHECK: memref.copy +// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement( +// CHECK-ELIM: tensor.empty +// CHECK-ELIM: linalg.fill +// CHECK-ELIM: memref.load +// CHECK-ELIM: tensor.insert_slice +func.func @side_effecting_op_blocks_movement( + %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true}, + %mem: memref<index>, %f: f32) -> tensor<10xf32> +{ + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %offset = memref.load %mem[] : memref<index> + %2 = tensor.insert_slice %1 into %arg0[%offset][5][1] + : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +} diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir index 2c8807b..9884b04 100644 --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() { // expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}} arith.constant {bufferization.manual_deallocation} 0 : index } + +// ----- + +func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3xf32> + return +} + +// ----- + +func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x4x3xf32> + return +} + +// ----- + +func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{shapes do not match}} + %t = bufferization.to_tensor %b + : memref<1x2x4x3xf32> to tensor<1x2x3x4xf32> + return +} + +// ----- + +func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) { + // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %b = bufferization.to_buffer %t + : tensor<1x2x3x4xf32> to memref<1x2x3x4xf16> + return +} + +// ----- + +func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) { + // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}} + // expected-error @below{{element types do not match}} + %t2 = bufferization.to_tensor %b + : memref<1x2x3x4xf16> to tensor<1x2x3x4xf32> + return +} diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir index fc6df4a..b0db1bb 100644 --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>, bufferization.dealloc return %0#0, %0#1 : i1, i1 } + +// CHECK: func.func @test_builtin_custom_builtin_type_conversion +// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32> +func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>) + -> tensor<42xf32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to !test.test_memref<[42], f32> + %buffer = bufferization.to_buffer %t + : tensor<42xf32> to !test.test_memref<[42], f32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to tensor<42xf32> + %tensor = bufferization.to_tensor %buffer + : !test.test_memref<[42], f32> to tensor<42xf32> + + // CHECK: return %[[tensor]] + return %tensor : tensor<42xf32> +} + +// CHECK: func.func @test_custom_builtin_custom_type_conversion +// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>) +// CHECK-SAME: -> !test.test_tensor<[42], f32> +func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>) + -> !test.test_tensor<[42], f32> { + // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]] + // CHECK-SAME: to memref<42xf32> + %buffer = bufferization.to_buffer %t + : !test.test_tensor<[42], f32> to memref<42xf32> + + // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]] + // CHECK-SAME: to !test.test_tensor<[42], f32> + %tensor = bufferization.to_tensor %buffer + : memref<42xf32> to !test.test_tensor<[42], f32> + + // CHECK: return %[[tensor]] + return %tensor : !test.test_tensor<[42], f32> +} diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir index 17f7d28..21a1678 100644 --- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir +++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir @@ -634,3 +634,25 @@ func.func @unsimplified_cycle_2(%c : i1) { ^bb7: cf.br ^bb6 } + +// CHECK-LABEL: @drop_unreachable_branch_1 +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: return +func.func @drop_unreachable_branch_1(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + "test.foo"() : () -> () + return +^bb2: + ub.unreachable +} + +// CHECK-LABEL: @drop_unreachable_branch_2 +// CHECK-NEXT: ub.unreachable +func.func @drop_unreachable_branch_2(%c: i1) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + ub.unreachable +^bb2: + ub.unreachable +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index f285196..d1601be 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -914,3 +914,19 @@ func.func @test_for_unmatch_type(%arg0: index) { ) : (index, index, index) -> () return } + +// ----- + +func.func @address_of(%arg0: !emitc.lvalue<i32>) { + // expected-error @+1 {{failed to verify that input and result reference the same type}} + %1 = "emitc.address_of"(%arg0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i8> + return +} + +// ----- + +func.func @dereference(%arg0: !emitc.ptr<i32>) { + // expected-error @+1 {{failed to verify that input and result reference the same type}} + %1 = "emitc.dereference"(%arg0) : (!emitc.ptr<i32>) -> !emitc.lvalue<i8> + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 1259748..b2c8b84 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -355,3 +355,13 @@ func.func @do(%arg0 : !emitc.ptr<i32>) { return } + +func.func @address_of(%arg0: !emitc.lvalue<i32>) { + %1 = emitc.address_of %arg0 : !emitc.lvalue<i32> + return +} + +func.func @dereference(%arg0: !emitc.ptr<i32>) { + %1 = emitc.dereference %arg0 : !emitc.ptr<i32> + return +} diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 35381da..26bcf94 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -688,7 +688,7 @@ func.func @mmamatrix_operand_type(){ func.func @mmamatrix_invalid_element_type(){ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index - // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}} + // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64}} %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp"> return } @@ -708,7 +708,7 @@ func.func @mmaLoadOp_identity_layout(){ // ----- func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) { - // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}} + // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float values of ranks 1 values}} %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp"> return } diff --git a/mlir/test/Dialect/IRDL/variadics.mlir b/mlir/test/Dialect/IRDL/variadics.mlir index a8871fc..873f248 100644 --- a/mlir/test/Dialect/IRDL/variadics.mlir +++ b/mlir/test/Dialect/IRDL/variadics.mlir @@ -133,7 +133,7 @@ func.func @testOptOperandFail(%x: i16) { // Check that an operation with multiple variadics expects the segment size // attribute func.func @testMultOperandsMissingSegment(%x: i16, %z: i64) { - // expected-error@+1 {{'operand_segment_sizes' attribute is expected but not provided}} + // expected-error@+1 {{'operandSegmentSizes' attribute is expected but not provided}} "testvar.var_and_opt_operand"(%x, %x, %z) : (i16, i16, i64) -> () return } @@ -143,8 +143,8 @@ func.func @testMultOperandsMissingSegment(%x: i16, %z: i64) { // Check that an operation with multiple variadics expects the segment size // attribute of the right type func.func @testMultOperandsWrongSegmentType(%x: i16, %z: i64) { - // expected-error@+1 {{'operand_segment_sizes' attribute is expected to be a dense i32 array}} - "testvar.var_and_opt_operand"(%x, %x, %z) {operand_segment_sizes = i32} : (i16, i16, i64) -> () + // expected-error@+1 {{'operandSegmentSizes' attribute is expected to be a dense i32 array}} + "testvar.var_and_opt_operand"(%x, %x, %z) {operandSegmentSizes = i32} : (i16, i16, i64) -> () return } @@ -153,12 +153,12 @@ func.func @testMultOperandsWrongSegmentType(%x: i16, %z: i64) { // Check that an operation with multiple variadics with the right segment size // verifies. func.func @testMultOperands(%x: i16, %y: i32, %z: i64) { - "testvar.var_and_opt_operand"(%x, %x, %z) {operand_segment_sizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () - // CHECK: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () - "testvar.var_and_opt_operand"(%x, %x, %y, %z) {operand_segment_sizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () - // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () - "testvar.var_and_opt_operand"(%y, %z) {operand_segment_sizes = array<i32: 0, 1, 1>} : (i32, i64) -> () - // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}) {operand_segment_sizes = array<i32: 0, 1, 1>} : (i32, i64) -> () + "testvar.var_and_opt_operand"(%x, %x, %z) {operandSegmentSizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () + // CHECK: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 2, 0, 1>} : (i16, i16, i64) -> () + "testvar.var_and_opt_operand"(%x, %x, %y, %z) {operandSegmentSizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () + // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 2, 1, 1>} : (i16, i16, i32, i64) -> () + "testvar.var_and_opt_operand"(%y, %z) {operandSegmentSizes = array<i32: 0, 1, 1>} : (i32, i64) -> () + // CHECK-NEXT: "testvar.var_and_opt_operand"(%{{.*}}, %{{.*}}) {operandSegmentSizes = array<i32: 0, 1, 1>} : (i32, i64) -> () return } @@ -166,8 +166,8 @@ func.func @testMultOperands(%x: i16, %y: i32, %z: i64) { // Check that the segment sizes expects non-negative values func.func @testMultOperandsSegmentNegative() { - // expected-error@+1 {{'operand_segment_sizes' attribute for specifying operand segments must have non-negative values}} - "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 2, -1, 1>} : () -> () + // expected-error@+1 {{'operandSegmentSizes' attribute for specifying operand segments must have non-negative values}} + "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 2, -1, 1>} : () -> () return } @@ -175,8 +175,8 @@ func.func @testMultOperandsSegmentNegative() { // Check that the segment sizes expects 1 for single values func.func @testMultOperandsSegmentWrongSingle() { - // expected-error@+1 {{element 2 in 'operand_segment_sizes' attribute must be equal to 1}} - "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 0, 0, 0>} : () -> () + // expected-error@+1 {{element 2 in 'operandSegmentSizes' attribute must be equal to 1}} + "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 0, 0, 0>} : () -> () return } @@ -184,8 +184,8 @@ func.func @testMultOperandsSegmentWrongSingle() { // Check that the segment sizes expects not more than 1 for optional values func.func @testMultOperandsSegmentWrongOptional() { - // expected-error@+1 {{element 1 in 'operand_segment_sizes' attribute must be equal to 0 or 1}} - "testvar.var_and_opt_operand"() {operand_segment_sizes = array<i32: 0, 2, 0>} : () -> () + // expected-error@+1 {{element 1 in 'operandSegmentSizes' attribute must be equal to 0 or 1}} + "testvar.var_and_opt_operand"() {operandSegmentSizes = array<i32: 0, 2, 0>} : () -> () return } @@ -193,8 +193,8 @@ func.func @testMultOperandsSegmentWrongOptional() { // Check that the sum of the segment sizes should be equal to the number of operands func.func @testMultOperandsSegmentWrongOptional(%y: i32, %z: i64) { - // expected-error@+1 {{sum of elements in 'operand_segment_sizes' attribute must be equal to the number of operands}} - "testvar.var_and_opt_operand"(%y, %z) {operand_segment_sizes = array<i32: 0, 0, 1>} : (i32, i64) -> () + // expected-error@+1 {{sum of elements in 'operandSegmentSizes' attribute must be equal to the number of operands}} + "testvar.var_and_opt_operand"(%y, %z) {operandSegmentSizes = array<i32: 0, 0, 1>} : (i32, i64) -> () return } @@ -334,7 +334,7 @@ func.func @testOptResultFail() { // Check that an operation with multiple variadics expects the segment size // attribute func.func @testMultResultsMissingSegment() { - // expected-error@+1 {{'result_segment_sizes' attribute is expected but not provided}} + // expected-error@+1 {{'resultSegmentSizes' attribute is expected but not provided}} "testvar.var_and_opt_result"() : () -> (i16, i16, i64) return } @@ -344,8 +344,8 @@ func.func @testMultResultsMissingSegment() { // Check that an operation with multiple variadics expects the segment size // attribute of the right type func.func @testMultResultsWrongSegmentType() { - // expected-error@+1 {{'result_segment_sizes' attribute is expected to be a dense i32 array}} - "testvar.var_and_opt_result"() {result_segment_sizes = i32} : () -> (i16, i16, i64) + // expected-error@+1 {{'resultSegmentSizes' attribute is expected to be a dense i32 array}} + "testvar.var_and_opt_result"() {resultSegmentSizes = i32} : () -> (i16, i16, i64) return } @@ -354,12 +354,12 @@ func.func @testMultResultsWrongSegmentType() { // Check that an operation with multiple variadics with the right segment size // verifies. func.func @testMultResults() { - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) - // CHECK: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) - // CHECK-NEXT: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 1, 1>} : () -> (i32, i64) - // CHECK-NEXT: "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 1, 1>} : () -> (i32, i64) + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) + // CHECK: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 0, 1>} : () -> (i16, i16, i64) + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) + // CHECK-NEXT: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, 1, 1>} : () -> (i16, i16, i32, i64) + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 1, 1>} : () -> (i32, i64) + // CHECK-NEXT: "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 1, 1>} : () -> (i32, i64) return } @@ -367,8 +367,8 @@ func.func @testMultResults() { // Check that the segment sizes expects non-negative values func.func @testMultResultsSegmentNegative() { - // expected-error@+1 {{'result_segment_sizes' attribute for specifying result segments must have non-negative values}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 2, -1, 1>} : () -> () + // expected-error@+1 {{'resultSegmentSizes' attribute for specifying result segments must have non-negative values}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 2, -1, 1>} : () -> () return } @@ -376,8 +376,8 @@ func.func @testMultResultsSegmentNegative() { // Check that the segment sizes expects 1 for single values func.func @testMultResultsSegmentWrongSingle() { - // expected-error@+1 {{element 2 in 'result_segment_sizes' attribute must be equal to 1}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 0, 0>} : () -> () + // expected-error@+1 {{element 2 in 'resultSegmentSizes' attribute must be equal to 1}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 0, 0>} : () -> () return } @@ -385,8 +385,8 @@ func.func @testMultResultsSegmentWrongSingle() { // Check that the segment sizes expects not more than 1 for optional values func.func @testMultResultsSegmentWrongOptional() { - // expected-error@+1 {{element 1 in 'result_segment_sizes' attribute must be equal to 0 or 1}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 2, 0>} : () -> () + // expected-error@+1 {{element 1 in 'resultSegmentSizes' attribute must be equal to 0 or 1}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 2, 0>} : () -> () return } @@ -394,7 +394,7 @@ func.func @testMultResultsSegmentWrongOptional() { // Check that the sum of the segment sizes should be equal to the number of results func.func @testMultResultsSegmentWrongOptional() { - // expected-error@+1 {{sum of elements in 'result_segment_sizes' attribute must be equal to the number of results}} - "testvar.var_and_opt_result"() {result_segment_sizes = array<i32: 0, 0, 1>} : () -> (i32, i64) + // expected-error@+1 {{sum of elements in 'resultSegmentSizes' attribute must be equal to the number of results}} + "testvar.var_and_opt_result"() {resultSegmentSizes = array<i32: 0, 0, 1>} : () -> (i32, i64) return } diff --git a/mlir/test/Dialect/Index/inliner-interface.mlir b/mlir/test/Dialect/Index/inliner-interface.mlir new file mode 100644 index 0000000..4c3d106 --- /dev/null +++ b/mlir/test/Dialect/Index/inliner-interface.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -inline | FileCheck %s + +// CHECK-LABEL: @main +func.func @main(%arg0: i32) -> index { + // CHECK-NOT: call + // CHECK: index.castu + %0 = call @f(%arg0) : (i32) -> index + return %0 : index +} + +// CHECK-LABEL: @f +func.func @f(%arg0: i32) -> index { + %0 = index.castu %arg0 : i32 to index + return %0 : index +} diff --git a/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir b/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir index dfbf992..ffeb871 100644 --- a/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir +++ b/mlir/test/Dialect/LLVMIR/add-debuginfo-func-scope.mlir @@ -141,3 +141,22 @@ module { llvm.func @func_callsiteloc() loc(callsite("foo" at "mysource.cc":10:8)) } loc(unknown) +// ----- + +// CHECK-LABEL: llvm.func @func_cross_file_op() +// CHECK: #di_file = #llvm.di_file<"<unknown>" in ""> +// CHECK: #di_file1 = #llvm.di_file<"caller.py" in ""> +// CHECK: #di_file2 = #llvm.di_file<"callee.py" in ""> +// CHECK: #di_subroutine_type = #llvm.di_subroutine_type<callingConvention = DW_CC_normal> +// CHECK: #di_subprogram = #llvm.di_subprogram<id = distinct[1]<>, compileUnit = #di_compile_unit, scope = #di_file1, name = "func_cross_file_op", linkageName = "func_cross_file_op", file = #di_file1, line = 5, scopeLine = 5, subprogramFlags = "Definition|Optimized", type = #di_subroutine_type> +// CHECK: #di_lexical_block_file = #llvm.di_lexical_block_file<scope = #di_subprogram, file = #di_file2, discriminator = 0> + +#loc = loc("caller.py":5:1) +#loc1 = loc("callee.py":10:5) + +module { + llvm.func @func_cross_file_op() { + llvm.return loc(#loc1) + } loc(#loc) +} loc(unknown) + diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir index cec4586..094313c 100644 --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -210,8 +210,8 @@ module { } // CHECK-LABEL: llvm.func @memory_attr - // CHECK-SAME: attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite>} { - llvm.func @memory_attr() attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite>} { + // CHECK-SAME: attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { + llvm.func @memory_attr() attributes {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir index 8e292f4..9a77c5e 100644 --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -422,7 +422,7 @@ llvm.func @test_byval(%ptr : !llvm.ptr) { // ----- -llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} { +llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } @@ -436,7 +436,7 @@ llvm.func @test_byval_read_only(%ptr : !llvm.ptr) { // ----- -llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = write, inaccessibleMem = readwrite>} { +llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = write, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } @@ -451,7 +451,7 @@ llvm.func @test_byval_write_only(%ptr : !llvm.ptr) { // ----- -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } @@ -472,7 +472,7 @@ llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } @@ -496,7 +496,7 @@ module attributes { llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } @@ -524,7 +524,7 @@ module attributes { llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } @@ -550,7 +550,7 @@ llvm.func @test_alignment_exceeded_anyway() { llvm.mlir.global private @unaligned_global(42 : i64) : i64 llvm.mlir.global private @aligned_global(42 : i64) { alignment = 64 } : i64 -llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read>} { +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} { llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir b/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir new file mode 100644 index 0000000..bdc98ed --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/invalid-cg-profile.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate %s -mlir-to-llvmir | FileCheck %s +// CHECK: !llvm.module.flags = !{![[CG_FLAG:[0-9]+]], ![[DBG_FLAG:[0-9]+]]} +// CHECK: ![[CG_FLAG]] = !{i32 5, !"CG Profile", ![[CG_LIST:[0-9]+]]} +// CHECK: ![[CG_LIST]] = distinct !{![[CG_ENTRY:[0-9]+]], ![[CG_ENTRY]], ![[CG_ENTRY]]} +// CHECK: ![[CG_ENTRY]] = !{null, null, i64 222} +// CHECK: ![[DBG_FLAG]] = !{i32 2, !"Debug Info Version", i32 3} + +module { + llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [ + #llvm.cgprofile_entry<from = @from, to = @to, count = 222>, + #llvm.cgprofile_entry<from = @from, count = 222>, + #llvm.cgprofile_entry<from = @to, to = @from, count = 222> + ]>] +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir new file mode 100644 index 0000000..ff3e91b --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir @@ -0,0 +1,221 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for sparse MMA (mma.sp.sync) operations with KIND variants. +// The kind::f8f6f4 variant was introduced in PTX ISA 8.7 for sm_90+ architectures. +// +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// KIND::F8F6F4 enables: +// - Additional FP8 types: e3m2, e2m3, e2m1 +// - F16 accumulator for m16n8k64 FP8 operations +// - Mixed-precision FP8 computations +// +// Requirements: +// - ONLY works with ordered metadata (sp::ordered_metadata) +// - ONLY for shape m16n8k64 +// - ONLY for FP8 types (not integers or other floats) + +// ============================================================================= +// FP8 e4m3 Sparse MMA with KIND (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e5m2 Sparse MMA with KIND (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e3m2 Sparse MMA with KIND (m16n8k64) +// NOTE: e3m2 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e3m2>, + multiplicandBPtxType = #nvvm.mma_type<e3m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e3m2>, multiplicandBPtxType = #nvvm.mma_type<e3m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e3m2>, + multiplicandBPtxType = #nvvm.mma_type<e3m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e2m3 Sparse MMA with KIND (m16n8k64) +// NOTE: e2m3 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m3>, + multiplicandBPtxType = #nvvm.mma_type<e2m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m3>, multiplicandBPtxType = #nvvm.mma_type<e2m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m3>, + multiplicandBPtxType = #nvvm.mma_type<e2m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e2m1 Sparse MMA with KIND (m16n8k64) +// NOTE: e2m1 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m1>, + multiplicandBPtxType = #nvvm.mma_type<e2m1>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind<f8f6f4>, multiplicandAPtxType = #nvvm.mma_type<e2m1>, multiplicandBPtxType = #nvvm.mma_type<e2m1>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind<f8f6f4>, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e2m1>, + multiplicandBPtxType = #nvvm.mma_type<e2m1>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir new file mode 100644 index 0000000..a4e2812 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir @@ -0,0 +1,411 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for sparse MMA (mma.sp.sync) operations with ORDERED metadata. +// The ordered metadata variant was introduced in PTX ISA 8.5 for sm_90+ architectures. +// +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// Ordered metadata provides an alternative metadata ordering for 2:4 structured sparsity +// that can offer better performance on newer architectures. + +// ============================================================================= +// F16 Sparse MMA Operations with Ordered Metadata (m16n8k16) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f16 +func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// F16 Sparse MMA Operations with Ordered Metadata (m16n8k32) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f16 +func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// BF16 Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_bf16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_bf16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// TF32 Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k8_tf32_f32 +func.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 8>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_tf32_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// Integer (s8) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite +func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Integer (u8) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_u8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (s4) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_s4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (u4) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_u4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// FP8 (e4m3) Sparse MMA Operations with Ordered Metadata +// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+ +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 (e5m2) Sparse MMA Operations with Ordered Metadata +// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+ +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, orderedMetadata, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir new file mode 100644 index 0000000..e7122aa --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir @@ -0,0 +1,390 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for all sparse MMA (mma.sp.sync) operations in the NVVM dialect +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// Sparse MMA operations follow 2:4 structured sparsity where 2 out of every 4 elements +// in the A operand are non-zero. The A operand is provided in compressed form, +// and sparseMetadata provides the sparsity indices. +// +// NOTE: These tests use the default (standard) metadata ordering. +// For ordered metadata tests (PTX ISA 8.5+, sm_90+), see nvvm-mma-sp-ordered.mlir. + +// ============================================================================= +// F16 Sparse MMA Operations (m16n8k16) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f16 +func.func @nvvm_mma_sp_m16n8k16_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f32 +func.func @nvvm_mma_sp_m16n8k16_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// F16 Sparse MMA Operations (m16n8k32) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f16 +func.func @nvvm_mma_sp_m16n8k32_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f32 +func.func @nvvm_mma_sp_m16n8k32_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// BF16 Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_bf16_f32 +func.func @nvvm_mma_sp_m16n8k16_bf16_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_bf16_f32 +func.func @nvvm_mma_sp_m16n8k32_bf16_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<bf16>, multiplicandBPtxType = #nvvm.mma_type<bf16>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<bf16>, + multiplicandBPtxType = #nvvm.mma_type<bf16>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// TF32 Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k8_tf32_f32 +func.func @nvvm_mma_sp_m16n8k8_tf32_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 8>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 8>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_tf32_f32 +func.func @nvvm_mma_sp_m16n8k16_tf32_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<tf32>, + multiplicandBPtxType = #nvvm.mma_type<tf32>, + shape = #nvvm.shape<m = 16, n = 8, k = 16>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// Integer (s8) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32 +func.func @nvvm_mma_sp_m16n8k32_s8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32_satfinite +func.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s8_s32 +func.func @nvvm_mma_sp_m16n8k64_s8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s8>, + multiplicandBPtxType = #nvvm.mma_type<s8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Integer (u8) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_u8_s32 +func.func @nvvm_mma_sp_m16n8k32_u8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u8_s32 +func.func @nvvm_mma_sp_m16n8k64_u8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u8>, + multiplicandBPtxType = #nvvm.mma_type<u8>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (s4) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s4_s32 +func.func @nvvm_mma_sp_m16n8k64_s4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_s4_s32 +func.func @nvvm_mma_sp_m16n8k128_s4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<s4>, + multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (u4) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u4_s32 +func.func @nvvm_mma_sp_m16n8k64_u4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_u4_s32 +func.func @nvvm_mma_sp_m16n8k128_u4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, multiplicandAPtxType = #nvvm.mma_type<u4>, multiplicandBPtxType = #nvvm.mma_type<u4>, shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<u4>, + multiplicandBPtxType = #nvvm.mma_type<u4>, + intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, + shape = #nvvm.shape<m = 16, n = 8, k = 128>} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// FP8 (e4m3) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e4m3>, multiplicandBPtxType = #nvvm.mma_type<e4m3>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e4m3>, + multiplicandBPtxType = #nvvm.mma_type<e4m3>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 (e5m2) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type<e5m2>, multiplicandBPtxType = #nvvm.mma_type<e5m2>, shape = #nvvm.shape<m = 16, n = 8, k = 64>} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type<e5m2>, + multiplicandBPtxType = #nvvm.mma_type<e5m2>, + shape = #nvvm.shape<m = 16, n = 8, k = 64>} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir b/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir new file mode 100644 index 0000000..c2cfa76 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-target-invalid.mlir @@ -0,0 +1,11 @@ +// RUN: not mlir-opt %s 2>&1 | FileCheck %s +// CHECK: 'nvvm.tcgen05.alloc' op is not supported on sm_90 + +module { + gpu.module @mod [#nvvm.target<chip = "sm_90">] { + func.func @tcgen05_alloc(%arg0: !llvm.ptr<7>, %arg1: i32) { + nvvm.tcgen05.alloc %arg0, %arg1 : !llvm.ptr<7>, i32 + return + } + } +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index cd7bd37..579f0ac 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -92,13 +92,6 @@ func.func @llvm_nvvm_cluster_wait() { llvm.return } -// CHECK-LABEL: @llvm_nvvm_fence_sc_cluster -func.func @llvm_nvvm_fence_sc_cluster() { - // CHECK: nvvm.fence.sc.cluster - nvvm.fence.sc.cluster - llvm.return -} - // CHECK-LABEL: @nvvm_shfl func.func @nvvm_shfl( %arg0 : i32, %arg1 : i32, %arg2 : i32, @@ -464,19 +457,6 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) { llvm.return } -llvm.func private @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 { - // CHECK: nvvm.mbarrier.test.wait %{{.*}} - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1 - llvm.return %isComplete : i1 -} - -llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) { - %count = nvvm.read.ptx.sreg.ntid.x : i32 - // CHECK: nvvm.mbarrier.test.wait %{{.*}} - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1 - llvm.return -} - // CHECK-LABEL: @wgmma_fence_aligned func.func @wgmma_fence_aligned() { // CHECK: nvvm.wgmma.fence.aligned diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir index 35f5e1b..506b81e 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir @@ -2,35 +2,15 @@ // Test invalid target architecture (sm_100 instead of sm_100a) gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] { - func.func @convert_rs() { - %f1 = llvm.mlir.constant(1.0 : f32) : f32 - %f2 = llvm.mlir.constant(2.0 : f32) : f32 - %rbits = llvm.mlir.constant(0x12345678 : i32) : i32 - // expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}} - %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16> + func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) { + // expected-error@+1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}} + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN) return } } // ----- -// Test that operations require stochastic rounding mode -llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}} - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// ----- - -llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}} - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// ----- - // Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2) llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}} diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 5e85759..40084bc 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -14,18 +14,24 @@ func.func @rocdl_special_regs() -> i32 { %4 = rocdl.workgroup.id.y : i32 // CHECK: rocdl.workgroup.id.z : i32 %5 = rocdl.workgroup.id.z : i32 + // CHECK: rocdl.cluster.id.x : i32 + %6 = rocdl.cluster.id.x : i32 + // CHECK: rocdl.cluster.id.y : i32 + %7 = rocdl.cluster.id.y : i32 + // CHECK: rocdl.cluster.id.z : i32 + %8 = rocdl.cluster.id.z : i32 // CHECK: rocdl.workgroup.dim.x : i32 - %6 = rocdl.workgroup.dim.x : i32 + %9 = rocdl.workgroup.dim.x : i32 // CHECK: rocdl.workgroup.dim.y : i32 - %7 = rocdl.workgroup.dim.y : i32 + %10 = rocdl.workgroup.dim.y : i32 // CHECK: rocdl.workgroup.dim.z : i32 - %8 = rocdl.workgroup.dim.z : i32 + %11 = rocdl.workgroup.dim.z : i32 // CHECK: rocdl.grid.dim.x : i32 - %9 = rocdl.grid.dim.x : i32 + %12 = rocdl.grid.dim.x : i32 // CHECK: rocdl.grid.dim.y : i32 - %10 = rocdl.grid.dim.y : i32 + %13 = rocdl.grid.dim.y : i32 // CHECK: rocdl.grid.dim.z : i32 - %11 = rocdl.grid.dim.z : i32 + %14 = rocdl.grid.dim.z : i32 llvm.return %0 : i32 } @@ -43,6 +49,59 @@ func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4 llvm.return %0 : vector<4xf16> } +func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) { + // CHECK-LABEL: rocdl.math.ops + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.tanh %{{.*}} bf16 -> bf16 + %tanh0 = rocdl.tanh %a f32 -> f32 + %tanh1 = rocdl.tanh %b f16 -> f16 + %tanh2 = rocdl.tanh %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.sin %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.sin %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.sin %{{.*}} bf16 -> bf16 + %sin0 = rocdl.sin %a f32 -> f32 + %sin1 = rocdl.sin %b f16 -> f16 + %sin2 = rocdl.sin %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.cos %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.cos %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.cos %{{.*}} bf16 -> bf16 + %cos0 = rocdl.cos %a f32 -> f32 + %cos1 = rocdl.cos %b f16 -> f16 + %cos2 = rocdl.cos %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.rcp %{{.*}} bf16 -> bf16 + %rcp0 = rocdl.rcp %a f32 -> f32 + %rcp1 = rocdl.rcp %b f16 -> f16 + %rcp2 = rocdl.rcp %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} bf16 -> bf16 + %exp2_0 = rocdl.exp2 %a f32 -> f32 + %exp2_1 = rocdl.exp2 %b f16 -> f16 + %exp2_2 = rocdl.exp2 %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.log %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.log %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.log %{{.*}} bf16 -> bf16 + %log0 = rocdl.log %a f32 -> f32 + %log1 = rocdl.log %b f16 -> f16 + %log2 = rocdl.log %c bf16 -> bf16 + + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f32 -> f32 + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f16 -> f16 + // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} bf16 -> bf16 + %sqrt0 = rocdl.sqrt %a f32 -> f32 + %sqrt1 = rocdl.sqrt %b f16 -> f16 + %sqrt2 = rocdl.sqrt %c bf16 -> bf16 + llvm.return +} + func.func @rocdl.barrier() { // CHECK: rocdl.barrier rocdl.barrier @@ -650,6 +709,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) { + // CHECK-LABEL: @rocdl.load.tr.ops + // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>) + // CHECK: rocdl.global.load.tr4.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32> + // CHECK: rocdl.global.load.tr.b64 %[[GL_PTR]] : !llvm.ptr<1> -> vector<2xi32> + // CHECK: rocdl.global.load.tr6.b96 %[[GL_PTR]] : !llvm.ptr<1> -> vector<3xi32> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xi16> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xf16> + // CHECK: rocdl.global.load.tr.b128 %[[GL_PTR]] : !llvm.ptr<1> -> vector<8xbf16> + // CHECK: rocdl.ds.load.tr4.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.load.tr8.b64 %[[DS_OTR]] : !llvm.ptr<3> -> vector<2xi32> + // CHECK: rocdl.ds.load.tr6.b96 %[[DS_OTR]] : !llvm.ptr<3> -> vector<3xi32> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xi16> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xf16> + // CHECK: rocdl.ds.load.tr16.b128 %[[DS_OTR]] : !llvm.ptr<3> -> vector<8xbf16> + // CHECK: llvm.return + + rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16> + + rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16> + llvm.return +} + llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) { // CHECK-LABEL @rocdl.load.to.lds //CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7> @@ -670,13 +762,27 @@ llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3 // CHECK: rocdl.global.load.async.to.lds.b32 %{{.*}}, %{{.*}}, 0, 0 // CHECK: rocdl.global.load.async.to.lds.b64 %{{.*}}, %{{.*}}, 0, 0 // CHECK: rocdl.global.load.async.to.lds.b128 %{{.*}}, %{{.*}}, 0, 0 - rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : <1>, <3> - rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : <1>, <3> - rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : <1>, <3> - rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : <1>, <3> + rocdl.global.load.async.to.lds.b8 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.global.load.async.to.lds.b32 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.global.load.async.to.lds.b64 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.global.load.async.to.lds.b128 %src, %dst, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> llvm.return } +llvm.func @rocdl.cluster.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + // CHECK-LABEL @rocdl.cluster.load.async.to.lds + // CHECK: rocdl.cluster.load.async.to.lds.b8 %{{.*}}, %{{.*}}, 0, 0, 0 + // CHECK: rocdl.cluster.load.async.to.lds.b32 %{{.*}}, %{{.*}}, 0, 0, 0 + // CHECK: rocdl.cluster.load.async.to.lds.b64 %{{.*}}, %{{.*}}, 0, 0, 0 + // CHECK: rocdl.cluster.load.async.to.lds.b128 %{{.*}}, %{{.*}}, 0, 0, 0 + rocdl.cluster.load.async.to.lds.b8 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.cluster.load.async.to.lds.b32 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.cluster.load.async.to.lds.b64 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + rocdl.cluster.load.async.to.lds.b128 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + llvm.return +} + + // CHECK-LABEL @rocdl.tensor.load.to.lds llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { @@ -1050,6 +1156,13 @@ llvm.func @rocdl.s.get.barrier.state() { llvm.return } +llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) { + // CHECK-LABEL: rocdl.s.get.named.barrier.state + // CHECK: rocdl.s.get.named.barrier.state %[[PTR:.+]] + %0 = rocdl.s.get.named.barrier.state %ptr : i32 + llvm.return +} + llvm.func @rocdl.s.wait.dscnt() { // CHECK-LABEL: rocdl.s.wait.dscnt // CHECK: rocdl.s.wait.dscnt 0 @@ -1305,6 +1418,26 @@ llvm.func @rocdl.cvt.scalef32.sr.pk16(%v16xf32: vector<16xf32>, // ----- +// CHECK-LABEL: @rocdl_wmma_scale_ops +llvm.func @rocdl_wmma_scale_ops(%a_f8: vector<8xi32>, %a_f4: vector<4xi32>, %c_f32: vector<4xf32>, %c16_f32: vector<16xf32>, + %scale_i32: i32, %scale_i64: i64) { + // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + %r0 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i32, %scale_i32 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + %r1 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i64, %scale_i64 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + + // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32> + %r2 = rocdl.wmma.scale.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i32, %scale_i32 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32> + %r3 = rocdl.wmma.scale16.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i64, %scale_i64 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32> + + llvm.return +} + +// ----- + // expected-error@below {{attribute attached to unexpected op}} func.func private @expected_llvm_func() attributes { rocdl.kernel } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 00e763a..afbf47e 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -122,8 +122,8 @@ func.func @ops(%arg0: i32, %arg1: f32, // CHECK: llvm.call @baz() {will_return} : () -> () llvm.call @baz() {will_return} : () -> () -// CHECK: llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> () - llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> () +// CHECK: llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () // Terminator operations and their successors. // diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir new file mode 100644 index 0000000..4b2d42a --- /dev/null +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -0,0 +1,214 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test. + +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic + +// ----------------------------- +// Convolution ops. +// ----------------------------- +func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> { + %0 = linalg.conv_1d + ins(%in, %filter : tensor<?xf32>, tensor<?xf32>) + outs(%out : tensor<?xf32>) -> tensor<?xf32> + return %0 : tensor<?xf32> +} +// CHECK: @conv_1d +// CHECK: linalg.conv_1d + +// ----- + +func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.conv_1d_nwc_wcf + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @conv_1d_nwc_wcf +// CHECK: linalg.conv_1d_nwc_wcf +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.conv_1d_ncw_fcw + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @conv_1d_ncw_fcw +// CHECK: linalg.conv_1d_ncw_fcw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tensor<?x?xf32>) -> tensor<?x?xf32> { + %0 = linalg.conv_2d + ins(%in, %filter : tensor<?x?xf32>, tensor<?x?xf32>) + outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} +// CHECK: @conv_2d +// CHECK: linalg.conv_2d + +// ----- + +func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.conv_3d + ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs(%out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @conv_3d +// CHECK: linalg.conv_3d + +// ----- + +// ----------------------------- +// Depthwise Convolution ops. +// ----------------------------- +func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { + %0 = linalg.depthwise_conv_1d_ncw_cw + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + return %0 : tensor<?x?x?xf32> +} +// CHECK: @depthwise_conv_1d_ncw_cw +// CHECK: linalg.depthwise_conv_1d_ncw_cw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> { + %0 = linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<1x25x8xi8>, tensor<3x8xi8>) + outs (%output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> + return %0 : tensor<1x10x8xi32> +} +// CHECK: @depthwise_conv_1d_nwc_wc_static +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.depthwise_conv_1d_nwc_wcm + {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @depthwise_conv_1d_nwc_wcm +// CHECK: linalg.depthwise_conv_1d_nwc_wcm +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tensor<?x?x?xf16>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.depthwise_conv_2d_nchw_chw + {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf16>, tensor<?x?x?xf16>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm + {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) + outs (%output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> + return %0 : tensor<?x?x?x?x?x?xf32> +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> + +// ----- + +// ----------------------------- +// Pooling ops. +// ----------------------------- +func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_max + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_min + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_sum + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?xi8>, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> { + %0 = linalg.pooling_nhwc_max_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xi8>, tensor<?x?xi8>) + outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> + return %0 : tensor<?x?x?x?xi32> +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min_unsigned_integer(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>) + outs (%output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> + return %0 : tensor<?x?x?x?xi32> +} +// CHECK: @pooling_nhwc_min_unsigned_integer +// CHECK: linalg.pooling_nhwc_min_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>) + outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} +// CHECK: @pooling_nhwc_min_unsigned_float +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 2bf3d21..77c7d7d 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -594,6 +594,24 @@ func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x // ----- +func.func @no_fuse_by_collapsing_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [2, 3, 4] : tensor<2x12xi32> into tensor<2x3x4xi32> + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %expand low[1, 0, 0] high[5, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x3x4xi32> to tensor<8x3x4xi32> + return %padded_0 : tensor<8x3x4xi32> +} +// CHECK: func @no_fuse_by_collapsing_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>) +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] +// CHECK: return %[[PAD]] + +// ----- + func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> { %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> %cst = arith.constant 0 : i32 @@ -640,6 +658,63 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>, // CHECK: return %[[EXPAND]] // ----- + +func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> { + %cst = arith.constant 0 : i32 + %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, + %arg5: index, %arg6: index, %arg7: index, %arg8: index): + tensor.yield %cst : i32 + } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32> + %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] + : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32> + return %collapsed : tensor<8x12x17x336x14xi32> +} +// CHECK: func @collapse_shape_with_producer_pad +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32> +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] +// CHECK: return %[[PAD]] + +// ----- + +func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>, + %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> { + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32> + %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]] + : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32> + return %collapsed : tensor<?x?x?x?xf32> +} +// CHECK: func @collapse_shape_with_producer_pad_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32> +// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0] +// CHECK: return %[[PAD]] + +// ----- + +func.func @collapse_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> { + %cst = arith.constant 0 : i32 + %padded_0 = tensor.pad %arg0 low[1, 0, 0] high[5, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x3x4xi32> to tensor<8x3x4xi32> + %collapsed = tensor.collapse_shape %padded_0 [[0], [1, 2]] : tensor<8x3x4xi32> into tensor<8x12xi32> + return %collapsed : tensor<8x12xi32> +} +// CHECK: func @collapse_shape_with_producer_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PAD]] +// CHECK: return %[[COLLAPSED]] + +// ----- // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes. #map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)> diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index bc55c12..6f1a422 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) { // ----- -// CHECK-LABEL: func @fold_fill_generic_different_dtype -// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> { -// CHECK-NOT: linalg.fill -// CHECK: %[[GENERIC_OP:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>) -// CHECK-SAME: outs({{.*}} : tensor<?xf16>) { -#map0 = affine_map<(d0) -> (d0)> -func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) { - %c0 = arith.constant 0 : index - %cst = arith.constant 7.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor<?xf16> - %1 = tensor.empty(%0) : tensor<?xf16> - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16> - %3 = tensor.empty(%0) : tensor<?xf16> - %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) { - ^bb0(%arg1: f16, %arg2: f16, %arg3: f16): - %5 = arith.addf %arg1, %arg2 : f16 - linalg.yield %5 : f16 - } -> tensor<?xf16> - return %4 : tensor<?xf16> -} - -// ----- - // CHECK-LABEL: func @fold_fill_generic_mixedaccess // CHECK-NOT: linalg.fill // CHECK: %[[GENERIC_OP:.*]] = linalg.generic @@ -1079,4 +1055,4 @@ module { // CHECK-NOT: linalg.generic // CHECK: tensor.expand_shape // CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>)
\ No newline at end of file +// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir index 290c6c7..4526dc9 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t // ----- -func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> { - %0 = linalg.fill ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32> +func.func @generalize_fill_0d(%value: f32, %O: tensor<f32>) -> tensor<f32> { + %0 = linalg.fill ins(%value: f32) outs(%O : tensor<f32>) -> tensor<f32> return %0: tensor<f32> } @@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> { // ----- -func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) { - linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>) +func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) { + linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>) return } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index fabc8e6..1f554e6 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return // ----- +func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32> +{ + // expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}} + %0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32> +{ + // expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}} + %0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32> { // expected-error @+1 {{expected op with scalar input}} diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 67b4f2b..3fb7225 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -822,6 +822,23 @@ func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor< // ----- +func.func @no_fuse_by_expanding_pad_non_constant_padding(%arg0 : tensor<2x3x4xi32>) -> tensor<8x12xi32> { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor<2x3x4xi32> into tensor<2x12xi32> + %padded_0 = tensor.pad %collapse low[1, 0] high[5, 0] { + ^bb0(%arg1: index, %arg2: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x12xi32> to tensor<8x12xi32> + return %padded_0 : tensor<8x12xi32> +} +// CHECK: func @no_fuse_by_expanding_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4xi32>) +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] +// CHECK: return %[[PAD]] + +// ----- + func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> { %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32> %cst = arith.constant 0 : i32 @@ -863,6 +880,64 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i // ----- +func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> { + %cst = arith.constant 0 : i32 + %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index): + tensor.yield %cst : i32 + } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32> + %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14] + : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32> + return %expanded : tensor<8x3x4x17x6x7x8x14xi32> +} +// CHECK: func @expand_shape_with_producer_pad +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] +// CHECK: return %[[PAD]] + +// ----- + +func.func @expand_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?xf32>, + %s0: index, %s1: index, %s2: index, %s3: index, %s4: index, %s5: index, + %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?x?x?xf32> { + %cst = arith.constant 0.0 : f32 + %padded = tensor.pad %arg0 low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32> + %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] + : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32> + return %expanded : tensor<?x?x?x?x?x?xf32> +} +// CHECK: func @expand_shape_with_producer_pad_dynamic +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32> +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0:.+]] : tensor<?x?x?x?xf32> +// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2:.+]] : tensor<?x?x?x?xf32> +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] output_shape [%[[DIM0]], %[[S1]], %[[S2]], %[[DIM2]], %[[S4]], %[[S5]]] +// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0] +// CHECK: return %[[PAD]] + +// ----- + +func.func @expand_shape_with_producer_pad_non_constant_padding(%arg0 : tensor<2x12xi32>) -> tensor<8x3x4xi32> { + %padded_0 = tensor.pad %arg0 low[1, 0] high[5, 0] { + ^bb0(%arg1: index, %arg2: index): + %pad_val = arith.index_cast %arg1 : index to i32 + tensor.yield %pad_val : i32 + } : tensor<2x12xi32> to tensor<8x12xi32> + %expand = tensor.expand_shape %padded_0 [[0], [1, 2]] output_shape [8, 3, 4] : tensor<8x12xi32> into tensor<8x3x4xi32> + return %expand : tensor<8x3x4xi32> +} +// CHECK: func @expand_shape_with_producer_pad_non_constant_padding( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12xi32>) +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] +// CHECK: return %[[EXPAND]] + +// ----- + func.func @move_operand_deps(%arg0 : tensor<?x128xf16>, %arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 185fb9b..d72ab08 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,7 +170,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op in (%forall_op) + transform.test.fuse_consumer_using_slice %slice_op in (%forall_op) : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -231,7 +231,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer // to fuse" error. transform.yield diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir index 9a14ab7..95959fc 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir @@ -1481,23 +1481,23 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @reduce_1d( -// CHECK-SAME: %[[A:.*]]: tensor<32xf32> -func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> { +// CHECK-LABEL: func @reduce_to_rank_0( +// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32> +func.func @reduce_to_rank_0(%arg0: tensor<32xf32>) -> tensor<f32> { // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %f0 = arith.constant 0.000000e+00 : f32 - // CHECK: %[[init:.*]] = tensor.empty() : tensor<f32> + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<f32> %0 = tensor.empty() : tensor<f32> %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32> - // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] + // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> - // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0] + // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[F0]] [0] // CHECK-SAME: : vector<32xf32> to f32 - // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32> - // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][] + // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32> + // CHECK: %[[RES:.*]] = vector.transfer_write %[[RED_V1]], %[[INIT]][] // CHECK-SAME: : vector<f32>, tensor<f32> %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, @@ -1525,6 +1525,58 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @reduce_to_rank_1( +// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32> +func.func @reduce_to_rank_1(%arg0: tensor<32xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + %f0 = arith.constant 0.000000e+00 : f32 + + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32> + %0 = tensor.empty() : tensor<1xf32> + + // CHECK: %[[INIT_ZERO:.*]] = vector.transfer_write %[[F0]], %[[INIT]][%[[C0]]] + // CHECK-SAME: : vector<1xf32>, tensor<1xf32> + %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<1xf32>) -> tensor<1xf32> + + // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]] + // CHECK-SAME: : tensor<32xf32>, vector<32xf32> + // CHECK: %[[INIT_ZERO_VEC:.*]] = vector.transfer_read %[[INIT_ZERO]][%[[C0]]] + // CHECK-SAME: : tensor<1xf32>, vector<f32> + // CHECK: %[[INIT_ZERO_SCL:.*]] = vector.extract %[[INIT_ZERO_VEC]][] + // CHECK-SAME: : f32 from vector<f32> + // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[INIT_ZERO_SCL]] [0] + // CHECK-SAME: : vector<32xf32> to f32 + // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32> + // CHECK: vector.transfer_write %[[RED_V1]], %[[INIT_ZERO]][%[[C0]]] + // CHECK-SAME: : vector<f32>, tensor<1xf32> + + %2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (0)>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor<32xf32>) + outs(%1 : tensor<1xf32>) { + ^bb0(%a: f32, %b: f32): + %3 = arith.addf %a, %b : f32 + linalg.yield %3 : f32 + } -> tensor<1xf32> + + return %2 : tensor<1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + + +// ----- + // This test checks that vectorization does not occur when an input indexing map // is not a projected permutation. In the future, this can be converted to a // positive test when support is added. diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 3130902..e02717a 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -208,19 +208,6 @@ func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref<?xf32, stri // ----- -// CHECK-LABEL: func @dim_of_sized_view -// CHECK-SAME: %{{[a-z0-9A-Z_]+}}: memref<?xi8> -// CHECK-SAME: %[[SIZE:.[a-z0-9A-Z_]+]]: index -// CHECK: return %[[SIZE]] : index -func.func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index { - %c0 = arith.constant 0 : index - %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref<?xi8> to memref<?xi8> - %1 = memref.dim %0, %c0 : memref<?xi8> - return %1 : index -} - -// ----- - // CHECK-LABEL: func @no_fold_subview_negative_size // CHECK: %[[SUBVIEW:.+]] = memref.subview // CHECK: return %[[SUBVIEW]] diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 18cdfb7..4ed8d4b 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1455,3 +1455,20 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2 // CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base // CHECK-NOT: memref.memory_space_cast + +// ----- + +func.func @negative_memref_view_extract_aligned_pointer(%arg0: memref<?xi8>) -> index { + // `extract_aligned_pointer_as_index` must not be folded as `memref.view` can change the base pointer + // CHECK-LABEL: func @negative_memref_view_extract_aligned_pointer + // CHECK-SAME: (%[[ARG0:.*]]: memref<?xi8>) + // CHECK: %[[C10:.*]] = arith.constant 10 : index + // CHECK: %[[VIEW:.*]] = memref.view %[[ARG0]][%[[C10]]][] : memref<?xi8> to memref<f32> + // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[VIEW]] : memref<f32> -> index + // CHECK: return %[[PTR]] : index + + %c10 = arith.constant 10 : index + %0 = memref.view %arg0[%c10][] : memref<?xi8> to memref<f32> + %1 = memref.extract_aligned_pointer_as_index %0: memref<f32> -> index + return %1 : index +} diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 1066526..ca91b01 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -992,6 +992,55 @@ func.func @fold_vector_maskedstore_expand_shape( // ----- +func.func @fold_vector_transfer_read_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[PAD:.*]] = ub.poison : f32 +// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8) +// CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]} + +// ----- + +func.func @fold_vector_transfer_read_with_perm_map( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + +// ----- + +func.func @fold_vector_transfer_read_rank_mismatch( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32> + %1 = vector.transfer_read %0[%arg1, %c0, %c0], %pad {in_bounds = [true, true]} : memref<2x4x4xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_rank_mismatch +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1, 2]] output_shape [2, 4, 4] : memref<32xf32> into memref<2x4x4xf32> + +// ----- + func.func @fold_vector_load_collapse_shape( %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> { %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir index d300699..dd68675 100644 --- a/mlir/test/Dialect/MemRef/mem2reg.mlir +++ b/mlir/test/Dialect/MemRef/mem2reg.mlir @@ -18,7 +18,7 @@ func.func @basic() -> i32 { // CHECK-LABEL: func.func @basic_default func.func @basic_default() -> i32 { // CHECK-NOT: = memref.alloca - // CHECK: %[[RES:.*]] = arith.constant 0 : i32 + // CHECK: %[[RES:.*]] = ub.poison : i32 // CHECK-NOT: = memref.alloca %0 = arith.constant 5 : i32 %1 = memref.alloca() : memref<i32> diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index 3b37c62..7fc84d4 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -306,6 +306,23 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @dead_alloc_escaped +func.func @dead_alloc_escaped() -> memref<8x64xf32, 3> { + // CHECK: %{{.+}} = memref.alloc + %0 = memref.alloc() : memref<8x64xf32, 3> + return %0 : memref<8x64xf32, 3> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @dead_alloc func.func @dead_alloc() { // CHECK-NOT: %{{.+}} = memref.alloc @@ -378,6 +395,73 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @dead_store_through_subview +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_subview(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32> + %subview = memref.subview %alloc[%c0] [4] [1] : memref<64xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + vector.transfer_write %arg, %subview[%c0] {in_bounds = [true]} + : vector<4xf32>, memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @dead_store_through_expand +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_expand(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32> + %expand = memref.expand_shape %alloc [[0, 1]] output_shape [16, 4] : memref<64xf32> into memref<16x4xf32> + vector.transfer_write %arg, %expand[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<16x4xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @dead_store_through_collapse +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +func.func @dead_store_through_collapse(%arg: vector<4xf32>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<16x4xf32> + %collapse = memref.collapse_shape %alloc [[0, 1]] : memref<16x4xf32> into memref<64xf32> + vector.transfer_write %arg, %collapse[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + +// ----- + // CHECK-LABEL: func @lower_to_llvm // CHECK-NOT: memref.alloc // CHECK: llvm.call @malloc diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir index cff118b7..fed0a4b 100644 --- a/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir +++ b/mlir/test/Dialect/OpenACC/acc-implicit-data-reduction.mlir @@ -28,8 +28,8 @@ func.func @test_reduction_implicit_copy() { memref.store %c0_i32, %r[] : memref<i32> acc.parallel { - %red_var = acc.reduction varPtr(%r : memref<i32>) -> memref<i32> {name = "r"} - acc.loop reduction(@reduction_add_memref_i32 -> %red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) { + %red_var = acc.reduction varPtr(%r : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32> {name = "r"} + acc.loop reduction(%red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) { %load = memref.load %red_var[] : memref<i32> %add = arith.addi %load, %c1_i32 : i32 memref.store %add, %red_var[] : memref<i32> @@ -47,7 +47,7 @@ func.func @test_reduction_implicit_copy() { // When enable-implicit-reduction-copy=false: expect firstprivate for reduction variable // FIRSTPRIVATE-LABEL: func.func @test_reduction_implicit_copy -// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) -> memref<i32> {implicit = true, name = ""} +// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""} // FIRSTPRIVATE-NOT: acc.copyin // FIRSTPRIVATE-NOT: acc.copyout @@ -81,8 +81,8 @@ func.func @test_reduction_with_usage_outside_loop() { %out_create = acc.create varPtr(%out : memref<i32>) -> memref<i32> {dataClause = #acc<data_clause acc_copyout>, name = "out"} acc.parallel dataOperands(%out_create : memref<i32>) { - %red_var = acc.reduction varPtr(%r : memref<i32>) -> memref<i32> {name = "r"} - acc.loop reduction(@reduction_add_memref_i32_2 -> %red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) { + %red_var = acc.reduction varPtr(%r : memref<i32>) recipe(@reduction_add_memref_i32_2) -> memref<i32> {name = "r"} + acc.loop reduction(%red_var : memref<i32>) control(%iv : i32) = (%c1_i32 : i32) to (%c100_i32 : i32) step (%c1_i32 : i32) { %load = memref.load %red_var[] : memref<i32> %add = arith.addi %load, %c1_i32 : i32 memref.store %add, %red_var[] : memref<i32> @@ -100,10 +100,10 @@ func.func @test_reduction_with_usage_outside_loop() { // In this case, r should be firstprivate regardless of the flag setting // because it's used outside the reduction context // COPY-LABEL: func.func @test_reduction_with_usage_outside_loop -// COPY: acc.firstprivate varPtr({{.*}} : memref<i32>) -> memref<i32> {implicit = true, name = ""} +// COPY: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""} // COPY-NOT: acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {{.*}} name = "" // FIRSTPRIVATE-LABEL: func.func @test_reduction_with_usage_outside_loop -// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) -> memref<i32> {implicit = true, name = ""} +// FIRSTPRIVATE: acc.firstprivate varPtr({{.*}} : memref<i32>) recipe({{.*}}) -> memref<i32> {implicit = true, name = ""} // FIRSTPRIVATE-NOT: acc.copyin varPtr({{.*}} : memref<i32>) -> memref<i32> {{.*}} name = "" diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir index cf09c33..6909fe6 100644 --- a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir +++ b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir @@ -13,7 +13,7 @@ func.func @test_scalar_in_serial() { } // CHECK-LABEL: func.func @test_scalar_in_serial -// CHECK: acc.firstprivate varPtr({{.*}} : memref<i64>) -> memref<i64> {implicit = true, name = ""} +// CHECK: acc.firstprivate varPtr({{.*}} : memref<i64>) recipe({{.*}}) -> memref<i64> {implicit = true, name = ""} // ----- @@ -28,7 +28,7 @@ func.func @test_scalar_in_parallel() { } // CHECK-LABEL: func.func @test_scalar_in_parallel -// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) -> memref<f32> {implicit = true, name = ""} +// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""} // ----- @@ -110,7 +110,7 @@ func.func @test_array_parallel_defaultpresent() { } // CHECK-LABEL: func.func @test_array_parallel_defaultpresent -// CHECK: %[[PRESENT:.*]] = acc.present varPtr({{.*}} : memref<10xf32>) -> memref<10xf32> {implicit = true, name = ""} +// CHECK: %[[PRESENT:.*]] = acc.present varPtr({{.*}} : memref<10xf32>) -> memref<10xf32> {acc.from_default, implicit = true, name = ""} // CHECK: acc.delete accPtr(%[[PRESENT]] : memref<10xf32>) {dataClause = #acc<data_clause acc_present>, implicit = true, name = ""} // ----- @@ -126,7 +126,7 @@ func.func @test_scalar_parallel_defaultpresent() { } // CHECK-LABEL: func.func @test_scalar_parallel_defaultpresent -// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) -> memref<f32> {implicit = true, name = ""} +// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""} // ----- @@ -197,7 +197,7 @@ func.func @test_multiple_variables() { } // CHECK-LABEL: func.func @test_multiple_variables -// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) -> memref<f32> {implicit = true, name = ""} +// CHECK: acc.firstprivate varPtr({{.*}} : memref<f32>) recipe({{.*}}) -> memref<f32> {implicit = true, name = ""} // CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : memref<10xi32>) -> memref<10xi32> {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} // CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr({{.*}} : memref<10xi32>) {dataClause = #acc<data_clause acc_copy>, implicit = true, name = ""} diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir new file mode 100644 index 0000000..74ff338 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/acc-implicit-declare.mlir @@ -0,0 +1,175 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(acc-implicit-declare)" -split-input-file 2>&1 | FileCheck %s + +// ----- + +// Test that non-constant scalar globals in compute regions are hoisted +// instead of being marked with acc declare + +memref.global @gscalar : memref<f32> = dense<0.0> + +func.func @test_scalar_in_serial() { + acc.serial { + %addr = memref.get_global @gscalar : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + return +} + +// Expected to hoist this global access out of acc region instead of marking +// with `acc declare`. +// CHECK-LABEL: func.func @test_scalar_in_serial +// CHECK: memref.get_global @gscalar +// CHECK: acc.serial +// CHECK-NOT: acc.declare + +// ----- + +// Test that constant globals are marked with acc declare + +memref.global constant @gscalarconst : memref<f32> = dense<1.0> + +func.func @test_constant_in_serial() { + acc.serial { + %addr = memref.get_global @gscalarconst : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + return +} + +// This is expected to be `acc declare`'d since it is a constant. +// CHECK: memref.global constant @gscalarconst {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test globals referenced in acc routine functions + +memref.global @gscalar_routine : memref<f32> = dense<0.0> + +acc.routine @acc_routine_0 func(@test_scalar_in_accroutine) +func.func @test_scalar_in_accroutine() attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} { + %addr = memref.get_global @gscalar_routine : memref<f32> + %load = memref.load %addr[] : memref<f32> + return +} + +// Global should be acc declare'd because it's in an acc routine +// CHECK: memref.global @gscalar_routine {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test constant globals in acc routine + +memref.global constant @gscalarconst_routine : memref<f32> = dense<1.0> + +acc.routine @acc_routine_0 func(@test_constant_in_accroutine) +func.func @test_constant_in_accroutine() attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} { + %addr = memref.get_global @gscalarconst_routine : memref<f32> + %load = memref.load %addr[] : memref<f32> + return +} + +// CHECK: memref.global constant @gscalarconst_routine {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test acc.private.recipe with global reference - referenced variant + +memref.global @global_for_private : memref<f32> = dense<0.0> + +acc.private.recipe @private_recipe_with_global : memref<f32> init { +^bb0(%arg0: memref<f32>): + %0 = memref.alloc() : memref<f32> + %global_addr = memref.get_global @global_for_private : memref<f32> + %global_val = memref.load %global_addr[] : memref<f32> + memref.store %global_val, %0[] : memref<f32> + acc.yield %0 : memref<f32> +} destroy { +^bb0(%arg0: memref<f32>): + memref.dealloc %arg0 : memref<f32> + acc.terminator +} + +func.func @test_private_recipe_referenced() { + %var = memref.alloc() : memref<f32> + %priv = acc.private varPtr(%var : memref<f32>) recipe(@private_recipe_with_global) -> memref<f32> + acc.parallel private(%priv : memref<f32>) { + %load = memref.load %var[] : memref<f32> + acc.yield + } + memref.dealloc %var : memref<f32> + return +} + +// Global should be acc declare'd because the recipe is referenced +// CHECK: memref.global @global_for_private {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} + +// ----- + +// Test acc.private.recipe with global reference - unreferenced variant + +memref.global @global_for_private_unused : memref<f32> = dense<0.0> + +acc.private.recipe @private_recipe_unused : memref<f32> init { +^bb0(%arg0: memref<f32>): + %0 = memref.alloc() : memref<f32> + %global_addr = memref.get_global @global_for_private_unused : memref<f32> + %global_val = memref.load %global_addr[] : memref<f32> + memref.store %global_val, %0[] : memref<f32> + acc.yield %0 : memref<f32> +} destroy { +^bb0(%arg0: memref<f32>): + memref.dealloc %arg0 : memref<f32> + acc.terminator +} + +func.func @test_private_recipe_not_referenced() { + %var = memref.alloc() : memref<f32> + acc.parallel { + %load = memref.load %var[] : memref<f32> + acc.yield + } + memref.dealloc %var : memref<f32> + return +} + +// Global should NOT be acc declare'd because the recipe is not referenced +// CHECK-NOT: memref.global @global_for_private_unused {{.*}} {acc.declare + +// ----- + +// Test globals in different compute constructs (parallel, kernels, serial) + +memref.global @global_parallel : memref<f32> = dense<0.0> +memref.global @global_kernels : memref<f32> = dense<0.0> +memref.global constant @global_serial_const : memref<f32> = dense<1.0> + +func.func @test_multiple_constructs() { + acc.parallel { + %addr = memref.get_global @global_parallel : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + acc.kernels { + %addr = memref.get_global @global_kernels : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.terminator + } + acc.serial { + %addr = memref.get_global @global_serial_const : memref<f32> + %load = memref.load %addr[] : memref<f32> + acc.yield + } + return +} + +// Non-constant globals ARE hoisted before their compute regions +// Constant global should be marked with acc.declare +// CHECK: memref.global constant @global_serial_const {{.*}} {acc.declare = #acc.declare<dataClause = acc_copyin>} +// CHECK-LABEL: func.func @test_multiple_constructs +// CHECK: memref.get_global @global_parallel +// CHECK-NEXT: acc.parallel +// CHECK: memref.get_global @global_kernels +// CHECK-NEXT: acc.kernels + diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index 0e75894..d1a1c93 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -76,27 +76,65 @@ acc.loop { // ----- -// expected-error@+1 {{'acc.loop' op duplicate device_type found in gang attribute}} +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in gang attribute}} acc.loop { acc.yield } attributes {gang = [#acc.device_type<none>, #acc.device_type<none>]} // ----- -// expected-error@+1 {{'acc.loop' op duplicate device_type found in worker attribute}} +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in worker attribute}} acc.loop { acc.yield } attributes {worker = [#acc.device_type<none>, #acc.device_type<none>]} // ----- -// expected-error@+1 {{'acc.loop' op duplicate device_type found in vector attribute}} +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vector attribute}} acc.loop { acc.yield } attributes {vector = [#acc.device_type<none>, #acc.device_type<none>]} // ----- +// expected-error@+1 {{'acc.loop' op duplicate device_type `nvidia` found in gang attribute}} +acc.loop { + acc.yield +} attributes {gang = [#acc.device_type<nvidia>, #acc.device_type<nvidia>]} + +// ----- + +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in collapseDeviceType attribute}} +acc.loop { + acc.yield +} attributes {collapse = [1, 1], collapseDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]} + +// ----- + +%i64value = arith.constant 1 : i64 +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in workerNumOperandsDeviceType attribute}} +acc.loop worker(%i64value: i64, %i64value: i64) { + acc.yield +} attributes {workerNumOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]} + +// ----- + +%i64value = arith.constant 1 : i64 +// expected-error@+1 {{'acc.loop' op duplicate device_type `none` found in vectorOperandsDeviceType attribute}} +acc.loop vector(%i64value: i64, %i64value: i64) { + acc.yield +} attributes {vectorOperandsDeviceType = [#acc.device_type<none>, #acc.device_type<none>], independent = [#acc.device_type<none>]} + +// ----- + +func.func @acc_routine_parallelism() -> () { + return +} +// expected-error@+1 {{only one of `gang`, `worker`, `vector`, `seq` can be present at the same time for device_type `nvidia`}} +"acc.routine"() <{func_name = @acc_routine_parallelism, sym_name = "acc_routine_parallelism_rout", gang = [#acc.device_type<nvidia>], worker = [#acc.device_type<nvidia>]}> : () -> () + +// ----- + %1 = arith.constant 1 : i32 %2 = arith.constant 10 : i32 // expected-error@+1 {{only one of auto, independent, seq can be present at the same time}} @@ -483,15 +521,6 @@ acc.loop gang({static=%i64Value: i64, ) control(%iv : i32) = (%1 : i32) to (%2 : // ----- -func.func @fct1(%0 : !llvm.ptr) -> () { - // expected-error@+1 {{expected symbol reference @privatization_i32 to point to a private declaration}} - acc.serial private(@privatization_i32 -> %0 : !llvm.ptr) { - } - return -} - -// ----- - %i1 = arith.constant 1 : i32 %i2 = arith.constant 10 : i32 // expected-error@+1 {{unstructured acc.loop must not have induction variables}} @@ -843,6 +872,76 @@ func.func @acc_loop_container() { // ----- +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{expected symbol reference @privatization_i32 to point to a private declaration}} + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr + return +} + +// ----- + +acc.private.recipe @privatization_i32 : !llvm.ptr init { +^bb0(%arg0: !llvm.ptr): + %c1 = arith.constant 1 : i32 + %c0 = arith.constant 0 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + llvm.store %c0, %0 : i32, !llvm.ptr + acc.yield %0 : !llvm.ptr +} + +func.func @fct1(%0 : !llvm.ptr) -> () { + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr + // expected-error @below {{expected firstprivate as defining op}} + acc.serial firstprivate(%priv : !llvm.ptr) { + } + return +} + +// ----- + +acc.private.recipe @privatization_i32 : !llvm.ptr init { +^bb0(%arg0: !llvm.ptr): + %c1 = arith.constant 1 : i32 + %c0 = arith.constant 0 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + llvm.store %c0, %0 : i32, !llvm.ptr + acc.yield %0 : !llvm.ptr +} + +func.func @fct1(%0 : !llvm.ptr) -> () { + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) recipe(@privatization_i32) -> !llvm.ptr + // expected-error @below {{op private operand appears more than once}} + acc.serial private(%priv, %priv : !llvm.ptr, !llvm.ptr) { + } + return +} + +// ----- + +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{op recipe expected for private}} + %priv = acc.private varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr + return +} + +// ----- + +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{op recipe expected for firstprivate}} + %priv = acc.firstprivate varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr + return +} + +// ----- + +func.func @fct1(%0 : !llvm.ptr) -> () { + // expected-error @below {{op recipe expected for reduction}} + %priv = acc.reduction varPtr(%0 : !llvm.ptr) varType(i32) -> !llvm.ptr + return +} + +// ----- + func.func @verify_declare_enter(%arg0 : memref<i32>) { // expected-error @below {{expect valid declare data entry operation or acc.getdeviceptr as defining op}} %0 = acc.declare_enter dataOperands(%arg0 : memref<i32>) diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir index 40604dc..c7ef47c 100644 --- a/mlir/test/Dialect/OpenACC/legalize-data.mlir +++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir @@ -129,8 +129,8 @@ func.func @test(%a: memref<10xf32>) { %lb = arith.constant 0 : index %st = arith.constant 1 : index %c10 = arith.constant 10 : index - %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) { + %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + acc.parallel private(%p1 : memref<10xf32>) { acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { %ci = memref.load %a[%i] : memref<10xf32> acc.yield @@ -142,8 +142,8 @@ func.func @test(%a: memref<10xf32>) { // CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) -// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK: acc.parallel private(%[[PRIVATE]] : memref<10xf32>) { // CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { // DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> // CHECK: acc.yield @@ -167,9 +167,9 @@ func.func @test(%a: memref<10xf32>) { %lb = arith.constant 0 : index %st = arith.constant 1 : index %c10 = arith.constant 10 : index - %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> + %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> acc.parallel { - acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { + acc.loop private(%p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { %ci = memref.load %a[%i] : memref<10xf32> acc.yield } attributes {independent = [#acc.device_type<none>]} @@ -180,9 +180,9 @@ func.func @test(%a: memref<10xf32>) { // CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) -// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> // CHECK: acc.parallel { -// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { +// CHECK: acc.loop private(%[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { // DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> // CHECK: acc.yield // CHECK: } attributes {independent = [#acc.device_type<none>]} @@ -205,8 +205,8 @@ func.func @test(%a: memref<10xf32>) { %lb = arith.constant 0 : index %st = arith.constant 1 : index %c10 = arith.constant 10 : index - %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) { + %p1 = acc.private varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + acc.serial private(%p1 : memref<10xf32>) { acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { %ci = memref.load %a[%i] : memref<10xf32> acc.yield @@ -218,8 +218,8 @@ func.func @test(%a: memref<10xf32>) { // CHECK-LABEL: func.func @test // CHECK-SAME: (%[[A:.*]]: memref<10xf32>) -// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK: acc.serial private(%[[PRIVATE]] : memref<10xf32>) { // CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { // DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> // CHECK: acc.yield diff --git a/mlir/test/Dialect/OpenACC/legalize-serial.mlir b/mlir/test/Dialect/OpenACC/legalize-serial.mlir new file mode 100644 index 0000000..774c6b6 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/legalize-serial.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %0 = memref.alloc() : memref<10x10xf32> + acc.yield %0 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { +^bb0(%0: i64): + %1 = arith.constant 0 : i64 + acc.yield %1 : i64 +} combiner { +^bb0(%0: i64, %1: i64): + %2 = arith.addi %0, %1 : i64 + acc.yield %2 : i64 +} + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init { +^bb0(%arg0: memref<i64>): + %0 = memref.alloca() : memref<i64> + %c0 = arith.constant 0 : i64 + memref.store %c0, %0[] : memref<i64> + acc.yield %0 : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.terminator +} + +// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) { +// CHECK: } +// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} +// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64> +// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) { +// CHECK: } +// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) { +// CHECK: acc.yield +// CHECK: } attributes {seq = [#acc.device_type<none>]} +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {selfAttr} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.yield +// CHECK: } attributes {selfAttr} +// CHECK: return +// CHECK: } + +func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64value = arith.constant 1 : i64 + %i32value = arith.constant 1 : i32 + %idxValue = arith.constant 1 : index + acc.serial async(%i64value: i64) { + } + acc.serial async(%i32value: i32) { + } + acc.serial async(%idxValue: index) { + } + acc.serial wait({%i64value: i64}) { + } + acc.serial wait({%i32value: i32}) { + } + acc.serial wait({%idxValue: index}) { + } + acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + } + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + %c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { + } + %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} + acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) { + } + %i64mem = memref.alloca() : memref<i64> + memref.store %i64value, %i64mem[] : memref<i64> + %i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%i64reduction : memref<i64>) { + } + acc.serial combined(loop) { + acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) { + acc.yield + } attributes {seq = [#acc.device_type<none>]} + acc.terminator + } + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue none>} + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue present>} + acc.serial { + } attributes {asyncAttr} + acc.serial { + } attributes {waitAttr} + acc.serial { + } attributes {selfAttr} + acc.serial { + acc.yield + } attributes {selfAttr} + return +} + diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index fc11bae..d31397c 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -120,8 +120,8 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x %pc = acc.present varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> %pd = acc.present varPtr(%d : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { - %private = acc.private varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) { + %private = acc.private varPtr(%c : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(%private : memref<10xf32>) { acc.loop gang control(%x : index) = (%lb : index) to (%c10 : index) step (%st : index) { acc.loop worker control(%y : index) = (%lb : index) to (%c10 : index) step (%st : index) { %axy = memref.load %a[%x, %y] : memref<10x10xf32> @@ -157,8 +157,8 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x // CHECK-NEXT: [[NUMGANG:%.*]] = arith.constant 10 : i64 // CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64 // CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { -// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) { +// CHECK-NEXT: %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK-NEXT: acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(%[[P_ARG2]] : memref<10xf32>) { // CHECK-NEXT: acc.loop gang control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { // CHECK-NEXT: acc.loop worker control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { // CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> @@ -375,8 +375,8 @@ func.func @testloopfirstprivate(%a: memref<10xf32>, %b: memref<10xf32>) -> () { %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - %firstprivate = acc.firstprivate varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.loop firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { + %firstprivate = acc.firstprivate varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + acc.loop firstprivate(%firstprivate : memref<10xf32>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { "test.openacc_dummy_op"() : () -> () acc.yield } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]} @@ -385,8 +385,8 @@ func.func @testloopfirstprivate(%a: memref<10xf32>, %b: memref<10xf32>) -> () { // CHECK-LABEL: func.func @testloopfirstprivate( // CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>, %[[ARG1:.*]]: memref<10xf32>) -// CHECK: %[[FIRSTPRIVATE:.*]] = acc.firstprivate varPtr(%[[ARG0]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.loop firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTPRIVATE]] : memref<10xf32>) control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { +// CHECK: %[[FIRSTPRIVATE:.*]] = acc.firstprivate varPtr(%[[ARG0]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: acc.loop firstprivate(%[[FIRSTPRIVATE]] : memref<10xf32>) control(%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { // CHECK: "test.openacc_dummy_op"() : () -> () // CHECK: acc.yield // CHECK: } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]} @@ -464,7 +464,10 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x } acc.parallel vector_length(%idxValue: index) { } - acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@privatization_memref_10xf32 -> %b: memref<10xf32>) { + %private_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + %private_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + %firstprivate_b = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@privatization_memref_10xf32) -> memref<10xf32> + acc.parallel private(%private_a, %private_c : memref<10xf32>, memref<10x10xf32>) firstprivate(%firstprivate_b : memref<10xf32>) { } acc.parallel { } attributes {defaultAttr = #acc<defaultvalue none>} @@ -517,7 +520,10 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x // CHECK-NEXT: } // CHECK: acc.parallel vector_length([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.parallel firstprivate(@privatization_memref_10xf32 -> [[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) { +// CHECK: %[[PRIVATE_A:.*]] = acc.private varPtr([[ARGA]] : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK-NEXT: %[[PRIVATE_C:.*]] = acc.private varPtr([[ARGC]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK-NEXT: %[[FIRSTPRIVATE_B:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) recipe(@privatization_memref_10xf32) -> memref<10xf32> +// CHECK-NEXT: acc.parallel firstprivate(%[[FIRSTPRIVATE_B]] : memref<10xf32>) private(%[[PRIVATE_A]], %[[PRIVATE_C]] : memref<10xf32>, memref<10x10xf32>) { // CHECK-NEXT: } // CHECK: acc.parallel { // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>} @@ -596,8 +602,10 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 } acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { } - %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> - acc.serial private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(@firstprivatization_memref_10xf32 -> %firstprivate : memref<10xf32>) { + %private_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + %private_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + acc.serial private(%private_a, %private_c : memref<10xf32>, memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { } acc.serial { } attributes {defaultAttr = #acc<defaultvalue none>} @@ -633,8 +641,10 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // CHECK-NEXT: } // CHECK: acc.serial wait({[[I64VALUE]] : i64, [[I32VALUE]] : i32, [[IDXVALUE]] : index}) { // CHECK-NEXT: } -// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> -// CHECK: acc.serial firstprivate(@firstprivatization_memref_10xf32 -> %[[FIRSTP]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) { +// CHECK: %[[PRIVATE_A:.*]] = acc.private varPtr([[ARGA]] : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK-NEXT: %[[PRIVATE_C:.*]] = acc.private varPtr([[ARGC]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK-NEXT: %[[FIRSTP:.*]] = acc.firstprivate varPtr([[ARGB]] : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK-NEXT: acc.serial firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIVATE_A]], %[[PRIVATE_C]] : memref<10xf32>, memref<10x10xf32>) { // CHECK-NEXT: } // CHECK: acc.serial { // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>} @@ -721,6 +731,59 @@ func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10 // ----- +// Test acc.kernels with private and firstprivate operands, similar to acc.serial. + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %1 = memref.alloc() : memref<10x10xf32> + acc.yield %1 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %2 = memref.alloca() : memref<10xf32> + acc.yield %2 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + memref.copy %arg0, %arg1 : memref<10xf32> to memref<10xf32> + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + acc.terminator +} + +func.func @testkernelspriv(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %priv_a = acc.private varPtr(%a : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> + %priv_c = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + %firstp = acc.firstprivate varPtr(%b : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + acc.kernels firstprivate(%firstp : memref<10xf32>) private(%priv_a, %priv_c : memref<10xf32>, memref<10x10xf32>) { + } + return +} + +// CHECK-LABEL: func.func @testkernelspriv( +// CHECK: %[[PRIV_A:.*]] = acc.private varPtr(%{{.*}} : memref<10xf32>) recipe(@privatization_memref_10_f32) -> memref<10xf32> +// CHECK: %[[PRIV_C:.*]] = acc.private varPtr(%{{.*}} : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: %[[FIRSTP:.*]] = acc.firstprivate varPtr(%{{.*}} : memref<10xf32>) varType(tensor<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: acc.kernels firstprivate(%[[FIRSTP]] : memref<10xf32>) private(%[[PRIV_A]], %[[PRIV_C]] : memref<10xf32>, memref<10x10xf32>) { +// CHECK-NEXT: } + +// ----- + func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () { %ifCond = arith.constant true @@ -1511,32 +1574,43 @@ acc.private.recipe @privatization_struct_i32_i64 : !llvm.struct<(i32, i32)> init // ----- -acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { -^bb0(%arg0: i64): - %0 = arith.constant 0 : i64 - acc.yield %0 : i64 +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +^bb0(%arg0: memref<i64>): + %c0_i64 = arith.constant 0 : i64 + %alloca = memref.alloca() : memref<i64> + memref.store %c0_i64, %alloca[] : memref<i64> + acc.yield %alloca : memref<i64> } combiner { -^bb0(%arg0: i64, %arg1: i64): - %0 = arith.addi %arg0, %arg1 : i64 - acc.yield %0 : i64 +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.yield %arg0 : memref<i64> } -// CHECK-LABEL: acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator <add> init { -// CHECK: ^bb0(%{{.*}}: i64): +// CHECK-LABEL: acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +// CHECK: ^bb0(%{{.*}}: memref<i64>): // CHECK: %[[C0:.*]] = arith.constant 0 : i64 -// CHECK: acc.yield %[[C0]] : i64 +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[C0]], %[[ALLOCA]][] : memref<i64> +// CHECK: acc.yield %[[ALLOCA]] : memref<i64> // CHECK: } combiner { -// CHECK: ^bb0(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64): -// CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i64 -// CHECK: acc.yield %[[RES]] : i64 +// CHECK: ^bb0(%[[ARG0:.*]]: memref<i64>, %[[ARG1:.*]]: memref<i64>): +// CHECK: %[[LOAD0:.*]] = memref.load %[[ARG0]][] : memref<i64> +// CHECK: %[[LOAD1:.*]] = memref.load %[[ARG1]][] : memref<i64> +// CHECK: %[[RES:.*]] = arith.addi %[[LOAD0]], %[[LOAD1]] : i64 +// CHECK: memref.store %[[RES]], %[[ARG0]][] : memref<i64> +// CHECK: acc.yield %[[ARG0]] : memref<i64> // CHECK: } -func.func @acc_reduc_test(%a : i64) -> () { +func.func @acc_reduc_test(%a : memref<i64>) -> () { %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - acc.parallel reduction(@reduction_add_i64 -> %a : i64) { - acc.loop reduction(@reduction_add_i64 -> %a : i64) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { + %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.parallel reduction(%reduction_a : memref<i64>) { + acc.loop reduction(%reduction_a : memref<i64>) control(%iv : index) = (%c0 : index) to (%c10 : index) step (%c1 : index) { acc.yield } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]} acc.yield @@ -1545,31 +1619,68 @@ func.func @acc_reduc_test(%a : i64) -> () { } // CHECK-LABEL: func.func @acc_reduc_test( -// CHECK-SAME: %[[ARG0:.*]]: i64) -// CHECK: acc.parallel reduction(@reduction_add_i64 -> %[[ARG0]] : i64) -// CHECK: acc.loop reduction(@reduction_add_i64 -> %[[ARG0]] : i64) +// CHECK-SAME: %[[ARG0:.*]]: memref<i64>) +// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK-NEXT: acc.parallel reduction(%[[REDUCTION_A]] : memref<i64>) +// CHECK: acc.loop reduction(%[[REDUCTION_A]] : memref<i64>) // ----- -acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { -^bb0(%0: i64): - %1 = arith.constant 0 : i64 - acc.yield %1 : i64 +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +^bb0(%arg0: memref<i64>): + %c0_i64 = arith.constant 0 : i64 + %alloca = memref.alloca() : memref<i64> + memref.store %c0_i64, %alloca[] : memref<i64> + acc.yield %alloca : memref<i64> } combiner { -^bb0(%0: i64, %1: i64): +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> %2 = arith.addi %0, %1 : i64 - acc.yield %2 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.yield %arg0 : memref<i64> } -func.func @acc_reduc_test(%a : i64) -> () { - acc.serial reduction(@reduction_add_i64 -> %a : i64) { +func.func @acc_reduc_test(%a : memref<i64>) -> () { + %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%reduction_a : memref<i64>) { } return } // CHECK-LABEL: func.func @acc_reduc_test( -// CHECK-SAME: %[[ARG0:.*]]: i64) -// CHECK: acc.serial reduction(@reduction_add_i64 -> %[[ARG0]] : i64) +// CHECK-SAME: %[[ARG0:.*]]: memref<i64>) +// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK-NEXT: acc.serial reduction(%[[REDUCTION_A]] : memref<i64>) + +// ----- + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator <add> init { +^bb0(%arg0: memref<i64>): + %c0_i64 = arith.constant 0 : i64 + %alloca = memref.alloca() : memref<i64> + memref.store %c0_i64, %alloca[] : memref<i64> + acc.yield %alloca : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.yield %arg0 : memref<i64> +} + +func.func @acc_kernels_reduc_test(%a : memref<i64>) -> () { + %reduction_a = acc.reduction varPtr(%a : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.kernels reduction(%reduction_a : memref<i64>) { + } + return +} + +// CHECK-LABEL: func.func @acc_kernels_reduc_test( +// CHECK-SAME: %[[ARG0:.*]]: memref<i64>) +// CHECK: %[[REDUCTION_A:.*]] = acc.reduction varPtr(%[[ARG0]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK-NEXT: acc.kernels reduction(%[[REDUCTION_A]] : memref<i64>) // ----- @@ -1699,6 +1810,59 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang( // ----- +// Test acc.specialized_routine attribute for specialized device functions +acc.routine @routine_seq func(@device_func_seq) seq +acc.routine @routine_gang func(@device_func_gang) gang +acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64) +acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64) +acc.routine @routine_worker func(@device_func_worker) worker +acc.routine @routine_vector func(@device_func_vector) vector + +func.func @device_func_seq() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">} { + return +} + +func.func @device_func_gang() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">} { + return +} + +func.func @device_func_gang_dim2() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">} { + return +} + +func.func @device_func_gang_dim3() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">} { + return +} + +func.func @device_func_worker() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">} { + return +} + +func.func @device_func_vector() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">} { + return +} + +// CHECK: acc.routine @routine_seq func(@device_func_seq) seq +// CHECK: acc.routine @routine_gang func(@device_func_gang) gang +// CHECK: acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64) +// CHECK: acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64) +// CHECK: acc.routine @routine_worker func(@device_func_worker) worker +// CHECK: acc.routine @routine_vector func(@device_func_vector) vector +// CHECK-LABEL: func.func @device_func_seq() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">} +// CHECK-LABEL: func.func @device_func_gang() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">} +// CHECK-LABEL: func.func @device_func_gang_dim2() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">} +// CHECK-LABEL: func.func @device_func_gang_dim3() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">} +// CHECK-LABEL: func.func @device_func_worker() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">} +// CHECK-LABEL: func.func @device_func_vector() +// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">} + +// ----- + func.func @acc_func() -> () { "test.openacc_dummy_op"() {acc.declare_action = #acc.declare_action<postAlloc = @_QMacc_declareFacc_declare_allocateEa_acc_declare_update_desc_post_alloc>} : () -> () return diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir new file mode 100644 index 0000000..36df6a1 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-load.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=load}))" 2>&1 | FileCheck %s + +func.func @test_memref_load_scalar() { + %ptr = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Successfully generated load for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Loaded value type: f32 + // CHECK: Generated: %{{.*}} = memref.load %[[PTR]][] : memref<f32> + return +} + +// ----- + +func.func @test_memref_load_int() { + %ptr = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Successfully generated load for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Loaded value type: i64 + // CHECK: Generated: %{{.*}} = memref.load %[[PTR]][] : memref<i64> + return +} + +// ----- + +func.func @test_memref_load_dynamic() { + %c10 = arith.constant 10 : index + %ptr = memref.alloc(%c10) {test.ptr} : memref<?xf32> + // CHECK: Failed to generate load for operation: %[[PTR:.*]] = memref.alloc(%{{.*}}) {test.ptr} : memref<?xf32> + return +} + diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir new file mode 100644 index 0000000..0fee431 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-store.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=store}))" 2>&1 | FileCheck %s + +func.func @test_memref_store_scalar() { + %ptr = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<f32> + // CHECK: Generated: %[[VAL:.*]] = arith.constant 4.200000e+01 : f32 + // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<f32> + return +} + +// ----- + +func.func @test_memref_store_int() { + %ptr = memref.alloca() {test.ptr} : memref<i32> + // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i32> + // CHECK: Generated: %[[VAL:.*]] = arith.constant 42 : i32 + // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<i32> + return +} + +// ----- + +func.func @test_memref_store_i64() { + %ptr = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Successfully generated store for operation: %[[PTR:.*]] = memref.alloca() {test.ptr} : memref<i64> + // CHECK: Generated: %[[VAL:.*]] = arith.constant 42 : i64 + // CHECK: Generated: memref.store %[[VAL]], %[[PTR]][] : memref<i64> + return +} + +// ----- + +func.func @test_memref_store_dynamic() { + %c10 = arith.constant 10 : index + %ptr = memref.alloc(%c10) {test.ptr} : memref<?xf32> + // CHECK: Failed to generate store for operation: %[[PTR:.*]] = memref.alloc(%{{.*}}) {test.ptr} : memref<?xf32> + return +} + diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir new file mode 100644 index 0000000..154d44e --- /dev/null +++ b/mlir/test/Dialect/OpenACC/recipe-populate-private-from-firstprivate.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(test-acc-recipe-populate{recipe-type=private_from_firstprivate})" | FileCheck %s + +// Verify that we can create a private recipe using the convenience overload +// that takes an existing firstprivate recipe as input. For a simple scalar +// alloca-backed memref, only an init region is expected (no destroy). +// CHECK: acc.private.recipe @private_from_firstprivate_scalar : memref<f32> init { +// CHECK: ^bb0(%{{.*}}: memref<f32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32> +// CHECK: acc.yield %[[ALLOC]] : memref<f32> +// CHECK: } + +func.func @test_scalar_from_firstprivate() { + %0 = memref.alloca() {test.var = "scalar"} : memref<f32> + return +} + +// ----- + +// Verify that destroy regions are also present when creating a private recipe +// from a firstprivate recipe that requires dynamic deallocation. +// CHECK: acc.private.recipe @private_from_firstprivate_dynamic_d2 : memref<?x?xf32> init { +// CHECK: ^bb0(%[[ARG:.*]]: memref<?x?xf32>): +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_d2">} : memref<?x?xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32> +// CHECK: } destroy { +// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>): +// CHECK: memref.dealloc %[[VAL]] : memref<?x?xf32> +// CHECK: acc.terminator +// CHECK: } + +func.func @test_dynamic_from_firstprivate(%arg0: index, %arg1: index) { + %0 = memref.alloc(%arg0, %arg1) {test.var = "dynamic_d2"} : memref<?x?xf32> + return +} diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 084c3fc..ac590fc 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -974,6 +974,56 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { // ----- +// CHECK-LABEL: @while_move_if_down +func.func @while_move_if_down() -> i32 { + %defined_outside = "test.get_some_value0" () : () -> (i32) + %0 = scf.while () : () -> (i32) { + %used_value = "test.get_some_value1" () : () -> (i32) + %used_by_subregion = "test.get_some_value2" () : () -> (i32) + %else_value = "test.get_some_value3" () : () -> (i32) + %condition = "test.condition"() : () -> i1 + %res = scf.if %condition -> (i32) { + "test.use0" (%defined_outside) : (i32) -> () + "test.use1" (%used_value) : (i32) -> () + test.alloca_scope_region { + "test.use2" (%used_by_subregion) : (i32) -> () + } + %then_value = "test.get_some_value4" () : () -> (i32) + scf.yield %then_value : i32 + } else { + scf.yield %else_value : i32 + } + scf.condition(%condition) %res : i32 + } do { + ^bb0(%res_arg: i32): + "test.use3" (%res_arg) : (i32) -> () + scf.yield + } + return %0 : i32 +} +// CHECK: %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32 +// CHECK: %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) { +// CHECK: %[[used_value:.*]] = "test.get_some_value1"() : () -> i32 +// CHECK: %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32 +// CHECK: %[[else_value:.*]] = "test.get_some_value3"() : () -> i32 +// CHECK: %[[condition:.*]] = "test.condition"() : () -> i1 +// CHECK: scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32): +// CHECK: "test.use0"(%[[defined_outside]]) : (i32) -> () +// CHECK: "test.use1"(%[[used_value_arg]]) : (i32) -> () +// CHECK: test.alloca_scope_region { +// CHECK: "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> () +// CHECK: } +// CHECK: %[[then_value:.*]] = "test.get_some_value4"() : () -> i32 +// CHECK: "test.use3"(%[[then_value]]) : (i32) -> () +// CHECK: scf.yield +// CHECK: } +// CHECK: return %[[WHILE_RES]]#0 : i32 +// CHECK: } + +// ----- + // CHECK-LABEL: @while_cond_true func.func @while_cond_true() -> i1 { %0 = scf.while () : () -> i1 { diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir index 8e29ff6..b70bb40 100644 --- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir @@ -795,6 +795,53 @@ func.func @selection(%cond: i1) -> () { // ----- +func.func @selection_switch(%selector: i32) -> () { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function> + + // CHECK: spirv.mlir.selection { + spirv.mlir.selection { + // CHECK-NEXT: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1, + // CHECK-NEXT: 0: ^bb2, + // CHECK-NEXT: 1: ^bb3 + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1 + ] + // CHECK: ^bb1 + ^default: + spirv.Store "Function" %var, %one : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb2 + ^case0: + spirv.Store "Function" %var, %two : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb3 + ^case1: + spirv.Store "Function" %var, %three : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb4 + ^merge: + // CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge + } + + spirv.Return +} + +// ----- + // CHECK-LABEL: @empty_region func.func @empty_region() -> () { // CHECK: spirv.mlir.selection @@ -918,3 +965,171 @@ func.func @kill() { // CHECK: spirv.Kill spirv.Kill } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.Switch +//===----------------------------------------------------------------------===// + +func.func @switch(%selector: i32) -> () { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1, + // CHECK-NEXT: 0: ^bb2, + // CHECK-NEXT: 1: ^bb3, + // CHECK-NEXT: 2: ^bb4 + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1, + 2: ^case2 + ] +^default: + spirv.Branch ^merge + +^case0: + spirv.Branch ^merge + +^case1: + spirv.Branch ^merge + +^case2: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +func.func @switch_only_default(%selector: i32) -> () { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1 + spirv.Switch %selector : i32, [ + default: ^default + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +func.func @switch_operands(%selector : i32, %operand : i32) { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1({{%.*}} : i32), + // CHECK-NEXT: 0: ^bb2({{%.*}} : i32), + // CHECK-NEXT: 1: ^bb3({{%.*}} : i32) + spirv.Switch %selector : i32, [ + default: ^default(%operand : i32), + 0: ^case0(%operand : i32), + 1: ^case1(%operand : i32) + ] +^default(%argd : i32): + spirv.Branch ^merge + +^case0(%arg0 : i32): + spirv.Branch ^merge + +^case1(%arg1 : i32): + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_float_selector(%selector: f32) -> () { + // expected-error@+1 {{expected builtin.integer, but found 'f32'}} + spirv.Switch %selector : f32, [ + default: ^default + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_float_selector(%selector: i32) -> () { + // expected-error@+3 {{expected integer value}} + spirv.Switch %selector : i32, [ + default: ^default, + 0.0: ^case0 + ] +^default: + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_missing_default(%selector: i32) -> () { + // expected-error@+2 {{expected 'default'}} + spirv.Switch %selector : i32, [ + 0: ^case0 + ] +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_default_no_target(%selector: i32) -> () { + // expected-error@+2 {{expected block name}} + spirv.Switch %selector : i32, [ + default: + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_case_no_target(%selector: i32) -> () { + // expected-error@+3 {{expected block name}} + spirv.Switch %selector : i32, [ + default: ^default, + 0: + ] +^default: + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_missing_operand_type(%selector: i32) -> () { + %0 = spirv.Constant 0 : i32 + // expected-error@+2 {{expected ':'}} + spirv.Switch %selector : i32, [ + default: ^default (%0), + 0.0: ^case0 + ] +^default(%argd : i32): + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 5eb2360..be8ce20 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3 // CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]] // CHECK: %[[ALLOC:.*]] = memref.alloc // CHECK-SAME: memref<8x?xf32> -// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index -// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1] // CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] -// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index -// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1] // CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: return %[[RET]] @@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te // CHECK: %[[ALLOC:.*]] = memref.alloc // CHECK-SAME: memref<?x?xf32> // CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]] -// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1] // CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] -// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index -// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1] // CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: return %[[RET]] @@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<? // ----- +// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> + +// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static( +// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>, +// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>) +// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]] +// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]] +// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]] +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32> +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1] +// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] +// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1] +// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] +// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]] +// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1] +// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]] +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: return %[[RET]] +// CHECK: } +func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> { + %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32> + return %0 : tensor<8x10xf32> +} + +// ----- + // CHECK-LABEL: func @tensor.splat_dynamic( // CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32 // CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index a05f423..6ef8b3e 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -606,7 +606,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { // CHECK-LABEL: cast func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] - // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp, int64] ] + // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ] %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 5a40f3f..84776c4 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -362,6 +362,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens // ----- +// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp +// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8} +func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + return %1 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> +} + +// ----- + +// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp +// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8} +func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> +} + +// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp +// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} +// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8} +func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> + return %1 : tensor<?x112x112x32x!quant.uniform<i8:f32, 0.023529412224888802:-128>> +} + + +// ----- + // CHECK-LABEL: @concat_fold func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> { // CHECK: return %arg0 @@ -643,6 +673,48 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: // ----- +// CHECK-LABEL: @select_broadcast_same_value_no_fold +func.func @select_broadcast_same_value_no_fold(%arg0: tensor<2x2xi1>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> { + // CHECK: tosa.select %arg0, %arg1, %arg1 + %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_true_value_no_fold +func.func @select_broadcast_true_value_no_fold(%arg0: tensor<1x1xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<1> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32> + return %1 : tensor<?x?xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_false_value_no_fold +func.func @select_broadcast_false_value_no_fold(%arg0: tensor<2x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<1x1xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_false_value_dynamic_operand_no_fold +func.func @select_broadcast_false_value_dynamic_operand_no_fold(%arg0: tensor<2x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: @reduce_all_fold func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> { // CHECK: return %arg0 diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index c9e03ca..3d24928 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> { @@ -2044,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens // ----- +// CHECK-LABEL: test_scatter_duplicate_indices_int64 +func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}} + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +} + +// ----- + func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 68a9578..177192b 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -563,13 +563,6 @@ func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> { } // ----- -func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> { - // expected-error@+1 {{'tosa.cast' op illegal: requires all of [bf16, mxfp] but not enabled in target}} - %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> - return %0 : tensor<13x21x3xbf16> -} - -// ----- func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 22fde3b..652447bd 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -280,6 +280,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 1.0 } // ----- +// CHECK-LABEL: clamp_quantized_unsigned +func.func @clamp_quantized_unsigned(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) { + %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> +} + +// ----- // CHECK-LABEL: sigmoid func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = tosa.sigmoid %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -343,6 +350,13 @@ func.func @test_intdiv(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) - } // ----- +// CHECK-LABEL: intdiv_i64 +func.func @test_intdiv_i64(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + %0 = tosa.intdiv %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- // CHECK-LABEL: logical_and func.func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> { %0 = tosa.logical_and %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1> @@ -750,10 +764,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> } // ----- -// CHECK-LABEL: scatter -func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> - return %0 : tensor<13x52x3xf32> +// CHECK-LABEL: gather_int64 +func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> { + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32> + return %0 : tensor<13x26x3xf32> } // ----- @@ -764,6 +778,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso } // ----- +// CHECK-LABEL: scatter +func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> + return %0 : tensor<13x52x3xf32> +} + +// ----- +// CHECK-LABEL: scatter_int64 +func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> + return %0 : tensor<13x52x3xf32> +} + +// ----- // CHECK-LABEL: scatter_unranked_indices func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -1277,6 +1305,42 @@ func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, } // ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2_e2e +func.func @test_matmul_t_block_scaled_fp6e3m2_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3_e2e +func.func @test_matmul_t_block_scaled_fp6e2m3_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1_e2e +func.func @test_matmul_t_block_scaled_fp4e2m1_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_mxint8_e2e +func.func @test_matmul_t_block_scaled_mxint8_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- // CHECK-LABEL: test_cast_from_block_scaled_static func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> @@ -1307,7 +1371,7 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<* // ----- // CHECK-LABEL: test_cast_to_block_scaled_mxint8 func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) { - %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index f0ad4eb..88dffe7 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -1,14 +1,22 @@ // RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s // ----- -// CHECK-LABEL: test_build_qtype -func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> { +// CHECK-LABEL: test_build_qtype_unsigned +func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> { // CHECK: tosa.negate - %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> + %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> } // ----- +// CHECK-LABEL: test_build_qtype_signed +func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> { + // CHECK: tosa.negate + %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> + return %0 : tensor<16x1x1x8x!quant.uniform<i8<1:127>:f32, 0.015680249780416489:128>> +} + +// ----- // CHECK-LABEL: test_build_mult_and_shift func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform<i32:f32, 0.078431375324726104>> { // CHECK: tosa.conv2d diff --git a/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir new file mode 100644 index 0000000..fc2d77ef --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @rewrite_f32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: return %[[CST]] +func.func @rewrite_f32_tensor() -> tensor<2xf32> { + %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + return %c : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: return %[[CST]] +func.func @rewrite_i32_tensor() -> tensor<3xi32> { + %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32> + return %c : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i1_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1> +func.func @rewrite_i1_tensor() -> tensor<2xi1> { + %c = arith.constant dense<[true, false]> : tensor<2xi1> + return %c : tensor<2xi1> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_rank0_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor<f32>}> : () -> tensor<f32> +func.func @rewrite_rank0_tensor() -> tensor<f32> { + %c = arith.constant dense<1.234500e+00> : tensor<f32> + return %c : tensor<f32> +} + +// ----- + +// CHECK-LABEL: func.func @preserve_scalar_i32 +// CHECK: %[[CST:.*]] = arith.constant 42 : i32 +func.func @preserve_scalar_i32() -> i32 { + %c = arith.constant 42 : i32 + return %c : i32 +} + +// ----- + +// CHECK-LABEL: func.func @preserve_index_tensor +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex> +func.func @preserve_index_tensor() -> tensor<2xindex> { + %c = arith.constant dense<[0, 1]> : tensor<2xindex> + return %c : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_resource_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource<blob1> : tensor<4xf32>}> : () -> tensor<4xf32> +func.func @rewrite_resource_tensor() -> tensor<4xf32> { + %c = arith.constant dense_resource<"blob1"> : tensor<4xf32> + return %c : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8> +func.func @rewrite_quant_tensor() -> tensor<2xui8> { + %c = arith.constant dense<[10, 20]> : tensor<2xui8> + return %c : tensor<2xui8> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 5.000000e-01>>}> : () -> tensor<2x!quant.uniform<i8:f32, 5.000000e-01>> +func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform<i8:f32, 0.5:0>> { + %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform<i8:f32, 0.5:0>> + return %c : tensor<2x!quant.uniform<i8:f32, 0.5:0>> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_fp8_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN> +func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> { + %c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN> + return %c : tensor<2xf8E4M3FN> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_mxint8_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8> +func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> { + %c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8> + return %c : tensor<2x!tosa.mxint8> +} diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index c7eeb52..d4c4595 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } + +// ----- +// CHECK-LABEL: func.func @depthwise_conv2d_as_mul_dynamic_batch_bias( +// CHECK-SAME: %[[INP:.*]]: tensor<?x10x10x2xf32>, +// CHECK-SAME: %[[WTS:.*]]: tensor<1x1x2x3xf32>, +// CHECK-SAME: %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> { +// CHECK: %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32> +// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32> +// CHECK: %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32> +// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32> +// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32> +// CHECK: %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32> +// CHECK: return %[[RES]] +func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32> + return %0 : tensor<?x10x10x6xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 810135f..61ca0ae 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32> "func.return" (%2) : (tensor<1x19x2x1xi32>) -> () } + + +// ----- +// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch +// CHECK: tosa.conv2d +// CHECK-NOT: tosa.transpose_conv2d +func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32> + return %0 : tensor<?x18x19x5xf32> +} + +// ----- +// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch +// CHECK: tosa.conv2d +// CHECK-NOT: tosa.transpose_conv2d +func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32> + return %0 : tensor<?x35x47x5xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir new file mode 100644 index 0000000..1a36177 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND + +// CHECK-LABEL: test_i64_argmax_large_axis_dim +func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> { + // DEFAULT: tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi32> + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> + return %0 : tensor<1x513x513xi64> +} + +// ----- + +// CHECK-LABEL: test_convert_input_parameters +// DEFAULT: %[[IN:.*]]: tensor<1x513x513x3xi64> +// FUNCBOUND: %[[IN:.*]]: tensor<1x513x513x3xi32> +func.func @test_convert_input_parameters(%arg0: tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xf32> { + // DEFAULT: %[[FUNC_BOUND_CAST:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32> + // DEFAULT: %[[CAST1:.*]] = tosa.cast %[[FUNC_BOUND_CAST]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32> + // FUNCBOUND: %[[CAST1:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32> + %0 = tosa.cast %arg0 : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32> + + // COMMON: %[[CAST2:.*]] = tosa.cast %[[CAST1]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32> + %1 = tosa.cast %0 : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32> + return %1 : tensor<1x513x513x3xf32> +} + +// ----- + +// CHECK-LABEL: test_add +// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xi64>, %[[IN1:.*]]: tensor<13x21x3xi64> +// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xi32>, %[[IN1:.*]]: tensor<13x21x3xi32> +func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + // DEFAULT-DAG: %[[FUNC_BOUND_CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xi64>) -> tensor<13x21x1xi32> + // DEFAULT-DAG: %[[FUNC_BOUND_CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi32> + // DEFAULT: %[[ADD:.*]] = tosa.add %[[FUNC_BOUND_CAST0]], %[[FUNC_BOUND_CAST1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64> + // DEFAULT: return %[[CAST]] : tensor<13x21x3xi64> + // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xi32> + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_regions +// DEFAULT: %[[IN0:.*]]: tensor<i64>, %[[IN1:.*]]: tensor<i64> +func.func @test_regions(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i1>) -> tensor<i64> { + // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<i64>) -> tensor<i32> + // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<i64>) -> tensor<i32> + // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if + %0 = tosa.cond_if %arg2 : tensor<i1> -> (tensor<i64>) { + // DEFAULT: %[[ADD:.*]] = tosa.add %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + %1 = tosa.add %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64> + // COMMON: tosa.yield %[[ADD]] : tensor<i32> + tosa.yield %1 : tensor<i64> + } else { + // DEFAULT: %[[SUB:.*]] = tosa.sub %[[CAST0]], %[[CAST1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + // FUNCBOUND: %[[SUB:.*]] = tosa.sub %[[IN0]], %[[IN1]] : (tensor<i32>, tensor<i32>) -> tensor<i32> + %1 = tosa.sub %arg0, %arg1 : (tensor<i64>, tensor<i64>) -> tensor<i64> + // COMMON: tosa.yield %[[SUB]] : tensor<i32> + tosa.yield %1 : tensor<i64> + } + // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor<i32>) -> tensor<i64> + // DEFAULT: return %[[OUT]] : tensor<i64> + // FUNCBOUND: return %[[IF_RESULT]] : tensor<i32> + return %0 : tensor<i64> +} + +// ----- + +// CHECK-LABEL: test_const +func.func @test_const() -> tensor<2xi64> { + // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32> + %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> + // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64> + // DEFAULT: return %[[OUT]] : tensor<2xi64> + // FUNCBOUND: return %[[CONST]] : tensor<2xi32> + return %0 : tensor<2xi64> +} diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir new file mode 100644 index 0000000..a14483f --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir @@ -0,0 +1,162 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT +// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND + +// ----- + +// CHECK-LABEL: test_i64_argmax +func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> { + // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32> + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> + + // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xi64> + // FUNCBOUND: return %[[ARGMAX]] : tensor<1x513x513xi32> + return %0 : tensor<1x513x513xi64> +} + +// ----- + +// CHECK-LABEL: test_i64_argmax_cast +func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> { + // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32> + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> + // COMMON: tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xf32> + %1 = tosa.cast %0 : (tensor<1x513x513xi64>) -> tensor<1x513x513xf32> + return %1 : tensor<1x513x513xf32> +} + +// ----- + +// CHECK-LABEL: test_i64_argmax_large_axis_dim +func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> { + // expected-error @+1 {{failed to legalize operation 'tosa.argmax'}} + %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> + return %0 : tensor<1x513x513xi64> +} + +// ----- + +// CHECK-LABEL: test_add +func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + // expected-error @+1 {{failed to legalize operation 'tosa.add'}} + %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_regions +func.func @test_regions(%arg0: tensor<1x2xi32>, %arg1: tensor<1xi32>, %arg2: tensor<i1>) -> tensor<1xi32> { + // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<1xi32> { + // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi32> + %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi64> + // COMMON: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1xi32>) -> tensor<1xi32> + %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32> + // COMMON: tosa.yield %[[CAST]] : tensor<1xi32> + tosa.yield %2 : tensor<1xi32> + } else { + tosa.yield %arg1 : tensor<1xi32> + } + // COMMON: return %[[IF_RESULT]] : tensor<1xi32> + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: test_concat +func.func @test_concat(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<26x21x3xi64> { + // COMMON: tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<26x21x3xi32> + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<26x21x3xi64> + return %0 : tensor<26x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_pad +func.func @test_pad(%arg0: tensor<13x21x3xi64>, %arg1: tensor<1xi64>) -> tensor<15x23x5xi64> { + %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6> + // COMMON: tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<6>, tensor<1xi32>) -> tensor<15x23x5xi32> + %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xi64>, !tosa.shape<6>, tensor<1xi64>) -> tensor<15x23x5xi64> + return %1 : tensor<15x23x5xi64> +} + +// ----- + +// CHECK-LABEL: test_reshape +func.func @test_reshape(%arg0: tensor<13x21x3xi64>) -> tensor<1x819xi64> { + %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2> + // COMMON: tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<2>) -> tensor<1x819xi32> + %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xi64>, !tosa.shape<2>) -> tensor<1x819xi64> + return %0 : tensor<1x819xi64> +} + +// ----- + +// CHECK-LABEL: test_reverse +func.func @test_reverse(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { + // COMMON: tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %0 : tensor<13x21x3xi64> +} + +// ----- + +// CHECK-LABEL: test_slice +func.func @test_slice(%arg0: tensor<13x21x3xi64>) -> tensor<4x11x1xi64> { + %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> + // COMMON: tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi32> + %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xi64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi64> + return %2 : tensor<4x11x1xi64> +} + +// ----- + +// CHECK-LABEL: test_tile +func.func @test_tile(%arg0: tensor<13x21x3xi64>) -> tensor<39x21x6xi64> { + %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> + // COMMON: tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>) -> tensor<39x21x6xi32> + %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xi64>, !tosa.shape<3>) -> tensor<39x21x6xi64> + return %0 : tensor<39x21x6xi64> +} + +// ----- + +// CHECK-LABEL: transpose +func.func @test_transpose(%arg0: tensor<13x21x3xi64>) -> tensor<3x13x21xi64> { + // COMMON: tosa.transpose %{{.*}} {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi32>) -> tensor<3x13x21xi32> + %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<13x21x3xi64>) -> tensor<3x13x21xi64> + return %1 : tensor<3x13x21xi64> +} + +// ----- + +// CHECK-LABEL: test_transition_to_i64 +func.func @test_transition_to_i64(%arg0: tensor<1xi32>) -> tensor<1xi64> { + // COMMON: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi32> + %0 = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi64> + // COMMON: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32> + %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64> + // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32> + %2 = tosa.identity %1 : (tensor<1xi64>) -> tensor<1xi64> + // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi64> + // DEFAULT: return %[[OUT_CAST]] : tensor<1xi64> + // FUNCBOUND: return %[[IDENTITY2]] : tensor<1xi32> + return %2 : tensor<1xi64> +} + +// ----- + +// CHECK-LABEL: test_transition_from_i64 +func.func @test_transition_from_i64(%arg0: tensor<1xi64>) -> tensor<1xi32> { + // DEFAULT: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi64>) -> tensor<1xi32> + // DEFAULT: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32> + // FUNCBOUND: %[[IDENTITY1:.*]] = tosa.identity %arg0 : (tensor<1xi32>) -> tensor<1xi32> + %0 = tosa.identity %arg0 : (tensor<1xi64>) -> tensor<1xi64> + // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32> + %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64> + // COMMON: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi32> + %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32> + // COMMON: return %[[OUT_CAST]] : tensor<1xi32> + return %2 : tensor<1xi32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index 9bd7aa8..f6b1edc 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -2,6 +2,7 @@ // ----- +// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> @@ -53,14 +54,6 @@ func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> { // ----- -// CHECK-LABEL: test_cast_f4e2m1 -func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> { - %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> - return %0 : tensor<13x21x3xbf16> -} - -// ----- - // CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32 func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> @@ -109,14 +102,6 @@ func.func @test_const_mxint8() -> tensor<2x!tosa.mxint8> { // ----- -// CHECK-LABEL: test_cast_f4e2m1 -func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> { - %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> - return %0 : tensor<13x21x3xbf16> -} - -// ----- - // CHECK-LABEL: test_matmul_t_block_scaled_mxint8 func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> @@ -130,3 +115,28 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor< %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> } + +// ----- + +// CHECK-LABEL: test_argmax_fp8_i64 +func.func @test_argmax_fp8_i64(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> { + %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> + return %0 : tensor<12x16xi64> +} + +// ----- + +// CHECK-LABEL: test_argmax_bf16_i64 +func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64> { + %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64> + return %0 : tensor<12x16xi64> +} + +// ----- + +// CHECK-LABEL: test_scatter_const_indices_int64 +func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 6cf76cd..ea64d46 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4 %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU> } + +// ----- + +func.func @test_clamp_quantized(%arg0:tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) { + // expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}} + %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>>) -> tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> + return %0 : tensor<?x112x112x32x!quant.uniform<u8:f32, 0.023529412224888802:-128>> +} diff --git a/mlir/test/Dialect/Transform/include-failure-propagation.mlir b/mlir/test/Dialect/Transform/include-failure-propagation.mlir new file mode 100644 index 0000000..94e9d8f --- /dev/null +++ b/mlir/test/Dialect/Transform/include-failure-propagation.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --verify-diagnostics + +module attributes { transform.with_named_sequence } { + // Callee returns a silenceable failure when given a module instead of func.func. + transform.named_sequence @callee(%root: !transform.any_op {transform.consumed}) -> (!transform.any_op) { + transform.test_consume_operand_of_op_kind_or_fail %root, "func.func" : !transform.any_op + transform.yield %root : !transform.any_op + } + + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %res = transform.sequence %root : !transform.any_op -> !transform.any_op failures(suppress) { + ^bb0(%arg0: !transform.any_op): + // This include returns a silenceable failure; it must not remap results. + %included = transform.include @callee failures(propagate) (%arg0) : (!transform.any_op) -> (!transform.any_op) + transform.yield %included : !transform.any_op + } + + %count = transform.num_associations %res : (!transform.any_op) -> !transform.param<i64> + // expected-remark @below {{0}} + transform.debug.emit_param_as_remark %count : !transform.param<i64> + + // If the include incorrectly forwarded mappings on failure, this would run + // and produce an unexpected remark under --verify-diagnostics. + transform.foreach %res : !transform.any_op { + ^bb0(%it: !transform.any_op): + transform.debug.emit_remark_at %it, "include result unexpectedly populated" : !transform.any_op + } + transform.yield + } +} + +// ----- + +module { + func.func @payload() { + return + } +} diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index ce8f69c..4806daf7 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -386,7 +386,7 @@ module attributes {transform.with_named_sequence} { // ----- module attributes {transform.with_named_sequence} { - // expected-error @below {{trying to schedule a pass on an unsupported operation}} + // expected-error @below {{trying to schedule pass 'DuplicateFunctionEliminationPass' on an unsupported operation}} // expected-note @below {{target op}} func.func @invalid_target_op_type() { return diff --git a/mlir/test/Dialect/UB/ops.mlir b/mlir/test/Dialect/UB/ops.mlir index 724b6b4..730c1bd 100644 --- a/mlir/test/Dialect/UB/ops.mlir +++ b/mlir/test/Dialect/UB/ops.mlir @@ -38,3 +38,9 @@ func.func @poison_tensor() -> tensor<8x?xf64> { %0 = ub.poison : tensor<8x?xf64> return %0 : tensor<8x?xf64> } + +// CHECK-LABEL: func @unreachable() +// CHECK: ub.unreachable +func.func @unreachable() { + ub.unreachable +} diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir index 887fb94..70adefd 100644 --- a/mlir/test/Dialect/Vector/bufferize.mlir +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index, // ----- +// CHECK-LABEL: func @scatter( +// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>, +// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32> +// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32> +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32> +// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32> +// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32> +// CHECK: return %[[tensor]] : tensor<16x16xf32> +func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> { + %c0 = arith.constant 0 : index + %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value + : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} + +// ----- + // CHECK-LABEL: func @gather( // CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>, // CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 5f035e3..79b09e1 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}} + // expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}} vector.scatter %base[%c0][%indices], %mask, %pass_thru - : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index da9a1a8..de62022 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1160,3 +1160,17 @@ func.func @step() { %1 = vector.step : vector<[4]xindex> return } + +// CHECK-LABEL: func @scatter_tensor( +// CHECK-SAME: %[[BASE:.*]]: tensor<16x16xf32>, %[[V:.*]]: vector<16xi32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16x16xf32> +func.func @scatter_tensor(%base: tensor<16x16xf32>, %v: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[RESULT:.*]] = vector.scatter %[[BASE]][%[[C0]], %[[C0]]] [%[[V]]], %[[MASK]], %[[VALUE]] + %0 = vector.scatter %base[%c0, %c0] [%v], %mask, %value + : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32> + // CHECK: return %[[RESULT]] : tensor<16x16xf32> + return %0 : tensor<16x16xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir index 1d8f440..27a3653 100644 --- a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s +// RUN: mlir-opt %s -split-input-file --test-vector-scan-lowering | FileCheck %s // CHECK-LABEL: func @scan1d_inc // CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, @@ -18,6 +18,20 @@ func.func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi return %0#0, %0#1 : vector<2xi32>, vector<i32> } +// ----- + +// Reducing scalable dims is not yet supported! + +// CHECK-LABEL: func @scan1d_inc_scalable +// CHECK: vector.scan +func.func @scan1d_inc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector<i32>) -> (vector<[2]xi32>, vector<i32>) { + %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + vector<[2]xi32>, vector<i32> + return %0#0, %0#1 : vector<[2]xi32>, vector<i32> +} + +// ----- + // CHECK-LABEL: func @scan1d_exc // CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<i32> @@ -36,6 +50,20 @@ func.func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi return %0#0, %0#1 : vector<2xi32>, vector<i32> } +// ----- + +// Rducing scalable dims is not yet supported! + +// CHECK-LABEL: func @scan1d_exc_scalable +// CHECK: vector.scan +func.func @scan1d_exc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector<i32>) -> (vector<[2]xi32>, vector<i32>) { + %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = false, reduction_dim = 0} : + vector<[2]xi32>, vector<i32> + return %0#0, %0#1 : vector<[2]xi32>, vector<i32> +} + +// ----- + // CHECK-LABEL: func @scan2d_mul_dim0 // CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<3xi32> @@ -53,6 +81,27 @@ func.func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (v return %0#0, %0#1 : vector<2x3xi32>, vector<3xi32> } +// ----- + +// CHECK-LABEL: func @scan2d_mul_dim0_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<2x[3]xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<[3]xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x[3]xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x[3]xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x[3]xi32> to vector<[3]xi32> +// CHECK: return %[[F]], %[[G]] : vector<2x[3]xi32>, vector<[3]xi32> +func.func @scan2d_mul_dim0_scalable(%arg0 : vector<2x[3]xi32>, %arg1 : vector<[3]xi32>) -> (vector<2x[3]xi32>, vector<[3]xi32>) { + %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + vector<2x[3]xi32>, vector<[3]xi32> + return %0#0, %0#1 : vector<2x[3]xi32>, vector<[3]xi32> +} + +// ----- + // CHECK-LABEL: func @scan2d_mul_dim1 // CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<2xi32> @@ -73,6 +122,30 @@ func.func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (v return %0#0, %0#1 : vector<2x3xi32>, vector<2xi32> } +// ----- + +// CHECK-LABEL: func @scan2d_mul_dim1_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<[2]x3xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<[2]xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<[2]x3xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<[2]x1xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<[2]x1xi32> +// CHECK: %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[J:.*]] = vector.shape_cast %[[H]] : vector<[2]x1xi32> to vector<[2]xi32> +// CHECK: return %[[I]], %[[J]] : vector<[2]x3xi32>, vector<[2]xi32> +func.func @scan2d_mul_dim1_scalable(%arg0 : vector<[2]x3xi32>, %arg1 : vector<[2]xi32>) -> (vector<[2]x3xi32>, vector<[2]xi32>) { + %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = true, reduction_dim = 1} : + vector<[2]x3xi32>, vector<[2]xi32> + return %0#0, %0#1 : vector<[2]x3xi32>, vector<[2]xi32> +} + +// ----- + // CHECK-LABEL: func @scan3d_mul_dim1 // CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>, // CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32> @@ -89,3 +162,22 @@ func.func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) - vector<4x2x3xf32>, vector<4x3xf32> return %0#0, %0#1 : vector<4x2x3xf32>, vector<4x3xf32> } + +// ----- + +// CHECK-LABEL: func @scan3d_mul_dim1_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x[3]xf32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<4x[3]xf32> +// CHECK: %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x[3]xf32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x[3]xf32> to vector<4x1x[3]xf32> +// CHECK: %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x[3]xf32> to vector<4x1x[3]xf32> +// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32> +// CHECK: %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x[3]xf32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x[3]xf32> to vector<4x[3]xf32> +// CHECK: return %[[F]], %[[G]] : vector<4x2x[3]xf32>, vector<4x[3]xf32> +func.func @scan3d_mul_dim1_scalable(%arg0 : vector<4x2x[3]xf32>, %arg1 : vector<4x[3]xf32>) -> (vector<4x2x[3]xf32>, vector<4x[3]xf32>) { + %0:2 = vector.scan <mul>, %arg0, %arg1 {inclusive = false, reduction_dim = 1} : + vector<4x2x[3]xf32>, vector<4x[3]xf32> + return %0#0, %0#1 : vector<4x2x[3]xf32>, vector<4x[3]xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index 577b06d..69fba88 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -382,6 +382,21 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> { return %r : vector<2x[4]xi32> } +// ----- + +// CHECK-LABEL: func.func @negative_broadcast_cast_non_vector_result +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG]] : i64 to vector<26x7xi64> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[BCAST]] : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>> +// CHECK: return %[[CAST]] : !llvm.array<26 x vector<7xi64>> +/// This test ensures that the `ReorderCastOpsOnBroadcast` pattern does not +/// attempt to reorder a cast operation that produces a non-vector result type. +func.func @negative_broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> { + %0 = vector.broadcast %arg0 : i64 to vector<26x7xi64> + %1 = builtin.unrealized_conversion_cast %0 : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>> + return %1 : !llvm.array<26 x vector<7xi64>> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderElementwiseOpsOnTranspose] //===----------------------------------------------------------------------===// @@ -780,7 +795,7 @@ func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> } //----------------------------------------------------------------------------- -// [Pattern: StoreOpFromSplatOrBroadcast] +// [Pattern: StoreOpFromBroadcast] //----------------------------------------------------------------------------- // CHECK-LABEL: @store_splat diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5..805e66f 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,137 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + +func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> { + %0 = vector.create_mask %size1, %size2 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1> +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index +// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index +// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1> +// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index +// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index +// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index +// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1> +// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index +// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index +// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index +// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1> +// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index +// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index +// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index +// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index +// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1> +// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[INS11]] : vector<16x16xi1> + +func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { + %cst16 = arith.constant 16 : index + %0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1> +// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x8xi1> +// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[S3]] : vector<16x16xi1> + + +func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { + %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> + return %0 : vector<2x2x4xf32> +} + +// CHECK-LABEL: func @shape_cast_1D +// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: return %[[I1]] : vector<2x2x4xf32> + + +func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func @shape_cast_2D +// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: return %[[I1]] : vector<4x4xf32> + + +// This is a negative test case to ensure that such shape casts are not unrolled +// because the targetShape (2x4) is not contiguous in result vector +func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> { + %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32> + return %0 : vector<8x8xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous +// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32> +// CHECK: return %[[SC]] : vector<8x8xf32> + + +// This is negative test case to ensure that such shape casts are not unrolled +// because it cannot determine the extractShape from source vector (8x3) +// to extract conitguous targetShape (2x4) +func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> { + %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32> + return %0 : vector<6x4xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable +// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32> +// CHECK: return %[[SC]] : vector<6x4xf32> + + +// TargetShape is [1x16] +func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> { + %0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32> + return %0 : vector<1x32xf32> +} + +// CHECK-LABEL: func @shape_cast_leading_unit_dim +// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: return %[[I1]] : vector<1x32xf32> diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir new file mode 100644 index 0000000..e506b16 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-fma.mlir @@ -0,0 +1,344 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_outer_product_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_fma +// CHECK: vector.broadcast{{.*}}vector<1xf32> to vector<64xf32> +// CHECK: vector.fma{{.*}}vector<64xf32> +// CHECK: vector.shape_cast{{.*}}vector<64xf32> to vector<1x64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<64x1xf32> +!vecB = vector<1x1xf32> +!vecC = vector<64x1xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @matmul_outer_product_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x64x1xf32> +!vecB = vector<1x1x1xf32> +!vecC = vector<1x64x1xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_fma( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_fma +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x64x1xf32> +!vecB = vector<1x1x1xf32> +!vecC = vector<64x1xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_fma_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_fma_bcst_B +// CHECK: vector.broadcast +// CHECK: vector.fma{{.*}}vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<3x1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_non_unit_batch_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1xf32> +!vecB = vector<3x1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @negative_non_unit_batch_reduce_dim( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// Batch-reduce dimension should've been simplified earlier. + +// CHECK-LABEL: @negative_non_unit_batch_reduce_dim +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1xf32> +!vecB = vector<1x64xf32> +!vecC = vector<1x64xf32> +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @negative_invalid_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind<mul>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_kind +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xf32> +!vecB = vector<1x1x64xf32> +!vecC = vector<1x1x64xi32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_accumulator_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_accumulator_type +// CHECK-NOT: vector.fma +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_fma + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir new file mode 100644 index 0000000..65676cb --- /dev/null +++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir @@ -0,0 +1,681 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x16x1x2xbf16> +!vecB = vector<1x1x1x2xbf16> +!vecC = vector<16x1xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_bf16dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_bf16dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @brgemm_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @brgemm_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @batch_matmul_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + + +// CHECK-LABEL: @batch_matmul_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x8x1x4xi8> +!vecB = vector<1x1x1x4xi8> +!vecC = vector<1x8x1xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @batch_matmul_int8dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + + +// CHECK-LABEL: @batch_matmul_int8dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_bf16dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_bf16dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<16x1x2xbf16> +!vecB = vector<1x1x2xbf16> +!vecC = vector<16x1xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_bf16dp_bcst_B( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_bf16dp_bcst_B +// CHECK: vector.broadcast +// CHECK: x86vector.avx512.dot + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x4xi8> +!vecB = vector<1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @matmul_outer_product_to_int8dp( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @matmul_outer_product_to_int8dp +// CHECK: vector.broadcast +// CHECK: x86vector.avx.dot.i8 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x2xbf16> +!vecB = vector<1x16x2xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d4, d1, d2, d3) -> (d1, d3, d4)> +#map1 = affine_map<(d4, d1, d2, d3) -> (d3, d2, d4)> +#map2 = affine_map<(d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_invalid_vc_kind( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<mul>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_invalid_vc_kind +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xbf16> +!vecB = vector<1x1x16x4xbf16> +!vecC = vector<1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xi8> +!vecB = vector<1x1x8x2xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_false_vnni_int8( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_false_vnni_int8 +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<3x1x1x2xbf16> +!vecB = vector<3x1x16x2xbf16> +!vecC = vector<3x1x16xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_batch_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_batch_dimension +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<2x1x1x4xi8> +!vecB = vector<2x1x8x4xi8> +!vecC = vector<1x8xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)> +func.func @negative_brgemm_dimension( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_brgemm_dimension +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x16x2xbf16> +!vecC = vector<1x1x16xbf16> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_float_acc_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_float_acc_type +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x8x4xi8> +!vecC = vector<1x1x8xi8> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_int_acc_type( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_int_acc_type +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xbf16> +!vecB = vector<1x1x16x4xbf16> +!vecC = vector<1x1x16xbf16> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vnni_blocking_factor_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vnni_blocking_factor_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1xbf16> +!vecB = vector<1x1x32xbf16> +!vecC = vector<1x32xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @negative_brgemm_not_vnni( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_brgemm_not_vnni +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x4xi8> +!vecB = vector<1x1x16x4xi8> +!vecC = vector<1x1x16xi32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vector_shape_int8( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vector_shape_int8 +// CHECK-NOT: x86vector.avx.dot.i8 +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vecA = vector<1x1x1x2xbf16> +!vecB = vector<1x1x32x2xbf16> +!vecC = vector<1x1x32xf32> +#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)> +#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)> +#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)> +func.func @negative_wrong_vector_shape_bf16( + %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC +{ + %0 = vector.contract { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind<add>} + %arg0, %arg1, %arg2 + : !vecA, !vecB into !vecC + return %0 : !vecC +} + +// CHECK-LABEL: @negative_wrong_vector_shape_bf16 +// CHECK-NOT: x86vector.avx512.dot +// CHECK: vector.contract + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 92f3537..67faa60 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() { // ----- func.func @create_mem_desc_non_slm() { %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1> - // expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}} + // expected-error@+1 {{operand #0 must be reside in share memory and statically 1d shaped memref }} %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 9b38296..1e9738f 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } + +// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) { +gpu.func @create_mem_desc_from_2d_memref() { + //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16> + %m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3> + %mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16> + gpu.return +} + +// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) { +gpu.func @create_mem_desc_with_stride_from_2d_memref() { + //CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3> + //CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> + //CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> + %m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3> + %m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> + %mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> + gpu.return +} + // CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir index c31ef32..32fb317 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir @@ -1,18 +1,45 @@ // RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=inst" -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @load_store_no_array_len( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<8x32xf32>) { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK: %[[TDESC_SRC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +// CHECK: %[[TDESC_DST:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +// CHECK: xegpu.prefetch_nd %[[TDESC_SRC]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<inst_data = [8, 16]>}> : +// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +// CHECK: %[[LOADED:.*]] = xegpu.load_nd %0 <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : +// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x32xf32> +// CHECK: xegpu.store_nd %[[LOADED]], %[[TDESC_DST]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> +gpu.module @test { +// Although the uArch allows 8x32 inst data using block count (or array_len), +// it is up to optimization passes to decide on the block count usage. +func.func @load_store_no_array_len(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf32>) { + %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32> + %1 = xegpu.create_nd_tdesc %arg1 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32> + xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x32xf32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x32xf32> -> vector<8x32xf32> + xegpu.store_nd %2, %1 : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32> + return +} +} + +// ----- + // CHECK-LABEL: func.func @dpas_f16( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x16xf32> // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]> // CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>> -// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : +// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x16xf16> -// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : +// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf16> -// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : +// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_a = #xegpu.layout<inst_data = [8, 16]>, layout_b = #xegpu.layout<inst_data = [16, 16]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>, layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> // CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]> -// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>> +// CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>> gpu.module @test { func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { @@ -46,7 +73,7 @@ gpu.module @test_kernel { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) -> (!xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>) { - //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : + //CHECK: xegpu.load_nd {{.*}} <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x32xf16> %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16> %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16> @@ -85,7 +112,7 @@ gpu.module @test_kernel { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) -> (!xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>) { - //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} : + //CHECK: xegpu.load_nd {{.*}} <{layout = #xegpu.layout<inst_data = [4, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} : //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>> -> vector<12x32xf16> %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16> %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16> @@ -113,9 +140,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> -// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> +// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> // CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 8]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> -// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index eb004932..48e77d8 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -6,14 +6,14 @@ gpu.module @test { // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32> // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16> -// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> -// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> // CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK: xegpu.store_nd %[[T4]], %[[T5]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> @@ -32,7 +32,8 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me gpu.module @test { // CHECK-LABEL: func.func @dpas_i8( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) { -// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], +// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) { %c0 = arith.constant 0 : index %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> @@ -46,8 +47,8 @@ func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memre gpu.module @test { // CHECK-LABEL: func.func @load_with_transpose_effect( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf32>) { -// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{transpose = array<i64: 1, 0>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : -// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16> +// CHECK: %{{.*}} = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16> func.func @load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> @@ -108,7 +109,7 @@ gpu.module @test { // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16> func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index @@ -135,7 +136,7 @@ gpu.module @test { // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> -> // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : +// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32> func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) { %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> @@ -183,9 +184,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> -// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> -// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> @@ -204,7 +205,7 @@ gpu.module @test { // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> -// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> @@ -217,13 +218,13 @@ func.func @scatter_ops(%src: memref<256xf16>) { gpu.module @test { // CHECK-LABEL: func.func @scatter_ops_custom_perm_layout( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { -// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> -// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex> // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16> // CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] -// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops_custom_perm_layout(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> %offset = arith.constant dense<12> : vector<16xindex> @@ -237,9 +238,9 @@ func.func @scatter_ops_custom_perm_layout(%src: memref<256xf16>) { gpu.module @test { // CHECK-LABEL: func.func @scatter_ops_preserve_load_perm_layout( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { -// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1> -// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> -// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1> +// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex> +// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16> // CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] @@ -256,9 +257,9 @@ func.func @scatter_ops_preserve_load_perm_layout(%src: memref<256xf16>) { // ----- gpu.module @test { // CHECK-LABEL: func.func @vector_bitcast_i16_to_f16( -// CHECK: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<8x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xi16> -// CHECK: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} +// CHECK: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xi16> // CHECK: %{{.*}} = vector.bitcast %[[LOAD0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: vector<8x16xi16> to vector<8x16xf16> @@ -281,7 +282,7 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1 // ----- gpu.module @test { // CHECK-LABEL: func.func @vector_bitcast_i32_to_f16( -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32> // CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} // CHECK-SAME: vector<16x8xi32> to vector<16x16xf16> @@ -302,7 +303,7 @@ func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8 // ----- gpu.module @test { // CHECK-LABEL: func.func @vector_bitcast_i16_to_i32( -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} // CHECK-SAME: !xegpu.tensor_desc<8x32xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>> -> vector<8x32xi16> // CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: vector<8x32xi16> to vector<8x16xi32> @@ -339,9 +340,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { -// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK: %[[T1:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> -// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %{{.*}} = arith.addf %[[T1]], %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> func.func @binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) { @@ -362,9 +363,9 @@ gpu.module @test { // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { // CHECK: %[[T2:.*]] = arith.addf %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16> -// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> -// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]] : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK: %[[T3:.*]] = xegpu.dpas %{{.*}}, %[[T2]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: xegpu.store_nd %[[T3]], %[[ARG2]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK-NEXT: xegpu.store_nd %[[T2]], %[[ARG3]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) { %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> @@ -385,11 +386,11 @@ gpu.module @test { // CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32> // CHECK-NEXT: %[[T2:.*]]:3 = scf.for %{{.*}} iter_args(%[[ARG4:.*]] = %[[T0]], %[[ARG5:.*]] = %[[T1]], %[[ARG6:.*]] = %[[CST]]) -> // CHECK-SAME: (!xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32>) { -// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG4]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG4]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16> -// CHECK-NEXT: %[[T5:.*]] = xegpu.load_nd %[[ARG5]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T5:.*]] = xegpu.load_nd %[[ARG5]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> -// CHECK-NEXT: %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %[[ARG6]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> // CHECK-NEXT: %[[T7:.*]] = xegpu.update_nd_offset %[[ARG4]], [{{.*}}] : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NEXT: %[[T8:.*]] = xegpu.update_nd_offset %[[ARG5]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> @@ -397,7 +398,7 @@ gpu.module @test { // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, vector<8x16xf32> // CHECK-NEXT: } {layout_result_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-NEXT: %[[T3:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK-NEXT: xegpu.store_nd %[[T2]]#2, %[[T3]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index @@ -425,11 +426,11 @@ gpu.module @test { // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>, // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { // CHECK: %{{.*}} = scf.if %[[ARG2]] -> (vector<16x16xf16>) { -// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T3]] : vector<16x16xf16> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : +// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T4]] : vector<16x16xf16> // CHECK-NEXT: } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} @@ -455,11 +456,11 @@ gpu.module @test { // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: i1, %[[ARG3:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG4:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { // CHECK: %[[T1:.*]] = scf.if %[[ARG2]] -> (vector<16x16xf16>) { -// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T3]] : vector<16x16xf16> // CHECK-NEXT: } else { -// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : +// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[ARG1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: scf.yield %[[T4]] : vector<16x16xf16> // CHECK-NEXT: } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} @@ -539,7 +540,7 @@ gpu.module @test { // CHECK-LABEL: func.func @prefetch_2d( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) { // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> +// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> func.func @prefetch_2d(%arg0: memref<256x256xf16>){ %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16> @@ -552,7 +553,7 @@ gpu.module @test { // CHECK-LABEL: func.func @prefetch_1d( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) { // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> +// CHECK-NEXT: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> func.func @prefetch_1d(%arg0: memref<256xf16>){ %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16> @@ -599,7 +600,7 @@ gpu.module @test { // CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim1_distributed( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> @@ -621,7 +622,7 @@ gpu.module @test { // CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { -// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> // CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] @@ -639,3 +640,61 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc return } } +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_1d_to_2d_broadcast_along_row( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { +// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} +// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> +// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}} +// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> +// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[REDUCE]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16> +func.func @vector_broadcast_1d_to_2d_broadcast_along_row(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.0000> : vector<16xf16> + %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16> + %5 = vector.broadcast %4 : vector<16xf16> to vector<16x16xf16> + xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} +} + +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_2d_to_2d_along_column( +// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add> +// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16> +// CHECK-NEXT: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16> +// CHECK-NEXT: vector.broadcast %[[SHAPECAST]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16> + +func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.0000> : vector<16xf16> + %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16> + %5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16> + %6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16> + xegpu.store_nd %6, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} +} + +// ----- +gpu.module @test { +// CHECK-LABEL: func.func @vector_broadcast_scalar_to_vector( +// CHECK: %[[CST:.*]] = arith.constant 0.{{.*}} : f16 +// CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[CST]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> + +func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16>) { + %cst = arith.constant 0.0000 : f16 + %6 = vector.broadcast %cst : f16 to vector<16x16xf16> + xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} +}
\ No newline at end of file diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir index f233dff..216f3d1 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute -allow-unregistered-dialect \ -// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s - +// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' -test-xegpu-sg-distribute \ +// RUN: -allow-unregistered-dialect -canonicalize -cse %s | FileCheck %s +gpu.module @xevm_module{ // CHECK-LABEL: gpu.func @store_nd_1d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] @@ -11,20 +11,17 @@ // CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32, // CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.store_nd %[[W]]#0, %[[T1]][%[[W]]#2] : vector<1xf32>, !xegpu.tensor_desc<16xf32> -gpu.module @xevm_module{ - gpu.func @store_nd_1d(%laneid: index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - %cst = "some_op"() : () -> vector<16xf32> - xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - } - gpu.return +gpu.func @store_nd_1d(%laneid: index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + %cst = "some_op"() : () -> vector<16xf32> + xegpu.store_nd %cst, %0 [%c0] {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> } + gpu.return } -// ----- // CHECK-LABEL: gpu.func @store_nd_2d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] @@ -37,22 +34,18 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16x16xf16, // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.store_nd %[[CAST]], %[[T1]][%[[W]]#2, %[[W]]#3] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16> -gpu.module @xevm_module{ - gpu.func @store_nd_2d(%laneid : index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - %cst = "some_op"() : () -> vector<16x16xf16> - xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - } - gpu.return +gpu.func @store_nd_2d(%laneid : index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %cst = "some_op"() : () -> vector<16x16xf16> + xegpu.store_nd %cst, %0 [%c0, %c0] {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> } + gpu.return } - -// ----- // CHECK-LABEL: gpu.func @load_nd_1d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<1xf32>, @@ -63,21 +56,19 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[W]]#1 : !xegpu.tensor_desc<16xf32, // CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf32> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.load_nd %[[T1]][%[[W]]#2] : !xegpu.tensor_desc<16xf32> -> vector<1xf32> -gpu.module @xevm_module{ - gpu.func @load_nd_1d(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : - !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32> - gpu.yield %1 : vector<16xf32> - } - "some_user_op"(%r) : (vector<1xf32>) -> () - gpu.return +gpu.func @load_nd_1d(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + %1 = xegpu.load_nd %0 [%c0] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : + !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf32> + gpu.yield %1 : vector<16xf32> } + "some_user_op"(%r) : (vector<1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @load_nd_2d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, !xegpu.tensor_desc<16x16xf16, @@ -89,21 +80,19 @@ gpu.module @xevm_module{ // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> // CHECK: vector.shape_cast %[[T2]] : vector<16xf16> to vector<16x1xf16> -gpu.module @xevm_module{ - gpu.func @load_nd_2d(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> - gpu.yield %1 : vector<16x16xf16> - } - "some_user_op"(%r) : (vector<16x1xf16>) -> () - gpu.return +gpu.func @load_nd_2d(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> + gpu.yield %1 : vector<16x16xf16> } + "some_user_op"(%r) : (vector<16x1xf16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @load_nd_array_length // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<2x16x1xf16>, @@ -118,23 +107,21 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T2:.*]] = xegpu.load_nd %[[T1]][%[[W]]#2, %[[W]]#3] : !xegpu.tensor_desc<16x16xf16, // CHECK-SAME: #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<32xf16> // CHECK-NEXT: vector.shape_cast %[[T2]] : vector<32xf16> to vector<2x16x1xf16> -gpu.module @xevm_module{ - gpu.func @load_nd_array_length(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, - #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, - #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16> - gpu.yield %1 : vector<2x16x16xf16> - } - "some_user_op"(%r) : (vector<2x16x1xf16>) -> () - gpu.return +gpu.func @load_nd_array_length(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x16x1xf16>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, + #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>, + #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<2x16x16xf16> + gpu.yield %1 : vector<2x16x16xf16> } + "some_user_op"(%r) : (vector<2x16x1xf16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @dpas // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> @@ -146,29 +133,27 @@ gpu.module @xevm_module{ // CHECK-DAG: %[[T3:.*]] = vector.shape_cast %[[W]]#3 : vector<8x1xf32> to vector<8xf32> // CHECK-NEXT: %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T2]], %[[T3]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> // CHECK-NEXT: vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32> -gpu.module @xevm_module{ - gpu.func @dpas(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { - %0 = "some_op"() : () -> vector<8x16xf16> - %1 = "some_op"() : () -> vector<16x16xf16> - %2 = "some_op"() : () -> vector<8x16xf32> - %3 = xegpu.dpas %0, %1, %2 - { - layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, - layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - gpu.yield %3 : vector<8x16xf32> - } - "some_user_op"(%r) : (vector<8x1xf32>) -> () - gpu.return +gpu.func @dpas(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_op"() : () -> vector<8x16xf16> + %1 = "some_op"() : () -> vector<16x16xf16> + %2 = "some_op"() : () -> vector<8x16xf32> + %3 = xegpu.dpas %0, %1, %2 + { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + gpu.yield %3 : vector<8x16xf32> } + "some_user_op"(%r) : (vector<8x1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @create_nd_tdesc_non_memref // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG1]])[16] -> (!xegpu.tensor_desc<16x16xf16, @@ -178,21 +163,19 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[W]]#1, shape : [64, 128], strides : [128, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16> // CHECK-NEXT: builtin.unrealized_conversion_cast %[[T1]] : !xegpu.tensor_desc<16x16xf16> to !xegpu.tensor_desc<16x16xf16, // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> {resolve_simt_type_mismatch} -gpu.module @xevm_module{ - gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { - %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 -> - !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - } - "some_user_op"(%r) - : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> () - gpu.return +gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) { + %0 = xegpu.create_nd_tdesc %arg0, shape:[64, 128], strides:[128, 1] : ui64 -> + !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.yield %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> } + "some_user_op"(%r) + : (!xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @prefetch_2d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16x16xf16, @@ -204,21 +187,19 @@ gpu.module @xevm_module{ // CHECK-SAME: #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> to !xegpu.tensor_desc<16x16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1, %[[W]]#2] // CHECK-SAME: <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16> -gpu.module @xevm_module{ - gpu.func @prefetch_2d(%laneid: index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () - -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - xegpu.prefetch_nd %0[%c0, %c0] - <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> - } - gpu.return +gpu.func @prefetch_2d(%laneid: index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + xegpu.prefetch_nd %0[%c0, %c0] + <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @prefetch_1d // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: index) { // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (!xegpu.tensor_desc<16xf16, @@ -229,44 +210,40 @@ gpu.module @xevm_module{ // CHECK-SAME: #xegpu.layout<lane_layout = [16], lane_data = [1]>> to !xegpu.tensor_desc<16xf16> {resolve_simt_type_mismatch} // CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[W]]#1] <{l1_hint = #xegpu.cache_hint<cached>, // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16> -gpu.module @xevm_module{ - gpu.func @prefetch_1d(%laneid: index) { - %c0 = arith.constant 0 : index - gpu.warp_execute_on_lane_0(%laneid)[16] { - %0 = "some_op"() : () - -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - xegpu.prefetch_nd %0[%c0] - <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - } - gpu.return +gpu.func @prefetch_1d(%laneid: index) { + %c0 = arith.constant 0 : index + gpu.warp_execute_on_lane_0(%laneid)[16] { + %0 = "some_op"() : () + -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + xegpu.prefetch_nd %0[%c0] + <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) { // CHECK: gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) { // CHECK: gpu.yield %{{.*}} // CHECK: } // CHECK: %{{.*}} = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16xf16> -> vector<1xf16> // CHECK: gpu.barrier -gpu.module @xevm_module{ - gpu.func @gpu_barrier(%laneid: index) { - %c0 = arith.constant 0 : index - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) { - %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> - %1 = xegpu.load_nd %0[%c0] - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16> - gpu.barrier - gpu.yield %1 : vector<16xf16> - } - "some_user_op"(%r) : (vector<1xf16>) -> () - gpu.return +gpu.func @gpu_barrier(%laneid: index) { + %c0 = arith.constant 0 : index + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf16>) { + %0 = "some_op"() : () -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> + %1 = xegpu.load_nd %0[%c0] + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16> + gpu.barrier + gpu.yield %1 : vector<16xf16> } + "some_user_op"(%r) : (vector<1xf16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction // CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32> // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] @@ -285,7 +262,6 @@ gpu.module @xevm_module{ // CHECK: %[[T7:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32> // CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T6]], %[[T7]] : vector<16xf32> into f32 // CHECK: %[[T9:.*]] = vector.from_elements %[[T4]], %[[T8]] : vector<2xf32> -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -307,9 +283,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction // CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { // CHECK-NEXT: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32> @@ -320,7 +295,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) // CHECK-NEXT: %[[T6:.*]] = vector.from_elements %[[T3]], %[[T5]] : vector<2xf32> // CHECK-NEXT: gpu.yield %[[T6]] : vector<2xf32> // CHECK-NEXT: } -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -342,9 +316,8 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction // CHECK: %[[ACC:.*]] = arith.constant {{.*}} dense<0.000000e+00> : vector<32xf32> // CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<2x16xf32>, vector<2xf32>) { @@ -358,7 +331,6 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) // CHECK: %[[T5:.*]] = vector.extract %[[W]]#2[1] : f32 from vector<2xf32> // CHECK: %[[T6:.*]] = vector.reduction <add>, %[[T4]], %[[T5]] : vector<16xf32> into f32 // CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32> -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -380,9 +352,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction // CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { // CHECK: %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32> @@ -397,7 +368,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) // CHECK: %[[T7:.*]] = vector.from_elements %[[T3]], %[[T6]] : vector<2xf32> // CHECK: gpu.yield %[[T7]] : vector<2xf32> // CHECK: } -gpu.module @xevm_module{ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) { %c0 = arith.constant 0 : index %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { @@ -419,9 +389,8 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) "some_user_op"(%r) : (vector<2xf32>) -> () gpu.return } -} -// ----- + // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) { // CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex> // CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1> @@ -434,35 +403,33 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) // CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> // CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}> // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> -gpu.module @xevm_module{ - gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) { - gpu.warp_execute_on_lane_0(%laneid)[16] { - %1 = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<1>: vector<16xi1> - %offset = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> - { - layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> - } - : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> - xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> - { - layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> - } - : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> - } - gpu.return +gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) { + gpu.warp_execute_on_lane_0(%laneid)[16] { + %1 = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<1>: vector<16xi1> + %offset = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> + { + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> + } + : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> + { + layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { // CHECK: %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex> // CHECK: %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1> @@ -475,156 +442,144 @@ gpu.module @xevm_module{ // CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16> // CHECK-NEXT: xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2], %[[W]]#3 // CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> -gpu.module @xevm_module{ - gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) { - gpu.warp_execute_on_lane_0(%laneid)[16] { - %1 = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<1> : vector<16xi1> - %offset = arith.constant - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} - dense<12> : vector<16xindex> - %3 = xegpu.load %src[%offset], %1 - { - layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> - } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> - xegpu.store %3, %src[%offset], %1 - { - layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> - } - : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> +gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) { + gpu.warp_execute_on_lane_0(%laneid)[16] { + %1 = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<1> : vector<16xi1> + %offset = arith.constant + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} + dense<12> : vector<16xindex> + %3 = xegpu.load %src[%offset], %1 + { + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16> + xegpu.store %3, %src[%offset], %1 + { + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]> } - gpu.return + : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> } + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index( // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (index, memref<256x256xf16>) { // CHECK: gpu.yield %{{.*}}, %{{.*}} : index, memref<256x256xf16> // CHECK-NEXT: } // CHECK-NEXT: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[W]]#1 : memref<256x256xf16> -> index // CHECK-NEXT: arith.index_cast %[[INTPTR]] : index to i64 -gpu.module @xevm_module{ - gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) { - %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index - gpu.yield %ptr : index - } - %ptr_i64 = arith.index_cast %r : index to i64 - "some_user_op"(%ptr_i64) : (i64) -> () - gpu.return +gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (index) { + %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index + gpu.yield %ptr : index } + %ptr_i64 = arith.index_cast %r : index to i64 + "some_user_op"(%ptr_i64) : (i64) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_transpose( // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) { // CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<16x2xf32> // CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<2x16xf32>, vector<16x2xf32> // CHECK-NEXT: } // CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32> -gpu.module @xevm_module{ - gpu.func @vector_transpose(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} - : () -> (vector<16x2xf32>) - %transpose = vector.transpose %cst, [1, 0] - { - layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<16x2xf32> to vector<2x16xf32> - gpu.yield %transpose : vector<2x16xf32> - } - "some_user_op"(%r) : (vector<2x1xf32>) -> () - gpu.return +gpu.func @vector_transpose(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} + : () -> (vector<16x2xf32>) + %transpose = vector.transpose %cst, [1, 0] + { + layout_operand_0 = #xegpu.layout<lane_layout = [16 , 1], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16x2xf32> to vector<2x16xf32> + gpu.yield %transpose : vector<2x16xf32> } + "some_user_op"(%r) : (vector<2x1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_bitcast( // CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<4x1xi16>, vector<4x2xi8>) { // CHECK: %[[SRC:.*]] = "some_op"() {{.*}} : () -> vector<4x32xi8> // CHECK: gpu.yield %{{.*}}, %[[SRC]] : vector<4x16xi16>, vector<4x32xi8> // CHECK: } // CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16> -gpu.module @xevm_module{ - gpu.func @vector_bitcast(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} - : () -> (vector<4x32xi8>) - %bitcast = vector.bitcast %cst - { - layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<4x32xi8> to vector<4x16xi16> - gpu.yield %bitcast : vector<4x16xi16> - } - "some_user_op"(%r) : (vector<4x1xi16>) -> () - gpu.return +gpu.func @vector_bitcast(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} + : () -> (vector<4x32xi8>) + %bitcast = vector.bitcast %cst + { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<4x32xi8> to vector<4x16xi16> + gpu.yield %bitcast : vector<4x16xi16> } + "some_user_op"(%r) : (vector<4x1xi16>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing // CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) { // CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32> // CHECK: } // CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32> -gpu.module @xevm_module { - gpu.func @vector_shapecast_rank_increasing(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} - : () -> (vector<16xf32>) - %cast = vector.shape_cast %cst - { - layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<16xf32> to vector<1x16xf32> - gpu.yield %cast : vector<1x16xf32> - } - "some_user_op"(%r) : (vector<1x1xf32>) -> () - gpu.return +gpu.func @vector_shapecast_rank_increasing(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} + : () -> (vector<16xf32>) + %cast = vector.shape_cast %cst + { + layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16xf32> to vector<1x16xf32> + gpu.yield %cast : vector<1x16xf32> } + "some_user_op"(%r) : (vector<1x1xf32>) -> () + gpu.return } -// ----- + // CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing( // CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) { // CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32> // CHECK: } // CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32> -gpu.module @xevm_module { - gpu.func @vector_shapecast_rank_reducing(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} - : () -> (vector<1x16xf32>) - %cast = vector.shape_cast %cst - { - layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, - layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]> - } - : vector<1x16xf32> to vector<16xf32> - gpu.yield %cast : vector<16xf32> - } - "some_user_op"(%r) : (vector<1xf32>) -> () - gpu.return +gpu.func @vector_shapecast_rank_reducing(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + : () -> (vector<1x16xf32>) + %cast = vector.shape_cast %cst + { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]> + } + : vector<1x16xf32> to vector<16xf32> + gpu.yield %cast : vector<16xf32> } + "some_user_op"(%r) : (vector<1xf32>) -> () + gpu.return } -// ----- + // NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand. // // CHECK-LABEL: gpu.func @vector_shapecast_unsupported @@ -634,21 +589,400 @@ gpu.module @xevm_module { // CHECK: } // CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> () // CHECK: gpu.return -gpu.module @xevm_module { - gpu.func @vector_shapecast_unsupported(%laneid: index) { - %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { - %cst = "some_op"() - {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> } - : () -> (vector<16xf32>) - %cast = vector.shape_cast %cst - { - layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, - layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - } - : vector<16xf32> to vector<1x16xf32> - gpu.yield %cast : vector<1x16xf32> +gpu.func @vector_shapecast_unsupported(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) { + %cst = "some_op"() + {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> } + : () -> (vector<16xf32>) + %cast = vector.shape_cast %cst + { + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16xf32> to vector<1x16xf32> + gpu.yield %cast : vector<1x16xf32> + } + "some_user_op"(%r) : (vector<1x1xf32>) -> () + gpu.return +} + + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted +// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x16xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> () +gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_def"() : () -> (vector<24x16xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<24x16xf32> to vector<8x16xf32> + gpu.yield %1 : vector<8x16xf32> + } + "some_use"(%r) : (vector<8x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_non_distributed +// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x1xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x1xf32>, vector<24x1xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> () +gpu.func @vector_extract_strided_slice_non_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_def"() : () -> (vector<24x1xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 1], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<24x1xf32> to vector<8x1xf32> + gpu.yield %1 : vector<8x1xf32> + } + "some_use"(%r) : (vector<8x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_inner_distributed +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x4xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x64xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x64xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [8, 3], sizes = [8, 1], strides = [1, 1]} : vector<24x4xf32> to vector<8x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> () +gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) { + %0 = "some_def"() : () -> (vector<24x64xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [8, 48], sizes = [8, 16], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<24x64xf32> to vector<8x16xf32> + gpu.yield %1 : vector<8x16xf32> + } + "some_use"(%r) : (vector<8x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_outer_distributed +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<32x16xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32> +// CHECK: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32> +// CHECK-NEXT: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32> +// CHECK-NEXT: "some_use"(%[[T2]]) : (vector<1x16xf32>) -> () +gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) { + %0 = "some_def"() : () -> (vector<32x16xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]> + } + : vector<32x16xf32> to vector<16x16xf32> + gpu.yield %1 : vector<16x16xf32> + } + "some_use"(%r) : (vector<1x16xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d +// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) { +// CHECK: %[[S:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<32xf32>, vector<64xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1 +// CHECK-SAME: {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<2xf32>) -> () +gpu.func @vector_extract_strided_slice_1d(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [32], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<64xf32> to vector<32xf32> + gpu.yield %1 : vector<32xf32> + } + "some_use"(%r) : (vector<2xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_offset +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.extract_strided_slice +gpu.func @vector_extract_strided_slice_unsopported_offset(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [3], sizes = [32], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<64xf32> to vector<32xf32> + gpu.yield %1 : vector<32xf32> + } + "some_use"(%r) : (vector<2xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_extract_strided_slice_unsopported_source +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.extract_strided_slice +gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) { + %0 = "some_def"() : () -> (vector<54xf32>) + %1 = vector.extract_strided_slice %0 { offsets = [0], sizes = [32], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<54xf32> to vector<32xf32> + gpu.yield %1 : vector<32xf32> + } + "some_use"(%r) : (vector<2xf32>) -> () + gpu.return +} + + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted +// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x16xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16x16xf32>, vector<64x16xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> () +gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) { + %0 = "some_def"() : () -> (vector<16x16xf32>) + %1 = "some_def"() : () -> (vector<64x16xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16x16xf32> into vector<64x16xf32> + gpu.yield %2 : vector<64x16xf32> + } + "some_use"(%r) : (vector<64x1xf32>) -> () + gpu.return +} + + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_non_distributed +// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x1xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x1xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> () +gpu.func @vector_insert_strided_slice_non_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) { + %0 = "some_def"() : () -> (vector<16x1xf32>) + %1 = "some_def"() : () -> (vector<64x1xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> } - "some_user_op"(%r) : (vector<1x1xf32>) -> () - gpu.return + : vector<16x1xf32> into vector<64x1xf32> + gpu.yield %2 : vector<64x1xf32> } + "some_use"(%r) : (vector<64x1xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed +// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x32xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x32xf32>, vector<16x16xf32>, vector<64x32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [24, 1], strides = [1, 1]} : vector<16x1xf32> into vector<64x2xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x2xf32>) -> () +gpu.func @vector_insert_strided_slice_inner_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x2xf32>) { + %0 = "some_def"() : () -> (vector<16x16xf32>) + %1 = "some_def"() : () -> (vector<64x32xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [24, 16], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } + : vector<16x16xf32> into vector<64x32xf32> + gpu.yield %2 : vector<64x32xf32> + } + "some_use"(%r) : (vector<64x2xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_outer_distributed +// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3x32xf32>, vector<1x16xf32>, vector<3x32xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48x32xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48x32xf32>, vector<16x16xf32>, vector<48x32xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [2, 4], strides = [1, 1]} : vector<1x16xf32> into vector<3x32xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3x32xf32>) -> () +gpu.func @vector_insert_strided_slice_outer_distributed(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3x32xf32>) { + %0 = "some_def"() : () -> (vector<16x16xf32>) + %1 = "some_def"() : () -> (vector<48x32xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [32, 4], strides = [1, 1], + layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]> + } + : vector<16x16xf32> into vector<48x32xf32> + gpu.yield %2 : vector<48x32xf32> + } + "some_use"(%r) : (vector<3x32xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_1d +// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>, vector<1xf32>, vector<3xf32>) { +// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16xf32> +// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<48xf32> +// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<48xf32>, vector<16xf32>, vector<48xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2 +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xf32> into vector<3xf32> +// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<3xf32>) -> () +gpu.func @vector_insert_strided_slice_1d(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<16xf32>) + %1 = "some_def"() : () -> (vector<48xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<16xf32> into vector<48xf32> + gpu.yield %2 : vector<48xf32> + } + "some_use"(%r) : (vector<3xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_source +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.insert_strided_slice +gpu.func @vector_insert_strided_slice_unsupported_source(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<8xf32>) + %1 = "some_def"() : () -> (vector<48xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [16], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<8xf32> into vector<48xf32> + gpu.yield %2 : vector<48xf32> + } + "some_use"(%r) : (vector<3xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_insert_strided_slice_unsupported_offset +// CHECK: %{{.*}} = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<3xf32>) { +// CHECK: } +// CHECK-NOT: %{{.*}} = vector.insert_strided_slice +gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) { + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<16xf32>) + %1 = "some_def"() : () -> (vector<48xf32>) + %2 = vector.insert_strided_slice %0, %1 { offsets = [3], strides = [1], + layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, + layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> + } + : vector<16xf32> into vector<48xf32> + gpu.yield %2 : vector<48xf32> + } + "some_use"(%r) : (vector<3xf32>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane +// CHECK-SAME: (%[[ARG0:.*]]: index) { +// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<1xf16>) +// CHECK: %[[DEF:.*]] = "some_def"() +// CHECK: %[[BCAST_INNER:.*]] = vector.broadcast %[[DEF]] +// CHECK: gpu.yield %[[BCAST_INNER]], %[[DEF]] +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[R]]#1 : vector<1xf16> to vector<16x1xf16> +// CHECK: "some_use"(%[[BCAST]]) +gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%laneid: index) { + + %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) { + + %1 = "some_def"() : () -> vector<16xf16> + + %2 = vector.broadcast %1 { + layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } : vector<16xf16> to vector<16x16xf16> + + gpu.yield %2 : vector<16x16xf16> + } + "some_use"(%r) : (vector<16x1xf16>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case +// CHECK-SAME: (%[[ARG0:.*]]: index) +// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<16x1xf16>) +// CHECK: %[[DEF:.*]] = "some_def"() : () -> vector<16x1xf16> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] +// CHECK-SAME: : vector<16x1xf16> to vector<16x16xf16> +// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, vector<16x1xf16> +// CHECK: "some_use"(%[[R]]#1) : (vector<16x1xf16>) -> () +gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: index) { + %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) { + %1 = "some_def"() : () -> vector<16x1xf16> + %2 = vector.broadcast %1 { + layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, + layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> + } : vector<16x1xf16> to vector<16x16xf16> + gpu.yield %2: vector<16x16xf16> + } + "some_use"(%0) : (vector<16x1xf16>) -> () + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector +// CHECK-SAME: (%[[ARG0:.*]]: index) +// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, f16) +// CHECK: %[[DEF:.*]] = "some_def"() +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> +// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16 +// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x1xf16> +// CHECK: "some_use"(%[[RESULT]]) +gpu.func +@vector_shape_cast_scalar_to_vector(%arg0: index) { + %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) { + %1 = "some_def"() : () -> f16 + %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> + gpu.yield %2 : vector<16x16xf16> + } + "some_use"(%0) : (vector<16x1xf16>) -> () + gpu.return +} + } diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index 8fd3cca..e5e3d2a 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -271,11 +271,11 @@ gpu.module @xevm_module{ // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C8:.*]] = arith.constant 8 : index // CHECK: %[[LANE_ID:.*]] = gpu.lane_id -// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C8]] -// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C8]] -// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C2]] -// CHECK: %[[REMU3:.*]] = index.remu %[[REMU2]], %[[C2]] -// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C8]] +// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C8]] +// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C8]] +// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C2]] +// CHECK: %[[REMU3:.*]] = arith.remui %[[REMU2]], %[[C2]] +// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C8]] // CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[REMU4]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32> // CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[REMU4]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index gpu.module @xevm_module{ @@ -294,13 +294,13 @@ gpu.module @xevm_module{ // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[LANE_ID:.*]] = gpu.lane_id -// CHECK: %[[REMU1:.*]] = index.remu %[[LANE_ID]], %[[C4]] -// CHECK: %[[DIVU:.*]] = index.divu %[[LANE_ID]], %[[C4]] -// CHECK: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C4]] -// CHECK: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2]] -// CHECK: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C8]] -// CHECK: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C4]] -// CHECK: %[[ADD:.*]] = index.add %[[REMU4]], %[[C1]] +// CHECK: %[[REMU1:.*]] = arith.remui %[[LANE_ID]], %[[C4]] +// CHECK: %[[DIVU:.*]] = arith.divui %[[LANE_ID]], %[[C4]] +// CHECK: %[[REMU2:.*]] = arith.remui %[[DIVU]], %[[C4]] +// CHECK: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C2]] +// CHECK: %[[REMU3:.*]] = arith.remui %[[MUL]], %[[C8]] +// CHECK: %[[REMU4:.*]] = arith.remui %[[REMU1]], %[[C4]] +// CHECK: %[[ADD:.*]] = arith.addi %[[REMU4]], %[[C1]] // CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[REMU3]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32> // CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[REMU3]], %[[ADD]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index gpu.module @xevm_module{ @@ -330,3 +330,64 @@ gpu.module @xevm_module{ gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) { +gpu.module @xevm_module{ + gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) { + %c0 = arith.constant 0 : index + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16> + %tdesc0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %0 = xegpu.load_nd %tdesc0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16> + %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16> + // CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f16 to vector<16xf16> + %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16> + xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) { +gpu.module @xevm_module{ + gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) { + %c0 = arith.constant 0 : index + %mask = vector.constant_mask [16] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: vector<16xi1> + %1 = xegpu.load %arg0[%c0], %mask {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: memref<16xf16>, index, vector<16xi1> -> vector<16xf16> + + %11 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16> + %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16> + // CHECK-NOT: vector.broadcast + // CHECK-NOT: vector.shape_cast + + %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK: xegpu.store_nd {{.*}}, {{.*}}[{{.*}}, {{.*}}] + // CHECK-SAME: : vector<16xf16>, !xegpu.tensor_desc<16x16xf16> + + xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) { +gpu.module @xevm_module{ + gpu.func @vector_shape_cast_scalar_to_vector(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) { + %c0 = arith.constant 0 : index + %9 = gpu.block_id x + %10 = arith.index_cast %9 : index to i16 + %11 = arith.bitcast %10 : i16 to f16 + // CHECK: vector.broadcast {{.*}} : f16 to vector<16xf16> + %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16> + %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> + -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + gpu.return + } +} + + diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir index 726b674..dce4a41 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir @@ -71,3 +71,87 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index // expected-note {{target op}} + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected a gpu.launch op, but got: arith.constant}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Requires exactly one targetOp handle (got 2)}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected threads argument to consist of three values (got 2)}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_c +func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + // expected-note@below {{load op}} + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value + // expected-error@below {{Load op is not contained in a scf.for loop.}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index bd6a792..561034f 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -121,6 +121,25 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @set_desc_layout_slice +func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK-LABEL: @set_op_layout_attr_result_default_index func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> @@ -212,6 +231,25 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @set_op_layout_attr_result_slice +func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) { + // CHECK: = arith.extf + // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>} + %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op + transform.yield + } +} + +// ----- + // CHECK-LABEL: @set_op_layout_attr_operand_minimal func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) { %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> @@ -230,6 +268,7 @@ module attributes {transform.with_named_sequence} { transform.yield } } + // ----- // CHECK-LABEL: @set_op_layout_attr_operand1 @@ -252,3 +291,219 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads +func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads_param +func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + %th1 = transform.param.constant 4 : i64 -> !transform.param<i64> + transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64> + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a +func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[C32:.+]] = arith.constant 32 : index + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]] + // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]] + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2 +func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[C64:.+]] = arith.constant 64 : index + // CHECK: %[[C32:.+]] = arith.constant 32 : index + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]] + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]] + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]] + %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %nb = transform.param.constant 2 : i64 -> !transform.param<i64> + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @convert_layout_a +func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c0 = arith.constant 0 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16> + // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] + // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]> + // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V2]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.convert_layout %{{.*}} + transform.xegpu.convert_layout %1 + input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16] + target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16] + : (!transform.any_value) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @convert_layout_a_sg_param +func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c0 = arith.constant 0 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16> + // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] + // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]> + // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V2]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64> + // CHECK: transform.xegpu.convert_layout %{{.*}} + transform.xegpu.convert_layout %1 + input_sg_layout = [%layout0, 4] input_sg_data = [32, 32] input_inst_data = [32, 16] + target_sg_layout = [%layout0, 4] target_sg_data = [32, 32] target_inst_data = [8, 16] + : (!transform.any_value, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir index 02c5f71..8ce6d4d 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir @@ -3,10 +3,10 @@ gpu.module @test { gpu.func @slice_attr() -> vector<128xindex> { // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]] - // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]] - // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]] - // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]] + // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C8:.*]] + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU]], %[[C4:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]] + // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]] // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex> // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex> // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex> @@ -16,11 +16,10 @@ gpu.module @test { gpu.func @nested_slice_attr() -> vector<128xindex> { // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]] - // CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]] - // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]] - // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]] - // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]] + // CHECK-DAG: %[[DIVU2:.*]] = arith.divui %[[SGID]], %[[C8:.*]] + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[DIVU2]], %[[C4:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C32:.*]] + // CHECK-DAG: %[[MOD:.*]] = arith.remui %[[MUL]], %[[C128:.*]] // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex> // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex> // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex> @@ -29,4 +28,3 @@ gpu.module @test { } } - diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index 01134d8e..4829af3 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -16,18 +16,18 @@ gpu.module @test_round_robin_assignment { gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) { // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index // CHECK: %[[C4:.*]] = arith.constant 4 : index - // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]] - // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]] + // CHECK: %[[IDX:.*]] = arith.remui %[[SGID]], %[[C4]] + // CHECK: %[[IDY_DIV:.*]] = arith.divui %[[SGID]], %[[C4]] // CHECK: %[[C8:.*]] = arith.constant 8 : index - // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]] + // CHECK: %[[IDY:.*]] = arith.remui %[[IDY_DIV]], %[[C8]] // CHECK: %[[C16:.*]] = arith.constant 16 : index - // CHECK: %[[LY:.*]] = index.mul %[[IDY]], %[[C16]] + // CHECK: %[[LY:.*]] = arith.muli %[[IDY]], %[[C16]] // CHECK: %[[C64:.*]] = arith.constant 64 : index - // CHECK: %[[LX:.*]] = index.mul %[[IDX]], %[[C64]] + // CHECK: %[[LX:.*]] = arith.muli %[[IDX]], %[[C64]] // CHECK: %[[C128:.*]] = arith.constant 128 : index - // CHECK: %[[OFFY:.*]] = index.remu %[[LY]], %[[C128]] + // CHECK: %[[OFFY:.*]] = arith.remui %[[LY]], %[[C128]] // CHECK: %[[C64_1:.*]] = arith.constant 64 : index - // CHECK: %[[OFFX:.*]] = index.remu %[[LX]], %[[C64_1]] + // CHECK: %[[OFFX:.*]] = arith.remui %[[LX]], %[[C64_1]] // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>> diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 84ce80f4..c95c640 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -90,30 +90,27 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: non_splat_constant gpu.func @non_splat_constant() { - // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex> + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{.*}}0{{.*}}, {{.*}}16{{.*}}> : vector<2x1xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[SGID]], %[[C1:.*]] - // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C1:.*]] - // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C8:.*]] - // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2:.*]] - // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C32:.*]] - // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C1:.*]] - // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index - // CHECK-DAG: %[[REMU5:.*]] = index.remu %[[ADD16]], %[[C32:.*]] - // CHECK-DAG: %[[REMU6:.*]] = index.remu %[[REMU1]], %[[C1:.*]] - // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index - // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index - // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex> - // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex> - // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU5]], %[[C16:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index - // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU6]], %[[C0:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index - // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex> - // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex> + // CHECK-DAG: %[[T1:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index + // CHECK-DAG: %[[T2:.*]] = arith.muli %[[T1]], %[[C2:.*]] : index + // CHECK-DAG: %[[T3:.*]] = arith.remui %[[T2]], %[[C32:.*]] : index + // CHECK-DAG: %[[T4:.*]] = arith.addi %[[T2]], %[[C16:.*]] : index + // CHECK-DAG: %[[T5:.*]] = arith.remui %[[T4]], %[[C32_6:.*]] : index + // CHECK-DAG: %[[T6:.*]] = arith.muli %[[T3]], %[[C16_10:.*]] : index + // CHECK-DAG: %[[T7:.*]] = arith.addi %[[C0_11:.*]], %[[T6]] : index + // CHECK-DAG: %[[T8:.*]] = arith.muli %[[C0_4:.*]], %[[C0_9:.*]] : index + // CHECK-DAG: %[[T9:.*]] = arith.addi %[[T7]], %[[T8]] : index + // CHECK-DAG: %[[T10:.*]] = vector.broadcast %[[T9]] : index to vector<2x1xindex> + // CHECK-DAG: %[[T11:.*]] = arith.addi %[[CST]], %[[T10]] : vector<2x1xindex> + // CHECK-DAG: %[[T12:.*]] = arith.muli %[[T5]], %[[C16_10:.*]] : index + // CHECK-DAG: %[[T13:.*]] = arith.addi %[[C0_12:.*]], %[[T12]] : index + // CHECK-DAG: %[[T14:.*]] = arith.muli %[[C0_8:.*]], %[[C0_9:.*]] : index + // CHECK-DAG: %[[T15:.*]] = arith.addi %[[T13]], %[[T14]] : index + // CHECK-DAG: %[[T16:.*]] = vector.broadcast %[[T15]] : index to vector<2x1xindex> + // CHECK-DAG: %[[T17:.*]] = arith.addi %[[CST]], %[[T16]] : vector<2x1xindex> %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } @@ -130,5 +127,20 @@ gpu.module @test_distribution { %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32> gpu.return } -} + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + gpu.return + } + + gpu.func @vector_create_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + gpu.return + } +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 4fbb566c..69eb8ce 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -27,17 +27,17 @@ gpu.module @test_distribution { //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> //CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index //CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - //CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4]] - //CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4]] + //CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4]] + //CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4]] //CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index - //CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8]] + //CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8]] //CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index - //CHECK-DAG: %[[L_OFF_Y:.*]] = index.mul %[[SGIDY]], %[[C32]] - //CHECK-DAG: %[[L_OFF_X:.*]] = index.mul %[[SGIDX]], %[[C32]] + //CHECK-DAG: %[[L_OFF_Y:.*]] = arith.muli %[[SGIDY]], %[[C32]] : index + //CHECK-DAG: %[[L_OFF_X:.*]] = arith.muli %[[SGIDX]], %[[C32_1:.*]] : index //CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index - //CHECK-DAG: %[[OFF_Y:.*]] = index.remu %[[L_OFF_Y]], %[[C256]] + //CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[L_OFF_Y]], %[[C256]] : index //CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index - //CHECK-DAG: %[[OFF_X:.*]] = index.remu %[[L_OFF_X]], %[[C128]] + //CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[L_OFF_X]], %[[C128]] : index //CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32> %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -293,7 +293,7 @@ gpu.module @test_distribution { %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16> %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex> %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1> - xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, l1_hint = #xegpu.cache_hint<cached>} @@ -321,18 +321,18 @@ gpu.module @test_distribution { //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index //CHECK: [[c4:%.+]] = arith.constant 4 : index - //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]] - //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]] + //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index + //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index //CHECK: [[c2:%.+]] = arith.constant 2 : index - //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]] + //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]], [[c32]] + //CHECK: [[l_off_y:%.+]] = arith.muli [[sgidy]], [[c32]] : index //CHECK: [[c32_0:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_x:%.+]] = index.mul [[sgidx]], [[c32_0]] + //CHECK: [[l_off_x:%.+]] = arith.muli [[sgidx]], [[c32_0]] : index //CHECK: [[c64:%.+]] = arith.constant 64 : index - //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]] + //CHECK: [[off_y:%.+]] = arith.remui [[l_off_y]], [[c64]] : index //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]] + //CHECK: [[off_x:%.+]] = arith.remui [[l_off_x]], [[c128]] : index //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32> %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32> @@ -346,18 +346,18 @@ gpu.module @test_distribution { //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index //CHECK: [[c4:%.+]] = arith.constant 4 : index - //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]] - //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]] + //CHECK: [[sgidx:%.+]] = arith.remui [[sgid]], [[c4]] : index + //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgid]], [[c4]] : index //CHECK: [[c2:%.+]] = arith.constant 2 : index - //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]] + //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c2]] : index //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]], [[c32]] + //CHECK: [[l_off_y:%.+]] = arith.muli [[sgidy]], [[c32]] : index //CHECK: [[c32_0:%.+]] = arith.constant 32 : index - //CHECK: [[l_off_x:%.+]] = index.mul [[sgidx]], [[c32_0]] + //CHECK: [[l_off_x:%.+]] = arith.muli [[sgidx]], [[c32_0]] : index //CHECK: [[c64:%.+]] = arith.constant 64 : index - //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]] + //CHECK: [[off_y:%.+]] = arith.remui [[l_off_y]], [[c64]] : index //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]] + //CHECK: [[off_x:%.+]] = arith.remui [[l_off_x]], [[c128]] : index //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32> %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32> @@ -409,14 +409,14 @@ gpu.module @test_distribution { gpu.func @vector_step_op_slice_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index //CHECK: [[c8:%.+]] = arith.constant 8 : index - //CHECK: [[sgidx:%.+]] = index.remu [[sgId]], [[c8]] - //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgId]], [[c8]] + //CHECK: [[sgidx:%.+]] = arith.remui [[sgId]], [[c8]] : index + //CHECK: [[sgidy_tmp:%.+]] = arith.divui [[sgId]], [[c8]] : index //CHECK: [[c4:%.+]] = arith.constant 4 : index - //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c4]] + //CHECK: [[sgidy:%.+]] = arith.remui [[sgidy_tmp]], [[c4]] : index //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[LY:%.+]] = index.mul [[sgidy]], [[c32]] + //CHECK: [[LY:%.+]] = arith.muli [[sgidy]], [[c32]] : index //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[MODY:%.+]] = index.remu [[LY]], [[c128]] + //CHECK: [[MODY:%.+]] = arith.remui [[LY]], [[c128]] : index //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex> //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex> @@ -427,11 +427,11 @@ gpu.module @test_distribution { gpu.func @vector_step_op_layout_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index //CHECK: [[c16:%.+]] = arith.constant 16 : index - //CHECK: [[sgidx:%.+]] = index.remu [[sgId]], [[c16]] + //CHECK: [[sgidx:%.+]] = arith.remui [[sgId]], [[c16]] : index //CHECK: [[c8:%.+]] = arith.constant 8 : index - //CHECK: [[LOCALY:%.+]] = index.mul [[sgidx]], [[c8]] + //CHECK: [[LOCALY:%.+]] = arith.muli [[sgidx]], [[c8]] : index //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]] + //CHECK: [[MODY:%.+]] = arith.remui [[LOCALY]], [[c128]] : index //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex> //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex> //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex> @@ -479,18 +479,15 @@ gpu.module @test_distribution { // CHECK-LABEL: non_splat_constant_2D gpu.func @non_splat_constant_2D() { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex> - // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %{{.*}} - // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %{{.*}} - // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %{{.*}} - // CHECK-DAG: %[[IDY:.*]] = index.remu %[[SGIDY]], %{{.*}} - // CHECK-DAG: %[[IDX:.*]] = index.remu %[[SGIDX]], %{{.*}} - // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index - // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index - // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index - // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[STRIDEROW]] : index - // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex> - // CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex> + // CHECK-DAG: %[[T0:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[T1:.*]] = arith.remui %[[T0]], %[[C32:.*]] : index + // CHECK-DAG: %[[T2:.*]] = arith.remui %[[T1]], %[[C32_4:.*]] : index + // CHECK-DAG: %[[T3:.*]] = arith.muli %[[T2]], %[[C16:.*]] : index + // CHECK-DAG: %[[T4:.*]] = arith.addi %[[C0_8:.*]], %[[T3]] : index + // CHECK-DAG: %[[T5:.*]] = arith.muli %[[C0_6:.*]], %[[C0_7:.*]] : index + // CHECK-DAG: %[[T6:.*]] = arith.addi %[[T4]], %[[T5]] : index + // CHECK-DAG: %[[T7:.*]] = vector.broadcast %[[T6]] : index to vector<1x1xindex> + // CHECK-DAG: %[[T8:.*]] = arith.addi %[[CST]], %[[T7]] : vector<1x1xindex> %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex> gpu.return } @@ -499,13 +496,13 @@ gpu.module @test_distribution { gpu.func @non_splat_constant_2D_non_unit_dim() { // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{\[}}{{\[}}0, 16{{\]}}, {{\[}}8, 24{{\]}}{{\]}}> : vector<2x2xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %{{.*}} - // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %{{.*}} - // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %{{.*}} - // CHECK-DAG: %[[MULY:.*]] = index.mul %[[SGIDY]], %[[C2:.*]] - // CHECK-DAG: %[[MULX:.*]] = index.mul %[[SGIDX]], %{{.*}} - // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]] - // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %{{.*}} + // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %{{.*}} + // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %{{.*}} + // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %{{.*}} + // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[SGIDY]], %[[C2:.*]] : index + // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[SGIDX]], %{{.*}} : index + // CHECK-DAG: %[[REMU_Y:.*]] = arith.remui %[[MULY]], %[[C8:.*]] : index + // CHECK-DAG: %[[REMU_X:.*]] = arith.remui %[[MULX]], %{{.*}} : index // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %{{.*}} : index // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index // CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index @@ -529,8 +526,8 @@ gpu.module @test_distribution { gpu.func @non_splat_constant() { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %{{.*}} - // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[REMU]], %{{.*}} + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %{{.*}} + // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[REMU]], %{{.*}} // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C16:.*]] : index // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex> @@ -547,4 +544,106 @@ gpu.module @test_distribution { %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex> gpu.return } + + // CHECK-LABEL: vector_mask_1D + gpu.func @vector_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index + // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[MUL]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index + // CHECK-DAG: %[[COL:.*]] = arith.muli %[[SGIDX]], %[[C32:.*]] : index + // CHECK-DAG: %[[MODROW:.*]] = arith.remui %[[ROW]], %[[C256:.*]] : index + // CHECK-DAG: %[[MODCOL:.*]] = arith.remui %[[COL]], %[[C128:.*]] : index + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + gpu.return + } + + // CHECK-LABEL: vector_create_mask_1D + gpu.func @vector_create_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = arith.remui %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] + // CHECK-DAG: %[[REMU2:.*]] = arith.remui %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %cst8 = arith.constant 8 : index + %constant_mask = vector.create_mask %cst8 {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_create_mask_2D + gpu.func @vector_create_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] + // CHECK-DAG: %[[COL:.*]] = arith.muli %[[SGIDX]], %[[C32:.*]] + // CHECK-DAG: %[[MODROW:.*]] = arith.remui %[[ROW]], %[[C256:.*]] + // CHECK-DAG: %[[MODCOL:.*]] = arith.remui %[[COL]], %[[C128:.*]] + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C0:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %cst16 = arith.constant 16 : index + %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + gpu.return + } + + // CHECK-LABEL: distribute_load_slice_attr + gpu.func @distribute_load_slice_attr() { + %2 = memref.alloca() {alignment = 1024} : memref<4096xf32> + %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1> + + // CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>}> + // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>} : + // CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32> + %3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32> + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32xf32> to vector<32x32xf32> + %4 = vector.broadcast %3 {layout_result_0 = + #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32> + gpu.return + } + + // CHECK-LABEL: load_nd_tdesc_with_anchor_layout + gpu.func @load_nd_tdesc_with_anchor_layout(%src: memref<256x128xf32>) { + //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK: xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<inst_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1]>}> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32> + %load = xegpu.load_nd %tdesc[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16],lane_layout = [1, 16], lane_data = [1, 1]>}> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 5ce3d1d..a8015cc 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -5,13 +5,13 @@ gpu.module @test_1_1_assignment { // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[REMUX:.*]] = index.remu %[[SGID]], %[[C4:.*]] - // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C4:.*]] - // CHECK-DAG: %[[REMUY:.*]] = index.remu %[[DIVU]], %[[C8:.*]] - // CHECK-DAG: %[[MULY:.*]] = index.mul %[[REMUY]], %[[C32:.*]] - // CHECK-DAG: %[[MULX:.*]] = index.mul %[[REMUX]], %[[C32:.*]] - // CHECK-DAG: %[[MODY:.*]] = index.remu %[[MULY]], %[[C256:.*]] - // CHECK-DAG: %[[MODX:.*]] = index.remu %[[MULX]], %[[C128:.*]] + // CHECK-DAG: %[[REMUX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[REMUY:.*]] = arith.remui %[[DIVU]], %[[C8:.*]] + // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[REMUY]], %[[C32:.*]] + // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[REMUX]], %[[C32:.*]] + // CHECK-DAG: %[[MODY:.*]] = arith.remui %[[MULY]], %[[C256:.*]] + // CHECK-DAG: %[[MODX:.*]] = arith.remui %[[MULX]], %[[C128:.*]] // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[MODY]], %[[MODX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> @@ -22,13 +22,13 @@ gpu.module @test_1_1_assignment { // CHECK-SAME: %[[ARG_0:.*]]: memref<3x256x128xf32> gpu.func @create_nd_tdesc_from_higher_rank_memref(%src: memref<3x256x128xf32>) { // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[REMUX:.*]] = index.remu %[[SGID]], %[[C4:.*]] - // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C4:.*]] - // CHECK-DAG: %[[REMUY:.*]] = index.remu %[[DIVU]], %[[C8:.*]] - // CHECK-DAG: %[[MULY:.*]] = index.mul %[[REMUY]], %[[C32:.*]] - // CHECK-DAG: %[[MULX:.*]] = index.mul %[[REMUX]], %[[C32:.*]] - // CHECK-DAG: %[[MODY:.*]] = index.remu %[[MULY]], %[[C256:.*]] - // CHECK-DAG: %[[MODX:.*]] = index.remu %[[MULX]], %[[C128:.*]] + // CHECK-DAG: %[[REMUX:.*]] = arith.remui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[DIVU:.*]] = arith.divui %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[REMUY:.*]] = arith.remui %[[DIVU]], %[[C8:.*]] + // CHECK-DAG: %[[MULY:.*]] = arith.muli %[[REMUY]], %[[C32:.*]] + // CHECK-DAG: %[[MULX:.*]] = arith.muli %[[REMUX]], %[[C32:.*]] + // CHECK-DAG: %[[MODY:.*]] = arith.remui %[[MULY]], %[[C256:.*]] + // CHECK-DAG: %[[MODX:.*]] = arith.remui %[[MULX]], %[[C128:.*]] // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][1, %[[MODY]], %[[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> %tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py index 8f60088..e09720a 100644 --- a/mlir/test/Examples/NVGPU/Ch0.py +++ b/mlir/test/Examples/NVGPU/Ch0.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 0 : Hello World @@ -33,7 +37,7 @@ def main(alpha): # + operator generates arith.addi myValue = alpha + tidx # Print from a GPU thread - gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue]) + gpu.printf("GPU thread %llu has %llu\n", tidx, myValue) # 3. Call the GPU kernel kernel() @@ -43,8 +47,24 @@ alpha = 100 # 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function. main(alpha) - # CHECK: GPU thread 0 has 100 # CHECK: GPU thread 1 has 101 # CHECK: GPU thread 2 has 102 # CHECK: GPU thread 3 has 103 + +# DUMPIR: func.func @main(%arg0: index) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[C0_I32:.*]] = arith.constant 0 : i32 +# DUMPIR: %[[C1:.*]] = arith.constant 1 : index +# DUMPIR: %[[C1_0:.*]] = arith.constant 1 : index +# DUMPIR: %[[C1_1:.*]] = arith.constant 1 : index +# DUMPIR: %[[C4:.*]] = arith.constant 4 : index +# DUMPIR: %[[C1_2:.*]] = arith.constant 1 : index +# DUMPIR: %[[C1_3:.*]] = arith.constant 1 : index +# DUMPIR: gpu.launch blocks(%arg1, %arg2, %arg3) in (%arg7 = %[[C1]], %arg8 = %[[C1_0]], %arg9 = %[[C1_1]]) threads(%arg4, %arg5, %arg6) in (%arg10 = %[[C4]], %arg11 = %[[C1_2]], %arg12 = %[[C1_3]]) dynamic_shared_memory_size %[[C0_I32]] { +# DUMPIR: %[[TIDX:.*]] = gpu.thread_id x +# DUMPIR: %[[MYVAL:.*]] = arith.addi %arg0, %[[TIDX]] : index +# DUMPIR: gpu.printf "GPU thread %llu has %llu\0A", %[[TIDX]], %[[MYVAL]] : index, index +# DUMPIR: gpu.terminator +# DUMPIR: } +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py index cfb48d5..6e44e4d 100644 --- a/mlir/test/Examples/NVGPU/Ch1.py +++ b/mlir/test/Examples/NVGPU/Ch1.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 1 : 2D Saxpy @@ -24,12 +28,12 @@ import numpy as np def saxpy(x, y, alpha): # 1. Use MLIR GPU dialect to allocate and copy memory token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], []) y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], []) t4 = gpu.memcpy(token_ty, [t3], x_dev, x) t5 = gpu.memcpy(token_ty, [t4], y_dev, y) - t6 = gpu.wait(token_ty, [t5]) + t6 = gpu.wait([t5]) # 2. Compute 2D SAXPY kernel @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1)) @@ -47,7 +51,7 @@ def saxpy(x, y, alpha): saxpy_kernel() t7 = gpu.memcpy(token_ty, [t6], y, y_dev) - gpu.wait(token_ty, [t7]) + gpu.wait([t7]) # 3. Pass numpy arrays to MLIR @@ -56,11 +60,32 @@ N = 32 alpha = 2.0 x = np.random.randn(M, N).astype(np.float32) y = np.ones((M, N), np.float32) + saxpy(x, y, alpha) -# 4. Verify MLIR with reference computation -ref = np.ones((M, N), np.float32) -ref += x * alpha -np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) -print("PASS") +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # 4. Verify MLIR with reference computation + ref = np.ones((M, N), np.float32) + ref += x * alpha + np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @saxpy(%[[ARG0:.*]]: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[WAIT0:.*]] = gpu.wait async +# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32> +# DUMPIR: %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32> +# DUMPIR: %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %[[ARG0]] : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %[[ARG1]] : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]] +# DUMPIR: %[[LD0:.*]] = memref.load %[[MEMREF]][%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: %[[LD1:.*]] = memref.load %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: %[[MUL:.*]] = arith.mulf %[[LD0]], %[[ARG2]] : f32 +# DUMPIR: %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32 +# DUMPIR: memref.store %[[ADD]], %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: gpu.terminator +# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %[[ARG1]], %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]] +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py index 729913c..aba610c 100644 --- a/mlir/test/Examples/NVGPU/Ch2.py +++ b/mlir/test/Examples/NVGPU/Ch2.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 2 : 2D Saxpy with TMA @@ -28,12 +32,12 @@ import numpy as np @NVDSL.mlir_func def saxpy(x, y, alpha): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], []) y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], []) t4 = gpu.memcpy(token_ty, [t3], x_dev, x) t5 = gpu.memcpy(token_ty, [t4], y_dev, y) - t6 = gpu.wait(token_ty, [t5]) + t6 = gpu.wait([t5]) x_tma = TMA([1, N], x.type) y_tma = TMA([1, N], y.type) @@ -74,7 +78,7 @@ def saxpy(x, y, alpha): saxpy_tma_kernel() t7 = gpu.memcpy(token_ty, [t6], y, y_dev) - gpu.wait(token_ty, [t7]) + gpu.wait([t7]) # 3. Pass numpy arrays to MLIR @@ -85,9 +89,46 @@ x = np.random.randn(M, N).astype(np.float32) y = np.ones((M, N), np.float32) saxpy(x, y, alpha) -# 4. Verify MLIR with reference computation -ref = np.ones((M, N), np.float32) -ref += x * alpha -np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) -print("PASS") +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # 4. Verify MLIR with reference computation + ref = np.ones((M, N), np.float32) + ref += x * alpha + np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01) + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @saxpy(%{{.*}}: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[WAIT0:.*]] = gpu.wait async +# DUMPIR: %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32> +# DUMPIR: %[[CAST:.*]] = memref.cast %[[MEMREF]] : memref<256x32xf32> to memref<*xf32> +# DUMPIR: %[[C1:.*]] = arith.constant 1 : index +# DUMPIR: %[[C32:.*]] = arith.constant 32 : index +# DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST]] box[%[[C1]], %[[C32]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[C0:.*]] = arith.constant 0 : index +# DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index +# DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_10:.*]] = arith.constant 0 : index +# DUMPIR: %[[C1_11:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_10]]], %[[C1_11]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_12]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C128:.*]] = arith.constant 128 : index +# DUMPIR: %[[VIEW_13:.*]] = memref.view %[[DSM1]][%[[C128]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %[[TMA0]][%{{.*}}, %{{.*}}], %[[MB]][%{{.*}}] to %[[VIEW]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%{{.*}}], %{{.*}}, predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_20:.*]] = arith.constant 0 : index +# DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index +# DUMPIR: %[[FALSE:.*]] = arith.constant false +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_20]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index +# DUMPIR: %[[LD0:.*]] = memref.load %[[VIEW]][%[[C0_21]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index +# DUMPIR: %[[LD1:.*]] = memref.load %[[VIEW_13]][%[[C0_22]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>> +# DUMPIR: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<256x32xf32> +# DUMPIR: %[[MEMCPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG1]], %{{.*}} : memref<256x32xf32>, memref<256x32xf32> +# DUMPIR: %{{.*}} = gpu.wait async [%[[MEMCPY3]]] +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py index eb96b11..fe11575 100644 --- a/mlir/test/Examples/NVGPU/Ch3.py +++ b/mlir/test/Examples/NVGPU/Ch3.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 3 : GEMM 128x128x64 with Tensor Core @@ -60,13 +64,13 @@ def tma_load( @NVDSL.mlir_func def gemm_128_128_64(a, b, d): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_dev, a) t6 = gpu.memcpy(token_ty, [t5], b_dev, b) - t7 = gpu.wait(token_ty, [t6]) + t7 = gpu.wait([t6]) sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B a_tma = TMA([128, 64], a.type, swizzle=sw) @@ -111,7 +115,7 @@ def gemm_128_128_64(a, b, d): gemm_tma_kernel() t8 = gpu.memcpy(token_ty, [t7], d, d_dev) - gpu.wait(None, [t8]) + gpu.wait([t8]) # Python pass arguments to MLIR @@ -123,7 +127,73 @@ b = np.random.randn(K, N).astype(np.float16) d = np.zeros((M, N), np.float32) gemm_128_128_64(a, b, d) -ref_d = a.astype(np.float16) @ b.astype(np.float16) -np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) -print("PASS") +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # Verify MLIR program with reference computation in python + ref_d = a.astype(np.float16) @ b.astype(np.float16) + np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @gemm_128_128_64(%{{.*}}: memref<128x64xf16>, %{{.*}}: memref<64x128xf16>, %[[ARG2:.*]]: memref<128x128xf32>) attributes {llvm.emit_c_interface} { +# DUMPIR: %[[C128:.*]] = arith.constant 128 : index +# DUMPIR: %[[C64:.*]] = arith.constant 64 : index +# DUMPIR: %[[TMA0:.*]] = nvgpu.tma.create.descriptor %{{.*}} box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[CAST1:.*]] = memref.cast %{{.*}} : memref<64x128xf16> to memref<*xf16> +# DUMPIR: %[[C64_5:.*]] = arith.constant 64 : index +# DUMPIR: %[[C64_6:.*]] = arith.constant 64 : index +# DUMPIR: %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST1]] box[%[[C64_5]], %[[C64_6]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[THREADID:.*]] = gpu.thread_id x +# DUMPIR: %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0:.*]] = arith.constant 0 : index +# DUMPIR: %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index +# DUMPIR: %[[C0_12:.*]] = arith.constant 0 : index +# DUMPIR: %[[C1_13:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MB]][%[[C0_12]]], %[[C1_13]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA0]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: nvgpu.tma.prefetch.descriptor %[[TMA1]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_14:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_14]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C16384:.*]] = arith.constant 16384 : index +# DUMPIR: %[[VIEW_15:.*]] = memref.view %[[DSM1]][%[[C16384]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_16:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW_17:.*]] = memref.view %[[DSM2]][%[[C0_16]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C16384_18:.*]] = arith.constant 16384 : index +# DUMPIR: %[[VIEW_19:.*]] = memref.view %[[DSM3]][%[[C16384_18]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DSM4:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C24576:.*]] = arith.constant 24576 : index +# DUMPIR: %[[VIEW_20:.*]] = memref.view %[[DSM4]][%[[C24576]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_21:.*]] = arith.constant 0 : index +# DUMPIR: %[[C32768:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_21]]], %[[C32768]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_22:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_23:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_24:.*]] = arith.constant 0 : index +# DUMPIR: nvgpu.tma.async.load %[[TMA0]][%[[C0_23]], %[[C0_24]]], %[[MB]][%[[C0_22]]] to %[[VIEW_17]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_25:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_26:.*]] = arith.constant 0 : index +# DUMPIR: %[[C0_27:.*]] = arith.constant 0 : index +# DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C0_26]], %[[C0_27]]], %[[MB]][%[[C0_25]]] to %[[VIEW_19]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_28:.*]] = arith.constant 0 : index +# DUMPIR: %[[C64_29:.*]] = arith.constant 64 : index +# DUMPIR: %[[C0_30:.*]] = arith.constant 0 : index +# DUMPIR: nvgpu.tma.async.load %[[TMA1]][%[[C64_29]], %[[C0_30]]], %[[MB]][%[[C0_28]]] to %[[VIEW_20]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_31:.*]] = arith.constant 0 : index +# DUMPIR: %[[C10000000:.*]] = arith.constant 10000000 : index +# DUMPIR: %[[FALSE:.*]] = arith.constant false +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_31]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>> +# DUMPIR: %[[WG_ACC:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[GEN0:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW]], %[[TMA0]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[GEN1:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_15]], %[[TMA1]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[MMA:.*]] = nvgpu.warpgroup.mma %[[GEN0]], %[[GEN1]], %[[WG_ACC]] {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> +# DUMPIR: nvgpu.warpgroup.mma.store %[[MMA]], %{{.*}} : <fragmented = vector<128x128xf32>> to memref<128x128xf32> +# DUMPIR: gpu.terminator +# DUMPIR: } +# DUMPIR: %[[CPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG2]], %{{.*}} : memref<128x128xf32>, memref<128x128xf32> +# DUMPIR: gpu.wait async [%[[CPY3]]] +# DUMPIR: return +# DUMPIR: } diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py index 0e3460f..dffafda 100644 --- a/mlir/test/Examples/NVGPU/Ch4.py +++ b/mlir/test/Examples/NVGPU/Ch4.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 4 : Multistage GEMM with Tensor Core @@ -259,13 +263,13 @@ def epilogue(D: WGMMAMatrix, d_dev): @NVDSL.mlir_func def gemm_multistage(a, b, d, num_stages): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_dev, a) t6 = gpu.memcpy(token_ty, [t5], b_dev, b) - t7 = gpu.wait(token_ty, [t6]) + t7 = gpu.wait([t6]) sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B a_tma = TMA([128, 64], a.type, swizzle=sw) @@ -297,7 +301,7 @@ def gemm_multistage(a, b, d, num_stages): gemm_multistage_kernel() t8 = gpu.memcpy(token_ty, [t7], d, d_dev) - gpu.wait(None, [t8]) + gpu.wait([t8]) # Python pass arguments to MLIR @@ -313,11 +317,153 @@ d = np.zeros((M, N), np.float32) gemm_multistage(a, b, d, num_stages=7) +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # Verify MLIR with reference computation + ref_d = a.astype(np.float16) @ b.astype(np.float16) + np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) -# Verify MLIR with reference computation -ref_d = a.astype(np.float16) @ b.astype(np.float16) -np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) - - -print("PASS") + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: func.func @gemm_multistage(%{{.*}}: memref<512x1024xf16>, %{{.*}}: memref<1024x256xf16>, %{{.*}}: memref<512x256xf32>) attributes {llvm.emit_c_interface} { +# DUMPIR: scf.if %{{.*}} { +# DUMPIR: %[[C0_INIT:.*]] = arith.constant 0 : index +# DUMPIR: %[[C7:.*]] = arith.constant 7 : index +# DUMPIR: %[[C1_INIT:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] { +# DUMPIR: %[[C1_MBAR:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %{{.*}}[%arg15], %[[C1_MBAR]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: } +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: } +# DUMPIR: %[[C0_PROLOGUE:.*]] = arith.constant 0 : index +# DUMPIR: %[[C6:.*]] = arith.constant 6 : index +# DUMPIR: %[[C1_PROLOGUE:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_PROLOGUE]] to %[[C6]] step %[[C1_PROLOGUE]] { +# DUMPIR: %[[BID_X_P:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_P:.*]] = gpu.block_id y +# DUMPIR: %[[C128_P1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMX_P:.*]] = arith.muli %[[BID_X_P]], %[[C128_P1]] : index +# DUMPIR: %[[C128_P2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMY_P:.*]] = arith.muli %[[BID_Y_P]], %[[C128_P2]] : index +# DUMPIR: %{{.*}} = gpu.thread_id x +# DUMPIR: %[[TID_X_P:.*]] = gpu.thread_id x +# DUMPIR: %[[C0_P:.*]] = arith.constant 0 : index +# DUMPIR: %[[PRED_P:.*]] = arith.cmpi eq, %[[TID_X_P]], %[[C0_P]] : index +# DUMPIR: %[[C16384_P1:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_P:.*]] = arith.muli %arg15, %[[C16384_P1]] : index +# DUMPIR: %[[C16384_P2:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_B_BASE_P:.*]] = arith.muli %arg15, %[[C16384_P2]] : index +# DUMPIR: %[[C114688:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B1_P:.*]] = arith.addi %[[OFF_B_BASE_P]], %[[C114688]] : index +# DUMPIR: %[[C8192:.*]] = arith.constant 8192 : index +# DUMPIR: %[[OFF_B2_P:.*]] = arith.addi %[[OFF_B1_P]], %[[C8192]] : index +# DUMPIR: %[[SMEM_A_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_P:.*]] = memref.view %[[SMEM_A_P]][%[[OFF_A_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B1_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B1_P:.*]] = memref.view %[[SMEM_B1_P]][%[[OFF_B1_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B2_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B2_P:.*]] = memref.view %[[SMEM_B2_P]][%[[OFF_B2_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C32768:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%arg15], %[[C32768]], predicate = %[[PRED_P]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C64_K_P:.*]] = arith.constant 64 : index +# DUMPIR: %[[K_COORD_P:.*]] = arith.muli %arg15, %[[C64_K_P]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD_P]], %[[DIMX_P]]], %{{.*}}[%arg15] to %[[VIEW_A_P]], predicate = %[[PRED_P]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_P]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B1_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C64_OFF:.*]] = arith.constant 64 : index +# DUMPIR: %[[DIMY_P_OFF:.*]] = arith.addi %[[DIMY_P]], %[[C64_OFF]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_P_OFF]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B2_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: } +# DUMPIR: %[[TID_X_LOOP:.*]] = gpu.thread_id x +# DUMPIR: %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[FALSE_LOOP:.*]] = arith.constant false +# DUMPIR: %[[C0_LOOP:.*]] = arith.constant 0 : index +# DUMPIR: %[[C16_LOOP:.*]] = arith.constant 16 : index +# DUMPIR: %[[C1_LOOP:.*]] = arith.constant 1 : index +# DUMPIR: %[[LOOP_RES:.*]]:2 = scf.for %arg15 = %[[C0_LOOP]] to %[[C16_LOOP]] step %[[C1_LOOP]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE_LOOP]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) { +# DUMPIR: %[[C7_L:.*]] = arith.constant 7 : index +# DUMPIR: %[[STAGE_L:.*]] = arith.remui %arg15, %[[C7_L]] : index +# DUMPIR: %[[C10M:.*]] = arith.constant 10000000 : index +# DUMPIR: nvgpu.mbarrier.try_wait.parity %{{.*}}[%[[STAGE_L]]], %arg17, %[[C10M]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C16384_L:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_L:.*]] = arith.muli %[[STAGE_L]], %[[C16384_L]] : index +# DUMPIR: %[[C114688_L:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B_L:.*]] = arith.addi %[[OFF_A_L]], %[[C114688_L]] : index +# DUMPIR: %[[SMEM_A_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_L:.*]] = memref.view %[[SMEM_A_L]][%[[OFF_A_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B_L:.*]] = memref.view %[[SMEM_B_L]][%[[OFF_B_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DESC_A_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_L]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[DESC_B_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_L]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[ACC_L:.*]] = nvgpu.warpgroup.mma %[[DESC_A_L]], %[[DESC_B_L]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[C6_NEXT:.*]] = arith.constant 6 : index +# DUMPIR: %[[ITER_NEXT:.*]] = arith.addi %arg15, %[[C6_NEXT]] : index +# DUMPIR: %[[C16_CMP:.*]] = arith.constant 16 : index +# DUMPIR: %[[IN_RANGE:.*]] = arith.cmpi ult, %[[ITER_NEXT]], %[[C16_CMP]] : index +# DUMPIR: %[[C0_CMP:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_THREAD0_L:.*]] = arith.cmpi eq, %[[TID_X_LOOP]], %[[C0_CMP]] : index +# DUMPIR: %[[DO_LOAD:.*]] = arith.andi %[[IN_RANGE]], %[[IS_THREAD0_L]] : i1 +# DUMPIR: %[[C6_STAGE:.*]] = arith.constant 6 : index +# DUMPIR: %[[STAGE_NEXT_L:.*]] = arith.addi %arg15, %[[C6_STAGE]] : index +# DUMPIR: %[[C7_MOD:.*]] = arith.constant 7 : index +# DUMPIR: %[[STAGE_LOAD:.*]] = arith.remui %[[STAGE_NEXT_L]], %[[C7_MOD]] : index +# DUMPIR: %[[BID_X_L:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_L:.*]] = gpu.block_id y +# DUMPIR: %[[C128_L1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMX_L:.*]] = arith.muli %[[BID_X_L]], %[[C128_L1]] : index +# DUMPIR: %[[C128_L2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMY_L:.*]] = arith.muli %[[BID_Y_L]], %[[C128_L2]] : index +# DUMPIR: %[[TID_X_L1:.*]] = gpu.thread_id x +# DUMPIR: %[[TID_X_L2:.*]] = gpu.thread_id x +# DUMPIR: %[[C16384_LA1:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA1]] : index +# DUMPIR: %[[C16384_LA2:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_B_BASE_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA2]] : index +# DUMPIR: %[[C114688_LOAD:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B1_LOAD:.*]] = arith.addi %[[OFF_B_BASE_LOAD]], %[[C114688_LOAD]] : index +# DUMPIR: %[[C8192_LOAD:.*]] = arith.constant 8192 : index +# DUMPIR: %[[OFF_B2_LOAD:.*]] = arith.addi %[[OFF_B1_LOAD]], %[[C8192_LOAD]] : index +# DUMPIR: %[[SMEM_A_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_LOAD:.*]] = memref.view %[[SMEM_A_LOAD]][%[[OFF_A_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B1_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B1_LOAD:.*]] = memref.view %[[SMEM_B1_LOAD]][%[[OFF_B1_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_B2_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B2_LOAD:.*]] = memref.view %[[SMEM_B2_LOAD]][%[[OFF_B2_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C32768_LOAD:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%[[STAGE_LOAD]]], %[[C32768_LOAD]], predicate = %[[DO_LOAD]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C64_K_LOAD:.*]] = arith.constant 64 : index +# DUMPIR: %[[K_COORD_LOAD:.*]] = arith.muli %[[STAGE_NEXT_L]], %[[C64_K_LOAD]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD_LOAD]], %[[DIMX_L]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_A_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_L]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B1_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C64_OFF_LOAD:.*]] = arith.constant 64 : index +# DUMPIR: %[[DIMY_L_OFF:.*]] = arith.addi %[[DIMY_L]], %[[C64_OFF_LOAD]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIMY_L_OFF]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B2_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C6_FLIP:.*]] = arith.constant 6 : index +# DUMPIR: %[[IS_STAGE6:.*]] = arith.cmpi eq, %[[STAGE_L]], %[[C6_FLIP]] : index +# DUMPIR: %[[TRUE:.*]] = arith.constant true +# DUMPIR: %[[PARITY_FLIP:.*]] = arith.xori %arg17, %[[TRUE]] : i1 +# DUMPIR: %[[NEW_PARITY:.*]] = arith.select %[[IS_STAGE6]], %[[PARITY_FLIP]], %arg17 : i1 +# DUMPIR: scf.yield %[[ACC_L]], %[[NEW_PARITY]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1 +# DUMPIR: } +# DUMPIR: nvvm.wgmma.wait.group.sync.aligned 0 +# DUMPIR: %[[TID_X_EPI:.*]] = gpu.thread_id x +# DUMPIR: %[[BID_X_EPI:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_EPI:.*]] = gpu.block_id y +# DUMPIR: %[[C128_EPI1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMX_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI1]] : index +# DUMPIR: %[[C128_EPI2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIMY_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI2]] : index +# DUMPIR: %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_VIEW:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_VIEW]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[SUBVIEW_EPI:.*]] = memref.subview %{{.*}}[%[[DIMX_EPI]], %[[DIMY_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>> +# DUMPIR: nvgpu.warpgroup.mma.store %[[LOOP_RES]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: gpu.barrier +# DUMPIR: %[[C0_STORE:.*]] = arith.constant 0 : index +# DUMPIR: %[[C128_STORE:.*]] = arith.constant 128 : index +# DUMPIR: %[[C1_STORE:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] { +# DUMPIR: %[[VAL_LOAD:.*]] = memref.load %[[VIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: memref.store %[[VAL_LOAD]], %[[SUBVIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>> diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py index f98cfd7..b725e50 100644 --- a/mlir/test/Examples/NVGPU/Ch5.py +++ b/mlir/test/Examples/NVGPU/Ch5.py @@ -1,5 +1,9 @@ # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \ -# RUN: %PYTHON %s | FileCheck %s +# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \ +# RUN: then %PYTHON %s | FileCheck %s; \ +# RUN: else export MLIR_NVDSL_PRINT_IR=1; \ +# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi' + # ===----------------------------------------------------------------------===// # Chapter 5 : Warp Specialized GEMM with Tensor Core @@ -156,7 +160,7 @@ def producer_loop( ): phase = const(True, ty=T.bool()) - for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]): + for iv, phase, _ in scf.for_(0, (K // TILE_K), 1, [phase]): stage = iv % num_stages # Wait MMA to be done mbar_mma[stage].try_wait(phase) @@ -253,13 +257,13 @@ def epilogue(D: WGMMAMatrix, d_dev): @NVDSL.mlir_func def gemm_warp_specialized(a, b, d, num_stages): token_ty = gpu.AsyncTokenType.get() - t1 = gpu.wait(token_ty, []) + t1 = gpu.wait([]) a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], []) b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], []) d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], []) t5 = gpu.memcpy(token_ty, [t4], a_dev, a) t6 = gpu.memcpy(token_ty, [t5], b_dev, b) - t7 = gpu.wait(token_ty, [t6]) + t7 = gpu.wait([t6]) sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B a_tma = TMA([128, 64], a.type, swizzle=sw) @@ -295,7 +299,7 @@ def gemm_warp_specialized(a, b, d, num_stages): gemm_warp_specialized_kernel() t8 = gpu.memcpy(token_ty, [t7], d, d_dev) - gpu.wait(None, [t8]) + gpu.wait([t8]) # Python pass arguments to MLIR @@ -311,11 +315,166 @@ d = np.zeros((M, N), np.float32) gemm_warp_specialized(a, b, d, num_stages=7) +if os.getenv("MLIR_NVDSL_PRINT_IR") != "1": + # Verify MLIR with reference computation + ref_d = a.astype(np.float16) @ b.astype(np.float16) + np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) -# Verify MLIR with reference computation -ref_d = a.astype(np.float16) @ b.astype(np.float16) -np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01) - - -print("PASS") + print("PASS") # CHECK-NOT: Mismatched elements +# CHECK: PASS + +# DUMPIR: %[[TID_X:.*]] = gpu.thread_id x +# DUMPIR: %[[C128:.*]] = arith.constant 128 : index +# DUMPIR: %[[REM1:.*]] = arith.remui %[[TID_X]], %[[C128]] : index +# DUMPIR: %[[C0:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_PRIMARY:.*]] = arith.cmpi eq, %[[REM1]], %[[C0]] : index +# DUMPIR: %[[C128_1:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIV1:.*]] = arith.divui %[[TID_X]], %[[C128_1]] : index +# DUMPIR: %[[C1:.*]] = arith.constant 1 : index +# DUMPIR: %[[IS_PRODUCER:.*]] = arith.cmpi eq, %[[DIV1]], %[[C1]] : index +# DUMPIR: %[[TID_X_2:.*]] = gpu.thread_id x +# DUMPIR: %[[C128_2:.*]] = arith.constant 128 : index +# DUMPIR: %[[REM2:.*]] = arith.remui %[[TID_X_2]], %[[C128_2]] : index +# DUMPIR: %[[C0_2:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_PRIMARY_2:.*]] = arith.cmpi eq, %[[REM2]], %[[C0_2]] : index +# DUMPIR: %[[C128_3:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIV2:.*]] = arith.divui %[[TID_X_2]], %[[C128_3]] : index +# DUMPIR: %[[C0_3:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_CONSUMER:.*]] = arith.cmpi eq, %[[DIV2]], %[[C0_3]] : index +# DUMPIR: %[[TID_X_3:.*]] = gpu.thread_id x +# DUMPIR: %[[MBAR_MMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[MBAR_TMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C0_4:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_THREAD0:.*]] = arith.cmpi eq, %[[TID_X_3]], %[[C0_4]] : index +# DUMPIR: scf.if %[[IS_THREAD0]] { +# DUMPIR: %[[C0_INIT:.*]] = arith.constant 0 : index +# DUMPIR: %[[C7:.*]] = arith.constant 7 : index +# DUMPIR: %[[C1_INIT:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] { +# DUMPIR: %[[C1_INIT_VAL:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MBAR_MMA]][%arg15], %[[C1_INIT_VAL]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C1_INIT_VAL_2:.*]] = arith.constant 1 : index +# DUMPIR: nvgpu.mbarrier.init %[[MBAR_TMA]][%arg15], %[[C1_INIT_VAL_2]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: } +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +# DUMPIR: } +# DUMPIR: scf.if %[[IS_PRODUCER]] { +# DUMPIR: nvvm.setmaxregister decrease 40 +# DUMPIR: %[[TRUE:.*]] = arith.constant true +# DUMPIR: %[[C0_PROD:.*]] = arith.constant 0 : index +# DUMPIR: %[[C16:.*]] = arith.constant 16 : index +# DUMPIR: %[[C1_PROD:.*]] = arith.constant 1 : index +# DUMPIR: %[[PROD_LOOP:.*]] = scf.for %arg15 = %[[C0_PROD]] to %[[C16]] step %[[C1_PROD]] iter_args(%arg16 = %[[TRUE]]) -> (i1) { +# DUMPIR: %[[C7_PROD:.*]] = arith.constant 7 : index +# DUMPIR: %[[SLOT:.*]] = arith.remui %arg15, %[[C7_PROD]] : index +# DUMPIR: %[[TIMEOUT:.*]] = arith.constant 10000000 : index +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MBAR_MMA]][%[[SLOT]]], %arg16, %[[TIMEOUT]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C6:.*]] = arith.constant 6 : index +# DUMPIR: %[[IS_LAST:.*]] = arith.cmpi eq, %[[SLOT]], %[[C6]] : index +# DUMPIR: %[[TRUE_2:.*]] = arith.constant true +# DUMPIR: %[[FLIP:.*]] = arith.xori %arg16, %[[TRUE_2]] : i1 +# DUMPIR: %[[PHASE:.*]] = arith.select %[[IS_LAST]], %[[FLIP]], %arg16 : i1 +# DUMPIR: %[[BID_X:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y:.*]] = gpu.block_id y +# DUMPIR: %[[C128_TILE:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_X:.*]] = arith.muli %[[BID_X]], %[[C128_TILE]] : index +# DUMPIR: %[[C128_TILE_2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_Y:.*]] = arith.muli %[[BID_Y]], %[[C128_TILE_2]] : index +# DUMPIR: %[[TID_PROD:.*]] = gpu.thread_id x +# DUMPIR: %[[C16384:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A:.*]] = arith.muli %[[SLOT]], %[[C16384]] : index +# DUMPIR: %[[C16384_2:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_B_BASE:.*]] = arith.muli %[[SLOT]], %[[C16384_2]] : index +# DUMPIR: %[[C114688:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B1:.*]] = arith.addi %[[OFF_B_BASE]], %[[C114688]] : index +# DUMPIR: %[[C8192:.*]] = arith.constant 8192 : index +# DUMPIR: %[[OFF_B2:.*]] = arith.addi %[[OFF_B1]], %[[C8192]] : index +# DUMPIR: %[[SMEM:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A:.*]] = memref.view %[[SMEM]][%[[OFF_A]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B1:.*]] = memref.view %[[SMEM_2]][%[[OFF_B1]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B2:.*]] = memref.view %[[SMEM_3]][%[[OFF_B2]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[TX_COUNT:.*]] = arith.constant 32768 : index +# DUMPIR: nvgpu.mbarrier.arrive.expect_tx %[[MBAR_TMA]][%[[SLOT]]], %[[TX_COUNT]], predicate = %[[IS_PRIMARY]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C128_WG:.*]] = arith.constant 128 : index +# DUMPIR: %[[TID_MOD:.*]] = arith.remui %[[TID_PROD]], %[[C128_WG]] : index +# DUMPIR: %[[C0_TMA:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_TMA_THREAD:.*]] = arith.cmpi eq, %[[TID_MOD]], %[[C0_TMA]] : index +# DUMPIR: %[[C64:.*]] = arith.constant 64 : index +# DUMPIR: %[[K_COORD:.*]] = arith.muli %arg15, %[[C64]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[K_COORD]], %[[DIM_X]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_A]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIM_Y]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B1]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[C64_OFF:.*]] = arith.constant 64 : index +# DUMPIR: %[[DIM_Y_OFF:.*]] = arith.addi %[[DIM_Y]], %[[C64_OFF]] : index +# DUMPIR: nvgpu.tma.async.load %{{.*}}[%[[DIM_Y_OFF]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B2]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: scf.yield %[[PHASE]] : i1 +# DUMPIR: } +# DUMPIR: } +# DUMPIR: scf.if %[[IS_CONSUMER]] { +# DUMPIR: nvvm.setmaxregister increase 232 +# DUMPIR: %[[FALSE:.*]] = arith.constant false +# DUMPIR: %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[C0_CONS:.*]] = arith.constant 0 : index +# DUMPIR: %[[C16_CONS:.*]] = arith.constant 16 : index +# DUMPIR: %[[C1_CONS:.*]] = arith.constant 1 : index +# DUMPIR: %[[CONS_LOOP:.*]]:2 = scf.for %arg15 = %[[C0_CONS]] to %[[C16_CONS]] step %[[C1_CONS]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) { +# DUMPIR: %[[C7_CONS:.*]] = arith.constant 7 : index +# DUMPIR: %[[SLOT_CONS:.*]] = arith.remui %arg15, %[[C7_CONS]] : index +# DUMPIR: %[[TIMEOUT_CONS:.*]] = arith.constant 10000000 : index +# DUMPIR: nvgpu.mbarrier.try_wait.parity %[[MBAR_TMA]][%[[SLOT_CONS]]], %arg17, %[[TIMEOUT_CONS]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> +# DUMPIR: %[[C16384_CONS:.*]] = arith.constant 16384 : index +# DUMPIR: %[[OFF_A_CONS:.*]] = arith.muli %[[SLOT_CONS]], %[[C16384_CONS]] : index +# DUMPIR: %[[C114688_CONS:.*]] = arith.constant 114688 : index +# DUMPIR: %[[OFF_B_CONS:.*]] = arith.addi %[[OFF_A_CONS]], %[[C114688_CONS]] : index +# DUMPIR: %[[SMEM_CONS:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_A_CONS:.*]] = memref.view %[[SMEM_CONS]][%[[OFF_A_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[SMEM_CONS_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[VIEW_B_CONS:.*]] = memref.view %[[SMEM_CONS_2]][%[[OFF_B_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>> +# DUMPIR: %[[DESC_A:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_CONS]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[DESC_B:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_CONS]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>> +# DUMPIR: %[[ACC:.*]] = nvgpu.warpgroup.mma %[[DESC_A]], %[[DESC_B]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>> +# DUMPIR: %[[C0_CMP:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_NOT_FIRST:.*]] = arith.cmpi ugt, %arg15, %[[C0_CMP]] : index +# DUMPIR: %[[ARRIVE_PRED:.*]] = arith.andi %[[IS_NOT_FIRST]], %[[IS_PRIMARY_2]] : i1 +# DUMPIR: scf.if %[[ARRIVE_PRED]] { +# DUMPIR: %[[C0_ARR:.*]] = arith.constant 0 : index +# DUMPIR: %[[IS_ZERO:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C0_ARR]] : index +# DUMPIR: %[[C6_WRAP:.*]] = arith.constant 6 : index +# DUMPIR: %[[C1_SUB:.*]] = arith.constant 1 : index +# DUMPIR: %[[PREV_SLOT:.*]] = arith.subi %[[SLOT_CONS]], %[[C1_SUB]] : index +# DUMPIR: %[[BARR_ID:.*]] = arith.select %[[IS_ZERO]], %[[C6_WRAP]], %[[PREV_SLOT]] : index +# DUMPIR: %{{.*}} = nvgpu.mbarrier.arrive %[[MBAR_MMA]][%[[BARR_ID]]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> !nvgpu.mbarrier.token +# DUMPIR: } +# DUMPIR: %[[C6_LAST:.*]] = arith.constant 6 : index +# DUMPIR: %[[IS_LAST_CONS:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C6_LAST]] : index +# DUMPIR: %[[TRUE_CONS:.*]] = arith.constant true +# DUMPIR: %[[FLIP_CONS:.*]] = arith.xori %arg17, %[[TRUE_CONS]] : i1 +# DUMPIR: %[[PHASE_CONS:.*]] = arith.select %[[IS_LAST_CONS]], %[[FLIP_CONS]], %arg17 : i1 +# DUMPIR: scf.yield %[[ACC]], %[[PHASE_CONS]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1 +# DUMPIR: } +# DUMPIR: nvvm.wgmma.wait.group.sync.aligned 0 +# DUMPIR: %[[TID_EPI:.*]] = gpu.thread_id x +# DUMPIR: %[[BID_X_EPI:.*]] = gpu.block_id x +# DUMPIR: %[[BID_Y_EPI:.*]] = gpu.block_id y +# DUMPIR: %[[C128_EPI:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_X_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI]] : index +# DUMPIR: %[[C128_EPI_2:.*]] = arith.constant 128 : index +# DUMPIR: %[[DIM_Y_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI_2]] : index +# DUMPIR: %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>> +# DUMPIR: %[[C0_EPI:.*]] = arith.constant 0 : index +# DUMPIR: %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_EPI]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%[[DIM_X_EPI]], %[[DIM_Y_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>> +# DUMPIR: nvgpu.warpgroup.mma.store %[[CONS_LOOP]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: gpu.barrier +# DUMPIR: %[[C0_STORE:.*]] = arith.constant 0 : index +# DUMPIR: %[[C128_STORE:.*]] = arith.constant 128 : index +# DUMPIR: %[[C1_STORE:.*]] = arith.constant 1 : index +# DUMPIR: scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] { +# DUMPIR: %{{.*}} = memref.load %[[VIEW_EPI]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>> +# DUMPIR: memref.store %{{.*}}, %[[SUBVIEW]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>> +# DUMPIR: } +# DUMPIR: } +# DUMPIR: gpu.terminator diff --git a/mlir/test/Examples/NVGPU/lit.local.cfg b/mlir/test/Examples/NVGPU/lit.local.cfg index 689cd25..af44b2e 100644 --- a/mlir/test/Examples/NVGPU/lit.local.cfg +++ b/mlir/test/Examples/NVGPU/lit.local.cfg @@ -1,4 +1,4 @@ config.unsupported = False -if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests: +if not config.enable_cuda_runner or not config.enable_bindings_python: config.unsupported = True
\ No newline at end of file diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py index 90dbb23..8561072 100644 --- a/mlir/test/Examples/NVGPU/tools/nvdsl.py +++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py @@ -9,6 +9,7 @@ from mlir import runtime as rt from tools import nvgpucompiler MLIR_DYNAMIC = -9223372036854775808 +DUMP_ONLY = os.getenv("MLIR_NVDSL_PRINT_IR") == "1" def const(value: int, ty=None): @@ -84,9 +85,7 @@ class Mbarriers: self.mbar_group_op, txcount_op, self.id_op, predicate=predicate ) else: - nvgpu.mbarrier_arrive( - ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op - ) + nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op) def try_wait(self, phase: bool = False, ticks: int = 10000000): ticks_op = const(ticks) @@ -144,7 +143,9 @@ class TMA: device_ptr, ) self.tma_descriptor = nvgpu.TmaCreateDescriptorOp( - tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape) + tma_descriptor_ty, + device_unranked_memref, + list(map(const, self.tma_box_shape)), ) return self.tma_descriptor.result @@ -156,7 +157,7 @@ class TMA: dest, mbarrier.mbar_group_op, self.tma_descriptor, - coordinates=map(const, coords), + coordinates=list(map(const, coords)), mbarId=mbarrier.id_op, predicate=predicate, ) @@ -310,13 +311,10 @@ class NVDSL: @functools.wraps(func) def wrapper(*args, **kwargs): launch_op = gpu.LaunchOp( - None, - [], - *map(const, grid), - *map(const, block), - dynamicSharedMemorySize=arith.constant(T.i32(), smem), + grid_size=grid, + block_size=block, + dynamic_shared_memory_size=arith.constant(T.i32(), smem), ) - launch_op.body.blocks.append(*([T.index()] * 12)) with ir.InsertionPoint(launch_op.body.blocks[0]): result = func(*args, **kwargs) gpu.terminator() @@ -334,13 +332,11 @@ class NVDSL: def saveIR(module): """Save generated IR""" - if True: # self.saveIR: - # print(mlir_nvgpu_module) - original_stdout = sys.stdout - with open("nvdsl.mlir", "w") as f: - sys.stdout = f - print(module) - sys.stdout = original_stdout + original_stdout = sys.stdout + with open("nvdsl.mlir", "w") as f: + sys.stdout = f + print(module) + sys.stdout = original_stdout def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue": """Generate MLIR's Arith dialects binary operations.""" @@ -429,6 +425,9 @@ class NVDSL: # Save IR in a file # saveIR(module) + if DUMP_ONLY: + print(module) + return 0 # Verify the module module.operation.verify() diff --git a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py index 1c9cc74..4b661f8 100644 --- a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py +++ b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py @@ -35,9 +35,11 @@ class NvgpuCompiler: def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Wraps the module in a JIT execution engine.""" - return execution_engine.ExecutionEngine( + ee = execution_engine.ExecutionEngine( module, opt_level=self.opt_level, shared_libs=self.shared_libs ) + ee.initialize() + return ee def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Compiles and jits the module.""" diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 0c5fec8c..2f5dd28 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -145,3 +145,11 @@ func.func @verify_fail_3() { %r = "arith.constant"() {value = -3 : si32} : () -> si32 return } + +// ----- + +// Verify that symbols with results are rejected +module { + // expected-error@+1 {{'test.symbol_with_result' op symbols must not have results}} + %0 = "test.symbol_with_result"() <{sym_name = "test_symbol"}> : () -> i32 +} diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index b725307..2e23746 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -105,3 +105,10 @@ func.func @dialect_location() { test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>)) return } + +// CHECK-LABEL: @location_attr +// CHECK: test.op_with_loc_attr loc("loc1":10:20) {foo.discardable_loc_attr = loc("loc2":20:30)} loc({{.*}}locations.mlir":[[# @LINE+2]]:3) +func.func @location_attr() { + test.op_with_loc_attr loc("loc1":10:20) {foo.discardable_loc_attr = loc("loc2":20:30)} + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir new file mode 100644 index 0000000..5f8b2f4 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir @@ -0,0 +1,26 @@ +// REQUIRES: system-linux +// TODO: Run only on Linux until we figure out how to build +// mlir_apfloat_wrappers in a platform-independent way. + +// All floating-point arithmetics is lowered through APFloat. +// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-vector-to-scf \ +// RUN: --convert-scf-to-cf --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo_vec() -> (vector<4xf8E4M3FN>, vector<4xf32>) { + %cst1 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf8E4M3FN> + %cst2 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf32> + return %cst1, %cst2 : vector<4xf8E4M3FN>, vector<4xf32> +} + +func.func @entry() { + // CHECK: ( 3.5, 3.5, 3.5, 3.5 ) + %a1_vec = arith.constant dense<[1.4, 1.4, 1.4, 1.4]> : vector<4xf8E4M3FN> + %b1_vec, %b2_vec = func.call @foo_vec() : () -> (vector<4xf8E4M3FN>, vector<4xf32>) + %c1_vec = arith.addf %a1_vec, %b1_vec : vector<4xf8E4M3FN> // not supported by LLVM + vector.print %c1_vec : vector<4xf8E4M3FN> + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir new file mode 100644 index 0000000..7f72dd5 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -0,0 +1,82 @@ +// REQUIRES: system-linux +// TODO: Run only on Linux until we figure out how to build +// mlir_apfloat_wrappers in a platform-independent way. + +// Case 1: All floating-point arithmetics is lowered through APFloat. +// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Case 2: Only unsupported arithmetics (f8E4M3FN) is lowered through APFloat. +// Arithmetics on f32 is lowered directly to LLVM. +// RUN: mlir-opt %s --convert-to-llvm --convert-arith-to-apfloat \ +// RUN: --convert-to-llvm --reconcile-unrealized-casts | \ +// RUN: mlir-runner -e entry --entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s + +// Put rhs into separate function so that it won't be constant-folded. +func.func @foo() -> (f8E4M3FN, f32) { + %cst1 = arith.constant 2.2 : f8E4M3FN + %cst2 = arith.constant 2.2 : f32 + return %cst1, %cst2 : f8E4M3FN, f32 +} + +func.func @entry() { + %a1 = arith.constant 1.4 : f8E4M3FN + %a2 = arith.constant 1.4 : f32 + %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32) + + // CHECK: 2.2 + vector.print %b2 : f32 + + // CHECK-NEXT: 3.5 + %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM + vector.print %c1 : f8E4M3FN + + // CHECK-NEXT: 3.6 + %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM + vector.print %c2 : f32 + + // CHECK-NEXT: 2.25 + %cvt = arith.truncf %b2 : f32 to f8E4M3FN + vector.print %cvt : f8E4M3FN + + // CHECK-NEXT: -2.25 + %negated = arith.negf %cvt : f8E4M3FN + vector.print %negated : f8E4M3FN + + // CHECK-NEXT: -2.25 + %min = arith.minimumf %cvt, %negated : f8E4M3FN + vector.print %min : f8E4M3FN + + // CHECK-NEXT: 1 + %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN + vector.print %cmp1 : i1 + + // CHECK-NEXT: 1 + // Bit pattern: 01, interpreted as signed integer: 1 + %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2 + vector.print %cvt_int_signed : i2 + + // CHECK-NEXT: -2 + // Bit pattern: 10, interpreted as signed integer: -2 + %cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2 + vector.print %cvt_int_unsigned : i2 + + // CHECK-NEXT: -6 + // Bit pattern: 1...11110111, interpreted as signed: -9 + // Closest f4E2M1FN value: -6.0 + %c9 = arith.constant -9 : i16 + %cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN + vector.print %cvt_from_signed_int : f4E2M1FN + + // CHECK-NEXT: 6 + // Bit pattern: 1...11110111, interpreted as unsigned: 65527 + // Closest f4E2M1FN value: 6.0 + %cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN + vector.print %cvt_from_unsigned_int : f4E2M1FN + + return +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir index 9d04357..d26853d 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir @@ -22,7 +22,7 @@ func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : t } func.func @main() { - %c0 = arith.constant 0 : i32 + %c0 = arith.constant 0.0 : f32 %c7 = arith.constant 7 : index %A = arith.constant dense<[ @@ -44,7 +44,7 @@ func.func @main() { %A_dyn = tensor.cast %A : tensor<13x7xf32> to tensor<?x?xf32> %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32> - %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> + %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data = // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309] diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir index ad7dbb9..e2c0f1d2 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir @@ -16,7 +16,7 @@ func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf3 } func.func @main() { - %c0 = arith.constant 0 : i32 + %c0 = arith.constant 0.0 : f32 %c7 = arith.constant 7 : index %A = arith.constant dense<[ @@ -37,7 +37,7 @@ func.func @main() { %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32> %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32> - %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> + %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data = // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309] diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir index 243f9e5..007189a 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir @@ -29,7 +29,7 @@ func.func @main() { %c128 = arith.constant 128 : i32 func.call @setArmSVLBits(%c128) : (i32) -> () - %c0 = arith.constant 0 : i32 + %c0 = arith.constant 0.0 : f32 %c7 = arith.constant 7 : index %A = arith.constant dense<[ @@ -50,7 +50,7 @@ func.func @main() { %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32> %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32> - %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> + %C = linalg.fill ins(%c0 : f32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32> // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data = // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309] diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir index 127ab70..c90476e 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir @@ -24,17 +24,14 @@ func.func @main() { %d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32> %d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) %c1x1 = arith.constant dense<0.0> : tensor<1x1xf32> @@ -48,72 +45,82 @@ func.func @main() { %d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32> %d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size func.call @broadcast_add(%d1x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size func.call @broadcast_add(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @matmul_generic(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.matmul - // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK-NEXT: linalg.matmul + // CHECK-NEXT: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size func.call @matmul_named(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) %c64x57 = arith.constant dense<0.0> : tensor<16x29xf32> %c3x4 = arith.constant dense<0.0> : tensor<3x4xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>) - - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>) - // CHECK: ERROR: Runtime op verification failed - // CHECK: linalg.generic - // CHECK: unexpected negative result on dimension #0 of input/output operand #0 + // CHECK-NEXT: linalg.generic + // CHECK-NEXT: unexpected negative result on dimension #0 of input/output operand #0 func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>) %c0x = arith.constant dense<1.0> : tensor<0xf32> %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32> - // CHECK-NOT: ERROR: Runtime op verification failed - func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>) %c0x5 = arith.constant dense<0.0> : tensor<0x5xf32> %d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor<?x?xf32> // CHECK-NOT: ERROR: Runtime op verification failed + func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed func.call @fill_empty_2d(%d0x5) : (tensor<?x?xf32>) -> (tensor<?x?xf32>) + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + return } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index 8fa32d7..bbda8d4e 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -27,8 +27,8 @@ func.func @main() { %A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor<?x?xf32> %B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor<?x?xf32> - %c0_i32 = arith.constant 0 : i32 - %C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32> + %c0_f32 = arith.constant 0.0 : f32 + %C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor<?x?xf32>) -> tensor<?x?xf32> %res = linalg.matmul ins(%A_dyn, %B_dyn: tensor<?x?xf32>, tensor<?x?xf32>) outs(%C_init: tensor<?x?xf32>) -> tensor<?x?xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index 8487567..09cfee1 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -50,6 +50,17 @@ func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], return } +func.func @subview_with_empty_slice(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, + %dim_0: index, + %dim_1: index, + %dim_2: index, + %offset: index) { + %subview = memref.subview %memref[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : + memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to + memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> + return +} + func.func @main() { %0 = arith.constant 0 : index @@ -127,5 +138,9 @@ func.func @main() { func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2) : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> () + // CHECK-NOT: ERROR: Runtime op verification failed + %offset = arith.constant 10 : index + func.call @subview_with_empty_slice(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2, %offset) + : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index, index) -> () return } diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir index a77fa31..745eea3 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir @@ -39,6 +39,11 @@ func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, return } +func.func @extract_slice_empty_tensor(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index, %offset: index) { + tensor.extract_slice %arg0[%offset, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32> + return +} + func.func @main() { %0 = arith.constant 0 : index @@ -115,5 +120,9 @@ func.func @main() { %dim_2 = arith.constant 1 : index func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> () + // CHECK-NOT: ERROR: Runtime op verification failed + %offset = arith.constant 10 : index + func.call @extract_slice_empty_tensor(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2, %offset) : (tensor<10x4x1xf32>, index, index, index, index) -> () + return } diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir index a374d9a..e3fee91 100644 --- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir +++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir @@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te } func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> { - %cst = arith.constant 0.0 : f64 + %cst = arith.constant 0.0 : f32 %empty = tensor.empty() : tensor<10x15xf32> // expected-remark @below {{fill}} - %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32> %real_lhs = linalg.mul ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32> diff --git a/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir b/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir new file mode 100644 index 0000000..c4608ac --- /dev/null +++ b/mlir/test/Integration/Dialect/XeGPU/LANE/load_store_subview.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_levelzero_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +module @subview attributes {gpu.container_module} { + gpu.module @kernel { + gpu.func @subview(%src: memref<256xf32>, %dst: memref<256xf32>) kernel { + %src_subview = memref.subview %src[5] [251] [1] : memref<256xf32> to memref<251xf32, strided<[1], offset: 5>> + %dst_subview = memref.subview %dst[10] [246] [1] : memref<256xf32> to memref<246xf32, strided<[1], offset: 10>> + %lane_id = gpu.lane_id + %mask = arith.constant 1 : i1 + %loaded = xegpu.load %src_subview[%lane_id], %mask : memref<251xf32, strided<[1], offset: 5>>, index, i1 -> f32 + xegpu.store %loaded, %dst_subview[%lane_id], %mask : f32, memref<246xf32, strided<[1], offset: 10>>, index, i1 + gpu.return + } + } + func.func @test(%src: memref<256xf32>, %dst: memref<256xf32>) -> memref<256xf32> { + %memref_src = gpu.alloc () : memref<256xf32> + gpu.memcpy %memref_src, %src : memref<256xf32>, memref<256xf32> + %memref_dst = gpu.alloc () : memref<256xf32> + gpu.memcpy %memref_dst, %dst : memref<256xf32>, memref<256xf32> + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + gpu.launch_func @kernel::@subview blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%memref_src : memref<256xf32>, %memref_dst : memref<256xf32>) + gpu.wait // Wait for the kernel to finish. + gpu.memcpy %dst, %memref_dst : memref<256xf32>, memref<256xf32> + gpu.dealloc %memref_src : memref<256xf32> + gpu.dealloc %memref_dst : memref<256xf32> + return %dst : memref<256xf32> + } + func.func @main() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %memref_src = memref.alloc() : memref<256xf32> + %memref_dst = memref.alloc() : memref<256xf32> + // Initialize source memref + scf.for %i = %c0 to %c256 step %c1 { + %val = arith.index_cast %i : index to i32 + %val_float = arith.sitofp %val : i32 to f32 + memref.store %val_float, %memref_src[%i] : memref<256xf32> + } + // Initialize destination memref to zero + scf.for %i = %c0 to %c256 step %c1 { + %zero = arith.constant 0.0 : f32 + memref.store %zero, %memref_dst[%i] : memref<256xf32> + } + // Call test function + %gpu_result = call @test(%memref_src, %memref_dst) : (memref<256xf32>, memref<256xf32>) -> memref<256xf32> + %gpu_result_casted = memref.cast %gpu_result : memref<256xf32> to memref<*xf32> + // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}} + // CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + call @printMemrefF32(%gpu_result_casted) : (memref<*xf32>) -> () + // Deallocate memrefs + memref.dealloc %memref_src : memref<256xf32> + memref.dealloc %memref_dst : memref<256xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir index edf8775..5ed2148 100644 --- a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir +++ b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir @@ -3,7 +3,7 @@ // RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ // RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \ // RUN: | mlir-runner \ -// RUN: --shared-libs=%mlir_sycl_runtime \ +// RUN: --shared-libs=%mlir_levelzero_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ // RUN: --shared-libs=%mlir_c_runner_utils \ // RUN: --entry-point-result=void \ diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir new file mode 100644 index 0000000..c3dd35b --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/wmma-matmul-f64.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d1, d0)> + +func.func @main() { + %a = memref.alloc() : memref<8x4xf64> + %b = memref.alloc() : memref<4x8xf64> + %c = memref.alloc() : memref<8x8xf64> + %d = memref.alloc() : memref<8x8xf64> + + %f1 = arith.constant 1.0e+00 : f64 + %fcst = arith.constant 3.14e+00 : f64 + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + // Initialize the Input matrixes with ones. + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c4 step %c1 { + memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64> + memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64> + } + } + // Initialize the accumulator matrix with a constant. + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64> + } + } + + %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64> + %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64> + %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64> + %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64> + + gpu.host_register %2 : memref<*xf64> + gpu.host_register %20 : memref<*xf64> + gpu.host_register %33 : memref<*xf64> + gpu.host_register %34 : memref<*xf64> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp"> + %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp"> + %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp"> + + %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp"> + + gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64> + gpu.terminator + } + // Print the memref after computation. + call @printMemrefF64(%34) : (memref<*xf64>) -> () + // CHECK: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14] + return +} + +func.func private @printMemrefF64(memref<*xf64>) diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir index 5585d98..d0001f6 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-and.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir index cd90ce3..fcff5f4 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-maxsi.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir index fec2567..4718ac9 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-minsi.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir index d5633b0..5e3a7e7e 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-op.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir index db297b0..f1a48ae 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-or.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir index 65cbc79..f0a46ce 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-region.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir index a0c955e..ddbabd4 100644 --- a/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir +++ b/mlir/test/Integration/GPU/CUDA/all-reduce-xor.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir b/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir index f041df8..5c56e2d 100644 --- a/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir +++ b/mlir/test/Integration/GPU/CUDA/alloc-host-shared.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir index 71a21cf..83cf70c 100644 --- a/mlir/test/Integration/GPU/CUDA/assert.mlir +++ b/mlir/test/Integration/GPU/CUDA/assert.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/async.mlir b/mlir/test/Integration/GPU/CUDA/async.mlir index 5acadd6..3e45b5a 100644 --- a/mlir/test/Integration/GPU/CUDA/async.mlir +++ b/mlir/test/Integration/GPU/CUDA/async.mlir @@ -8,8 +8,11 @@ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_async_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ -// RUN: --entry-point-result=void -O0 \ -// RUN: | FileCheck %s +// RUN: --entry-point-result=void -O0 +// RUN: +// This test is overly flaky right now and needs investigation, skipping FileCheck. +// See: https://github.com/llvm/llvm-project/issues/170833 +// DISABLED: | FileCheck %s func.func @main() { %c0 = arith.constant 0 : index diff --git a/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir b/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir index 34dde6e..77a4fa0 100644 --- a/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir +++ b/mlir/test/Integration/GPU/CUDA/command-line-arg.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 ptxas-cmd-options='-v --register-usage-level=8'" -debug-only=serialize-to-binary \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 ptxas-cmd-options='-v --register-usage-level=8' allow-pattern-rollback=0" -debug-only=serialize-to-binary \ // RUN: 2>&1 | FileCheck %s func.func @host_function(%arg0 : f32, %arg1 : memref<?xf32>) { diff --git a/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir index ed01416..51f6e36 100644 --- a/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir +++ b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir @@ -2,7 +2,7 @@ // increment a global atomic counter and wait for the counter to reach 2. // // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | env CUDA_MODULE_LOADING=EAGER mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir b/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir index 27ec1ec..efffcaa 100644 --- a/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir +++ b/mlir/test/Integration/GPU/CUDA/dump-ptx.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline -debug-only=serialize-to-isa \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" -debug-only=serialize-to-isa \ // RUN: 2>&1 | FileCheck %s // CHECK-LABEL: Generated by LLVM NVPTX Back-End diff --git a/mlir/test/Integration/GPU/CUDA/dump-sass.mlir b/mlir/test/Integration/GPU/CUDA/dump-sass.mlir index d32f5ef..f810678 100644 --- a/mlir/test/Integration/GPU/CUDA/dump-sass.mlir +++ b/mlir/test/Integration/GPU/CUDA/dump-sass.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline -debug-only=dump-sass \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="allow-pattern-rollback=0" -debug-only=dump-sass \ // RUN: 2>&1 | FileCheck %s // CHECK: MOV diff --git a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir index 07f3218..fe3c2b1 100644 --- a/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir +++ b/mlir/test/Integration/GPU/CUDA/gpu-to-cubin.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir index b2ac90a..f8f1aa8 100644 --- a/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir +++ b/mlir/test/Integration/GPU/CUDA/multiple-all-reduce.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir index fd664f2..ef11676 100644 --- a/mlir/test/Integration/GPU/CUDA/printf.mlir +++ b/mlir/test/Integration/GPU/CUDA/printf.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/shuffle.mlir b/mlir/test/Integration/GPU/CUDA/shuffle.mlir index a6207d6..a4be5223 100644 --- a/mlir/test/Integration/GPU/CUDA/shuffle.mlir +++ b/mlir/test/Integration/GPU/CUDA/shuffle.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/two-modules.mlir b/mlir/test/Integration/GPU/CUDA/two-modules.mlir index c3cee2f..3490003 100644 --- a/mlir/test/Integration/GPU/CUDA/two-modules.mlir +++ b/mlir/test/Integration/GPU/CUDA/two-modules.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s \ -// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format allow-pattern-rollback=0" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir index b98e8b0..c634444 100644 --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -184,3 +184,19 @@ func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) { } return } + +// CHECK-LABEL: func @multiple_loop_ivs +func.func @multiple_loop_ivs(%arg0: memref<?x64xi32>) { + %ub1 = test.with_bounds { umin = 1 : index, umax = 32 : index, + smin = 1 : index, smax = 32 : index } : index + %c0_i32 = arith.constant 0 : i32 + // CHECK: scf.forall + scf.forall (%arg1, %arg2) in (%ub1, 64) { + // CHECK: test.reflect_bounds {smax = 31 : index, smin = 0 : index, umax = 31 : index, umin = 0 : index} + %1 = test.reflect_bounds %arg1 : index + // CHECK-NEXT: test.reflect_bounds {smax = 63 : index, smin = 0 : index, umax = 63 : index, umin = 0 : index} + %2 = test.reflect_bounds %arg2 : index + memref.store %c0_i32, %arg0[%1, %2] : memref<?x64xi32> + } + return +} diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir index 4fa7406..624e099 100644 --- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s +// See %test_unreifiable_result_shape below for why `error-on-partition-iteration-limit` is set to false. func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) -> (index, index, index, index, index) { @@ -27,12 +28,14 @@ func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) // ----- -func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) +// Test result shape reification for an operation that implements only +// `reifyResultShapes` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_result_shapes(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) -> (index, index, index, index, index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1) + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>) %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32> %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32> @@ -41,7 +44,7 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3 %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> return %1, %2, %3, %4, %5 : index, index, index, index, index } -// CHECK-LABEL: func @result_shape_per_dim( +// CHECK-LABEL: func @reify_shaped_type_using_reify_result_shapes( // CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> // CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index @@ -51,3 +54,127 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3 // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// Test result shape reification for an operation that implements only +// `reifyShapeOfResult` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_shape_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) + -> (index, index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>) + %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32> + %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32> + %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @reify_shaped_type_using_reify_shape_of_result( +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// Test result shape reification for an operation that implements only +// `reifyDimOfResult` method of the `InferShapedTypeOpInterface`. +func.func @reify_shaped_type_using_reify_dim_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>) + -> (index, index, index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1) + : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>) + %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32> + %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32> + %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32> + %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32> + %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32> + return %1, %2, %3, %4, %5 : index, index, index, index, index +} +// CHECK-LABEL: func @reify_shaped_type_using_reify_dim_of_result( +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32> +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] + +// ----- + +// This tests also indicates a problem with the approach of just using `reifyShapes` +// without being specific about {result, dim} that needs to be resolved. The +// `reifyShapes` implementations introduces `dim` operations that are effectively +// dead, but it creates an infinite loop on pattern application (which eventually +// bails on hitting the iteration limit). This is the pitfall of this legacy +// mechanism. + +func.func @test_unreifiable_result_shapes(%arg0 : tensor<?x?xf32>) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_result_shapes"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> + %d0 = tensor.dim %0, %c0 : tensor<?x?xf32> + %d1 = tensor.dim %0, %c1 : tensor<?x?xf32> + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_result_shapes( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shapes"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] +// ----- + +func.func @test_unreifiable_result_shape(%arg0 : tensor<?x?xf32>) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> + %d0 = tensor.dim %0, %c0 : tensor<?x?xf32> + %d1 = tensor.dim %0, %c1 : tensor<?x?xf32> + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_result_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shape"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] + +// ----- + +func.func @test_unreifiable_dim_of_result_shape(%arg0 : tensor<?x?xf32>) + -> (index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = "test.unreifiable_dim_of_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> + %d0 = tensor.dim %0, %c0 : tensor<?x?xf32> + %d1 = tensor.dim %0, %c1 : tensor<?x?xf32> + return %d0, %d1 : index, index +} +// CHECK-LABEL: func @test_unreifiable_dim_of_result_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_dim_of_result_shape"(%[[ARG0]]) +// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]] +// CHECK: return %[[D0]], %[[D1]] diff --git a/mlir/test/Interfaces/TilingInterface/query-fusability.mlir b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir new file mode 100644 index 0000000..d7b0528 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/query-fusability.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics + +func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c20 = arith.constant 20 : index + + %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> + + // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}} + %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + return %result : tensor<100x200xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_producer_fusability %add : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @fusable_with_consumer_extract_slice(%arg0: tensor<100x200xf32>, %arg1: tensor<100x200xf32>, %dest: tensor<100x200xf32>) -> tensor<10x20xf32> { + // expected-remark @+1 {{can be fused with consumer tensor.extract_slice op}} + %add = linalg.add ins(%arg0, %arg1 : tensor<100x200xf32>, tensor<100x200xf32>) + outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> + + %c0 = arith.constant 0 : index + %slice = tensor.extract_slice %add[%c0, %c0] [10, 20] [1, 1] : tensor<100x200xf32> to tensor<10x20xf32> + + return %slice : tensor<10x20xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op) { + %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op + transform.test.query_consumer_fusability %add : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir new file mode 100644 index 0000000..62dd7fa --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir @@ -0,0 +1,1156 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_operand_2 = tensor.empty() : tensor<64x64xf32> + %out_operand_3 = tensor.empty() : tensor<64x64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %2 : tensor<64x64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %out_operand_4 = tensor.empty() : tensor<64xf32> + %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) { + ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.subf %out_0, %13 : f32 + %15 = arith.addf %out_1, %in : f32 + linalg.yield %14, %15 : f32, f32 + } -> (tensor<64xf32>, tensor<64xf32>) + return %2#1 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1] +// CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#3 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<64x64xf32> + %2 = tensor.empty() : tensor<64x64xf32> + %3 = tensor.empty() : tensor<64x64xf32> + %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32): + %6 = arith.mulf %in, %in_0 : f32 + %7 = arith.subf %out, %6 : f32 + %8 = arith.addf %out_1, %in : f32 + linalg.yield %7, %8 : f32, f32 + } -> (tensor<64x64xf32>, tensor<64x64xf32>) + %5 = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32> + return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %first_slice_op, %second_slice_op = transform.split_handle %slice_ops + : (!transform.any_op) + -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %first_slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32> +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[MAT_OUT:.*]] = linalg.matmul +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: %[[UNPACK:.*]] = linalg.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32> +// CHECK: return %[[FINAL_RESULT]]#3, %[[UNPACK]] : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<2048xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32> + return %unpack : tensor<2048xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)> +// CHECK: func.func @fuse_unpack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[TILED_UNPACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<2047xf32> + %unpack = linalg.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32> + return %unpack : tensor<2047xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)> +// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[TILED_UNPACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_perfect_tiling_pack_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<4x32x16xf32> + %pack = linalg.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32> + return %pack : tensor<4x32x16xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_perfect_tiling_pack_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 1) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] +// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] + +// ----- + +#map = affine_map<(d0) -> (-d0 + 4, 16)> +func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> { + %0 = tensor.empty() : tensor<1x4x16x1xf32> + %1 = tensor.empty() : tensor<4x4xf32> + %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) { + %3 = affine.min #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32> + %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32> + } + } + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32> + return %pack : tensor<1x4x16x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)> +// CHECK: func.func @fuse_pack_consumer_if_single_iteration( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32> +// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]]) +// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]]) +// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1] + +// ----- + +func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> { + %0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> + %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + } + } + %pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32> + return %pack : tensor<2x64x16x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] + +// ----- + +// It is valid to fuse the pack op in perfect tiling scenario when the dimension +// is dynamic and padding is not needed. + +func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, %arg1: tensor<64x?xf32>, %1: tensor<64x?x16xf32>) -> tensor<64x?x16xf32> { + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor<64x?xf32> + %0 = scf.forall (%arg2) = (0) to (%d1) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x?xf32>) { + %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x?xf32> to tensor<64x16xf32> + %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x?xf32> + } + } + %pack = linalg.pack %0 inner_dims_pos = [1] inner_tiles = [16] into %1 : tensor<64x?xf32> -> tensor<64x?x16xf32> + return %pack : tensor<64x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (%{{.+}}) step (16) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] + +// ----- + +// It is valid to fuse the pack op with padding semantics if it is a perfect +// tiling case. + +func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { + %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2) + %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32> + } + } + %1 = tensor.empty() : tensor<22x2x3x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<22x2x3x16xf32> + return %pack : tensor<22x2x3x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_with_padding_semantics( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> +// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) +// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM:.*]] = linalg.exp +// CHECK-SAME: ins(%[[ELEM_SRC]] +// CHECK-SAME: outs(%[[ELEM_DEST]] +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] +// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] + +// ----- + +// Imperfect tiling is not supported in pack op consumer fusion. + +#map = affine_map<(d0) -> (d0 * 5)> +#map1 = affine_map<(d0) -> (d0)> +func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x6xf32> { + %0 = tensor.empty() : tensor<30xf32> + %1 = scf.forall (%arg1) in (6) shared_outs(%arg2 = %0) -> (tensor<30xf32>) { + %3 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%3] [5] [1] : tensor<30xf32> to tensor<5xf32> + %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} ins(%extracted_slice : tensor<5xf32>) outs(%extracted_slice_0 : tensor<5xf32>) { + ^bb0(%in: f32, %out: f32): + %5 = arith.addf %in, %in : f32 + linalg.yield %5 : f32 + } -> tensor<5xf32> + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32> + } + } + %2 = tensor.empty() : tensor<5x6xf32> + %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32> + return %pack : tensor<5x6xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +module { + func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 2 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_add_multiple_tilable_consumers( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32> +// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) +// CHECK-SAME: { +// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] : +// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : +// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp +// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : +// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : +// CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul +// CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] : +// CHECK-SAME: outs(%[[MUL_OUT_SLICE]] : +// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] : +// CHECK: } +// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 : + +// ----- + +module { + func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { + %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %insert_slice : tensor<256x256xf32> + } + %dest1 = tensor.empty() : tensor<258x258xf32> + %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32> + %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @no_fuse_only_dps_consumer( +// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} { +// CHECK: linalg.add +// CHECK: linalg.mul +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice +// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +module { + func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %0 = tensor.empty() : tensor<256x256xf32> + %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> + %4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32> + %inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> + scf.yield %inserted_slice : tensor<256x256xf32> + } + %2 = tensor.empty() : tensor<256x256x24xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<256x256x24xf32> + return %3 : tensor<256x256x24xf32> + } +} + +// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> +// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { +// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> +// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: linalg.yield %[[VAL_23]] : f32 +// CHECK: } -> tensor<64x256x24xf32> +// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> +// CHECK: } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %slice_op in (%loop) num_consumer_to_fuse = 1 + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %generic:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.mulf %b0, %b1 : f32 + %1 = arith.addf %b0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor<?xf32>, tensor<?xf32>) + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + tensor.parallel_insert_slice %generic#1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + } + } + %empty = tensor.empty(%dim0) : tensor<?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?xf32> + return %result : tensor<?xf32> +} +// CHECK-LABEL: func @multi_slice_fusion1( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK-DAG: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]]#0, %[[GENERIC]]#1 : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// Check that when the given operand tiles are inconsistent, tiling fails. + +func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %loop:2 = scf.forall (%iv0) = (%c0) to (%dim0) step (%arg3) shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) { + %tilesize = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %arg0_slice = tensor.extract_slice %arg0[%iv0, 0] [%tilesize, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor<?xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> + } + } + %empty = tensor.empty(%dim0) : tensor<?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%loop#0, %loop#1 : tensor<?xf32>, tensor<?xf32>) outs(%empty : tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?xf32> + return %result : tensor<?xf32> +} +// CHECK-LABEL: func @multi_slice_fusion2( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV:.+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK: %[[TILESIZE:.+]] = affine.min +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] +// CHECK: return %[[RESULT]]#2 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>, + %arg3 : index, %arg4 : index) -> tensor<?x?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> + %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?xf32>) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor<?x?x?xf32> to tensor<?x?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> to tensor<?x?xf32> + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> to tensor<?xf32> + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%generic0 : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor<?xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> into tensor<?x?xf32> + tensor.parallel_insert_slice %generic1 into %init1[%iv0] [%tilesize0] [1] : tensor<?xf32> into tensor<?xf32> + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + return %result : tensor<?x?xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @multi_slice_fusion_with_broadcast( +// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 +// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) +// CHECK: %[[RESULT:.+]]:3 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: , %[[INIT:[a-zA-Z0-9]+]] = %[[EMPTY]]) +// CHECK-DAG: %[[TILESIZE0:.+]] = affine.min {{.+}}(%[[IV0]]) +// CHECK-DAG: %[[TILESIZE1:.+]] = affine.min {{.+}}(%[[IV1]]) +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: %[[FUSED:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV0]], %[[IV1]]] [%[[TILESIZE0]], %[[TILESIZE1]]] +// CHECK: return %[[RESULT]]#2 + +// ----- + +func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, + %arg3 : index, %arg4 : index) -> tensor<?x?xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> + %dim1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> + %dim2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> + %loop:2 = scf.forall (%iv0, %iv1) = (%c0, %c0) to (%dim0, %dim1) step (%arg3, %arg4) + shared_outs(%init0 = %arg1, %init1 = %arg2) -> (tensor<?x?xf32>, tensor<?x?xf32>) { + %tilesize0 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv0)[%dim0, %arg3] + %tilesize1 = affine.min affine_map<(d0)[s0, s1] -> (s1, s0 - d0)>(%iv1)[%dim1, %arg4] + %arg0_slice = tensor.extract_slice %arg0[%iv0, %iv1, 0] [%tilesize0, %tilesize1, %dim2] [1, 1, 1] + : tensor<?x?x?xf32> to tensor<?x?x?xf32> + %init0_slice = tensor.extract_slice %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> to tensor<?x?xf32> + %generic0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init0_slice : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.mulf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + %init1_slice = tensor.extract_slice %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> to tensor<?x?xf32> + %generic1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?x?xf32>) outs(%init1_slice: tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 + } -> tensor<?x?xf32> + scf.forall.in_parallel { + // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> into tensor<?x?xf32> + tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] + : tensor<?x?xf32> into tensor<?x?xf32> + } + } + %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%loop#0, %loop#1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0 : f32 + } -> tensor<?x?xf32> + return %result : tensor<?x?xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer_using_slice %yield0, %yield1 in (%loop) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 7888462..0137e2a 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics --mlir-print-local-scope %s | FileCheck %s #map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + func.func @fuse_tilable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -28,14 +28,14 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %add = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield in (%loop) + %a, %new_loop = transform.test.fuse_consumer %add into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK: func.func @fuse_tilable_consumer_scf_for( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) @@ -60,8 +60,61 @@ module attributes {transform.with_named_sequence} { // ----- +#map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + func.func @fuse_tilable_consumer_nested_scf_for(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, + %lb0 : index, %ub0 : index, %step0 : index, + %lb1 : index, %ub1 : index, %step1 : index) -> tensor<?x?xf32> { + %0 = scf.for %arg3 = %lb0 to %ub0 step %step0 iter_args(%init0 = %arg0) -> tensor<?x?xf32> { + %1 = scf.for %arg4 = %lb1 to %ub1 step %step1 iter_args(%init1 = %init0) -> tensor<?x?xf32> { + %extracted_slice = tensor.extract_slice %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> + %2 = tensor.insert_slice %extracted_slice into %init1[%arg3, %arg4] [%step0, %step1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32> + scf.yield %2 : tensor<?x?xf32> + } + scf.yield %1 : tensor<?x?xf32> + } + %2 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> + return %2 : tensor<?x?xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loops = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop0, %loop1 = transform.split_handle %loops + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %add = transform.structured.match ops{["linalg.add"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %new_loop0, %new_loop1 = transform.test.fuse_consumer %add into (%loop0, %loop1) + : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func @fuse_tilable_consumer_nested_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32> +// CHECK: %[[OUTER_RESULT:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT00:[a-zA-Z0-9_]+]] = %[[ARG0]], %[[INIT01:[a-zA-Z0-9_]+]] = %[[ARG2]]) +// CHECK: %[[INNER_RESULT:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT10:[a-zA-Z0-9_]+]] = %[[INIT00]], %[[INIT11:[a-zA-Z0-9_]+]] = %[[INIT01]]) +// CHECK-DAG: %[[OPERAND1:.+]] = tensor.extract_slice %[[INIT10]] +// CHECK-DAG: %[[OLD_INSERT_SLICE:.+]] = tensor.insert_slice %[[OPERAND1]] into %[[INIT10]] +// CHECK-DAG: %[[OPERAND2:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-DAG: %[[INIT:.+]] = tensor.extract_slice %[[INIT11]] +// CHECK: %[[ADD:.+]] = linalg.add +// CHECK-SAME: ins(%[[OPERAND1]], %[[OPERAND2]] : +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[ADD]] into %[[INIT11]] +// CHECK: scf.yield %[[OLD_INSERT_SLICE]], %[[INSERT_SLICE]] +// CHECK: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1 +// CHECK: return %[[OUTER_RESULT]]#1 + +// ----- + +module { + func.func @fuse_tilable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -83,19 +136,16 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %add = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %first_slice_op, %second_slice_op = transform.split_handle %slice_ops - : (!transform.any_op) - -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %add into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_forall( +// CHECK: func.func @fuse_tilable_consumer_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>) @@ -124,7 +174,7 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0) -> (d0)> module { - func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -155,16 +205,18 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %generics + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %yield in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer( +// CHECK: func.func @fuse_tilable_consumer_scf_for_multi_yielding_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) @@ -193,7 +245,7 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { + func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) { %c4 = arith.constant 4 : index %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -224,19 +276,16 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %first_slice_op, %second_slice_op = transform.split_handle %slice_ops - : (!transform.any_op) - -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %first_slice_op in (%loop) + %a, %new_loops = transform.test.fuse_consumer %generic into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer( +// CHECK: func.func @fuse_tilable_consumer_scf_forall_multi_yielding_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32> @@ -293,17 +342,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)> // CHECK: func.func @fuse_unpack_consumer_into_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -315,8 +362,8 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) -// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2048)>(%[[IV1]]) // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] // CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] @@ -356,17 +403,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)> -// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)> // CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -378,8 +423,8 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]]) -// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 * 32)>(%[[IV1]]) +// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min affine_map<(d0) -> (1024, d0 * -32 + 2047)>(%[[IV1]]) // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1] // CHECK: %[[TILED_UNPACK_OUT:.*]] = linalg.unpack %[[GENERIC_OUT]] // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] @@ -419,16 +464,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) + %a, %new_loop = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_perfect_tiling_pack_consumer( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> @@ -440,7 +484,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : -// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV1]]) // CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] // CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[GENERIC_OUT]] // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16] @@ -471,13 +515,12 @@ func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> ten module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %consumer into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)> // CHECK: func.func @fuse_pack_consumer_if_single_iteration( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32> @@ -485,7 +528,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16) // CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]]) -// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]]) +// CHECK-DAG: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 4, 16)>(%[[IV]]) // CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] // CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp @@ -517,13 +560,12 @@ func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -535,7 +577,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]]) // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] @@ -566,13 +608,12 @@ func.func @fuse_pack_consumer_with_no_pad_dynamic_dim(%arg0: tensor<64x?xf32>, % module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_no_pad_dynamic_dim( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -584,7 +625,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) +// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[IV]]) // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0] [64, 1, 16] [1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16] @@ -616,16 +657,12 @@ func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, % module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] @@ -633,7 +670,7 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) // CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) -// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%[[I]]) // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] // CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] @@ -641,9 +678,9 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) -// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) -// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply affine_map<(d0) -> (d0 ceildiv 3)>(%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply affine_map<(d0) -> (d0 floordiv 16)>(%[[J]]) // CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] // CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] @@ -674,20 +711,21 @@ func.func @nofuse_pack_with_imperfect_tiling(%arg0: tensor<30xf32>) -> tensor<5x linalg.yield %5 : f32 } -> tensor<5xf32> scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} + tensor.parallel_insert_slice %4 into %arg2[%3] [5] [1] : tensor<5xf32> into tensor<30xf32> } } %2 = tensor.empty() : tensor<5x6xf32> + // expected-error @below {{failed to fuse consumer of slice}} %pack = linalg.pack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [6] into %2 : tensor<30xf32> -> tensor<5x6xf32> return %pack : tensor<5x6xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -717,11 +755,15 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %mulop = transform.structured.match ops{["linalg.mul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 2 + %fused_consumer, %new_loop = transform.test.fuse_consumer %mulop into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %expop = transform.structured.match ops{["linalg.exp"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %fused_consumer_2, %new_loop_2 = transform.test.fuse_consumer %expop into (%new_loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -741,64 +783,20 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] : // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] : // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp -// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : -// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : // CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] // CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul // CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] : // CHECK-SAME: outs(%[[MUL_OUT_SLICE]] : -// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] -// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] : -// CHECK: } -// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 : - -// ----- - -module { - func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) { - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f32 - %dest0 = tensor.empty() : tensor<256x256xf32> - %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) { - %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32> - %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32> - %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32> - scf.yield %insert_slice : tensor<256x256xf32> - } - %dest1 = tensor.empty() : tensor<258x258xf32> - %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32> - %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> - return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32> - } -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 - : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} -// CHECK: func.func @no_fuse_only_dps_consumer( -// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} { -// CHECK: linalg.add -// CHECK: linalg.mul -// CHECK: scf.yield +// CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp +// CHECK-SAME: ins(%[[TILED_ADD_OUT]] : +// CHECK-SAME: outs(%[[EXP_OUT_SLICE]] : +// CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1] +// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_MUL]], %[[INSERT_EXP]] : // CHECK: } -// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice -// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]] +// CHECK: return %[[LOOP_RESULT]]#1, %[[LOOP_RESULT]]#2 : // ----- @@ -829,40 +827,41 @@ module { } } -// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> -// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { -// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> -// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] -// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> -// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] -// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { -// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): -// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 -// CHECK: linalg.yield %[[VAL_23]] : f32 -// CHECK: } -> tensor<64x256x24xf32> -// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] -// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> -// CHECK: } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + %consumer = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %loop = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.test.fuse_consumer %slice_op in (%loop) num_consumer_to_fuse = 1 + %a, %b = transform.test.fuse_consumer %consumer into (%loop) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } +// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> { +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 64 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32> +// CHECK: %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) { +// CHECK: %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32> +// CHECK: %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1] +// CHECK: %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: %[[VAL_19:.*]] = linalg.generic +// CHECK-SAME: ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) { +// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32): +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: linalg.yield %[[VAL_23]] : f32 +// CHECK: } -> tensor<64x256x24xf32> +// CHECK: %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1] +// CHECK: scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32> +// CHECK: } // ----- @@ -878,12 +877,12 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> %generic:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice, %init1_slice : tensor<?xf32>, tensor<?xf32>) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): %0 = arith.mulf %b0, %b1 : f32 - %1 = arith.addf %b0, %b2 : f32 - linalg.yield %0, %1 : f32, f32 + %1 = arith.addf %b0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 } -> (tensor<?xf32>, tensor<?xf32>) scf.forall.in_parallel { tensor.parallel_insert_slice %generic#0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> @@ -901,6 +900,19 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % } -> tensor<?xf32> return %result : tensor<?xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %producer, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @multi_slice_fusion1( // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> // CHECK: %[[C0:.+]] = arith.constant 0 @@ -916,23 +928,9 @@ func.func @multi_slice_fusion1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % // CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] // CHECK: return %[[RESULT]]#2 -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} // ----- -// Check that when the given operand tiles are inconsistent, tiling fails. - func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -944,20 +942,20 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % %init0_slice = tensor.extract_slice %init0[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> %generic0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init0_slice : tensor<?xf32>) { ^bb0(%b0 : f32, %b1 : f32): %0 = arith.mulf %b0, %b1 : f32 - linalg.yield %0 : f32 + linalg.yield %0 : f32 } -> tensor<?xf32> %init1_slice = tensor.extract_slice %init1[%iv0] [%tilesize] [1] : tensor<?xf32> to tensor<?xf32> %generic1 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { + iterator_types = ["parallel", "reduction"]} + ins(%arg0_slice : tensor<?x?xf32>) outs(%init1_slice: tensor<?xf32>) { ^bb0(%b0 : f32, %b1 : f32): - %0 = arith.addf %b0, %b1 : f32 - linalg.yield %0: f32 + %0 = arith.addf %b0, %b1 : f32 + linalg.yield %0: f32 } -> tensor<?xf32> scf.forall.in_parallel { tensor.parallel_insert_slice %generic0 into %init0[%iv0] [%tilesize] [1] : tensor<?xf32> into tensor<?xf32> @@ -975,6 +973,19 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % } -> tensor<?xf32> return %result : tensor<?xf32> } +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.forall"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %producer1, %producer2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + // CHECK-LABEL: func @multi_slice_fusion2( // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> // CHECK: %[[C0:.+]] = arith.constant 0 @@ -991,19 +1002,6 @@ func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, % // CHECK: tensor.parallel_insert_slice %[[FUSED]] into %[[INIT]][%[[IV]]] [%[[TILESIZE]]] // CHECK: return %[[RESULT]]#2 -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { - %loop = transform.structured.match ops{["scf.forall"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 - : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) - transform.yield - } -} - // ----- func.func @multi_slice_fusion_with_broadcast(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>, @@ -1060,11 +1058,11 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } @@ -1124,7 +1122,6 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor< linalg.yield %0: f32 } -> tensor<?x?xf32> scf.forall.in_parallel { - // expected-error @below {{failed to fuse consumer of slice}} tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32> tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1] @@ -1132,6 +1129,7 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor< } } %empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32> + // expected-error @below {{failed to fuse consumer of slice}} %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -1146,11 +1144,11 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %loop = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %yield0, %yield1 = transform.split_handle %yield : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.test.fuse_consumer %yield0, %yield1 in (%loop) - : (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %producer_1, %producer_2, %consumer = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_consumer %consumer into (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } diff --git a/mlir/test/Pass/invalid-unsupported-operation.mlir b/mlir/test/Pass/invalid-unsupported-operation.mlir new file mode 100644 index 0000000..1ee4584 --- /dev/null +++ b/mlir/test/Pass/invalid-unsupported-operation.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -test-print-liveness -split-input-file -verify-diagnostics + +// Unnamed modules do not implement SymbolOpInterface. +// expected-error-re @+1 {{trying to schedule pass '{{.*}}TestLivenessPass' on an unsupported operation}} +module {} + +// ----- + +// Named modules implement SymbolOpInterface. +module @named_module {} diff --git a/mlir/test/Pass/pipeline-invalid.mlir b/mlir/test/Pass/pipeline-invalid.mlir index 948a133..bff2b1c 100644 --- a/mlir/test/Pass/pipeline-invalid.mlir +++ b/mlir/test/Pass/pipeline-invalid.mlir @@ -15,5 +15,5 @@ arith.constant 0 // ----- -// expected-error@below {{trying to schedule a pass on an unsupported operation}} +// expected-error-re@below {{trying to schedule pass '{{.*}}TestFunctionPass' on an unsupported operation}} module {} diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir index 294e6af6..f397a4a 100644 --- a/mlir/test/Target/Cpp/common-cpp.mlir +++ b/mlir/test/Target/Cpp/common-cpp.mlir @@ -105,6 +105,25 @@ func.func @apply() -> !emitc.ptr<i32> { return %1 : !emitc.ptr<i32> } + +// CHECK-LABEL: void address_of() { +func.func @address_of() { + // CHECK-NEXT: int32_t [[V1:[^ ]*]]; + %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32> + // CHECK-NEXT: int32_t* [[V2:[^ ]*]] = &[[V1]]; + %1 = emitc.address_of %0 : !emitc.lvalue<i32> + return +} + +// CHECK-LABEL: void dereference +// CHECK-SAME: (int32_t* [[ARG0:[^ ]*]]) { +func.func @dereference(%arg0: !emitc.ptr<i32>) { + // CHECK-NEXT: int32_t [[V1:[^ ]*]] = *[[ARG0]]; + %2 = emitc.dereference %arg0 : !emitc.ptr<i32> + emitc.load %2 : !emitc.lvalue<i32> + return +} + // CHECK: void array_type(int32_t v1[3], float v2[10][20]) func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) { return diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 9f1c816..2de94d0 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -314,14 +314,14 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) return %v_load : i32 } -// CPP-DEFAULT: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { +// CPP-DEFAULT: int32_t expression_with_dereference_apply(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { // CPP-DEFAULT-NEXT: return *([[VAL_2]] - [[VAL_1]]); // CPP-DEFAULT-NEXT: } -// CPP-DECLTOP: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { +// CPP-DECLTOP: int32_t expression_with_dereference_apply(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { // CPP-DECLTOP-NEXT: return *([[VAL_2]] - [[VAL_1]]); // CPP-DECLTOP-NEXT: } -emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 { +emitc.func @expression_with_dereference_apply(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 { %c = emitc.expression %arg1, %arg2 : (i32, !emitc.ptr<i32>) -> i32 { %e = emitc.sub %arg2, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32> %d = emitc.apply "*"(%e) : (!emitc.ptr<i32>) -> i32 @@ -330,6 +330,28 @@ emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i return %c : i32 } +// CPP-DEFAULT: bool expression_with_address_taken_apply(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42; +// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: bool expression_with_address_taken_apply(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = 42; +// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; +// CPP-DECLTOP-NEXT: } + +func.func @expression_with_address_taken_apply(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 { + %a = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32> + %c = emitc.expression %arg1, %arg2, %a : (i32, !emitc.ptr<i32>, !emitc.lvalue<i32>) -> i1 { + %d = emitc.apply "&"(%a) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + %e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32> + %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1 + emitc.yield %f : i1 + } + return %c : i1 +} + // CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) { // CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42; // CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]]; @@ -344,7 +366,7 @@ emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 { %a = "emitc.variable"(){value = 42 : i32} : () -> !emitc.lvalue<i32> %c = emitc.expression %arg1, %arg2, %a : (i32, !emitc.ptr<i32>, !emitc.lvalue<i32>) -> i1 { - %d = emitc.apply "&"(%a) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> + %d = emitc.address_of %a : !emitc.lvalue<i32> %e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32> %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1 emitc.yield %f : i1 diff --git a/mlir/test/Target/LLVMIR/Import/debug-info-records.ll b/mlir/test/Target/LLVMIR/Import/debug-info-records.ll new file mode 100644 index 0000000..077871e --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/debug-info-records.ll @@ -0,0 +1,87 @@ +; RUN: mlir-translate -import-llvm -mlir-print-debuginfo -convert-debug-rec-to-intrinsics -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s +; RUN: mlir-translate -import-llvm -mlir-print-debuginfo -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s + +; CHECK: #[[LOCAL_VAR0:.*]] = #llvm.di_local_variable<scope = #di_lexical_block> +; CHECK: #[[LOCAL_VAR1:.*]] = #llvm.di_local_variable<scope = #di_lexical_block_file, name = "arg" +; CHECK: #[[LOCAL_VAR2:.*]] = #llvm.di_local_variable<scope = #di_lexical_block, name = "alloc" + +; CHECK: @callee() +define void @callee() { + ret void +} + +define void @func_with_empty_named_info() { + call void @callee() + ret void +} + +define void @func_no_debug() { + ret void +} + +; CHECK: llvm.func @func_with_debug(%[[ARG0:.*]]: i64 +define void @func_with_debug(i64 %0) !dbg !3 { + + ; CHECK: llvm.intr.dbg.value #[[LOCAL_VAR0]] = %[[ARG0]] : i64 + ; CHECK: llvm.intr.dbg.value #[[LOCAL_VAR1]] #llvm.di_expression<[DW_OP_LLVM_fragment(0, 1)]> = %[[ARG0]] : i64 + ; CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 + ; CHECK: %[[ADDR:.*]] = llvm.alloca %[[CST]] x i64 + ; CHECK: llvm.intr.dbg.declare #[[LOCAL_VAR2]] #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_convert(4, DW_ATE_signed)]> = %[[ADDR]] : !llvm.ptr + %2 = alloca i64, align 8, !dbg !19 + #dbg_value(i64 %0, !20, !DIExpression(DW_OP_LLVM_fragment, 0, 1), !22) + #dbg_declare(ptr %2, !23, !DIExpression(DW_OP_deref, DW_OP_LLVM_convert, 4, DW_ATE_signed), !25) + #dbg_value(i64 %0, !26, !DIExpression(), !27) + call void @func_no_debug(), !dbg !28 + %3 = add i64 %0, %0, !dbg !32 + ret void, !dbg !37 +} + +define void @empty_types() !dbg !38 { + ret void, !dbg !44 +} + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!2} + +!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "MLIR", isOptimized: true, runtimeVersion: 0, splitDebugFilename: "test.dwo", emissionKind: FullDebug, nameTableKind: None) +!1 = !DIFile(filename: "foo.mlir", directory: "/test/") +!2 = !{i32 2, !"Debug Info Version", i32 3} +!3 = distinct !DISubprogram(name: "func_with_debug", linkageName: "func_with_debug", scope: !4, file: !1, line: 3, type: !6, scopeLine: 3, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) +!4 = !DINamespace(name: "nested", scope: !5) +!5 = !DINamespace(name: "toplevel", scope: null, exportSymbols: true) +!6 = !DISubroutineType(cc: DW_CC_normal, types: !7) +!7 = !{null, !8, !9, !11, !12, !13, !16} +!8 = !DIBasicType(name: "si64") +!9 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !10, size: 64, align: 32, offset: 8, extraData: !10) +!10 = !DIBasicType(name: "si32", size: 32, encoding: DW_ATE_signed) +!11 = !DIDerivedType(tag: DW_TAG_pointer_type, name: "named", baseType: !10) +!12 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !10, size: 64, align: 32, offset: 8, dwarfAddressSpace: 3) +!13 = distinct !DICompositeType(tag: DW_TAG_structure_type, name: "composite", file: !1, line: 42, size: 64, align: 32, elements: !14) +!14 = !{!15} +!15 = !DISubrange(count: 4) +!16 = !DICompositeType(tag: DW_TAG_array_type, name: "array", file: !1, baseType: !8, flags: DIFlagVector, elements: !17) +!17 = !{!18} +!18 = !DISubrange(lowerBound: 0, upperBound: 4, stride: 1) +!19 = !DILocation(line: 100, column: 12, scope: !3) +!20 = !DILocalVariable(name: "arg", arg: 1, scope: !21, file: !1, line: 6, type: !8, align: 32) +!21 = distinct !DILexicalBlockFile(scope: !3, file: !1, discriminator: 0) +!22 = !DILocation(line: 103, column: 3, scope: !3) +!23 = !DILocalVariable(name: "alloc", scope: !24) +!24 = distinct !DILexicalBlock(scope: !3) +!25 = !DILocation(line: 106, column: 3, scope: !3) +!26 = !DILocalVariable(scope: !24) +!27 = !DILocation(line: 109, column: 3, scope: !3) +!28 = !DILocation(line: 1, column: 2, scope: !3) +!32 = !DILocation(line: 2, column: 4, scope: !33, inlinedAt: !36) +!33 = distinct !DISubprogram(name: "callee", scope: !13, file: !1, type: !34, spFlags: DISPFlagDefinition, unit: !0) +!34 = !DISubroutineType(types: !35) +!35 = !{!8, !8} +!36 = !DILocation(line: 28, column: 5, scope: !3) +!37 = !DILocation(line: 135, column: 3, scope: !3) +!38 = distinct !DISubprogram(name: "empty_types", scope: !39, file: !1, type: !40, spFlags: DISPFlagDefinition, unit: !0, annotations: !42) +!39 = !DIModule(scope: !1, name: "module", configMacros: "bar", includePath: "/", apinotes: "/", file: !1, line: 42, isDecl: true) +!40 = !DISubroutineType(cc: DW_CC_normal, types: !41) +!41 = !{} +!42 = !{!43} +!43 = !{!"foo", !"bar"} +!44 = !DILocation(line: 140, column: 3, scope: !38) diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll index 83c0438..023b012 100644 --- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll @@ -22,14 +22,14 @@ define dso_local void @dsolocal_func() { ; // ----- ; CHECK-LABEL: @func_readnone -; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} +; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} ; CHECK: llvm.return define void @func_readnone() readnone { ret void } ; CHECK-LABEL: @func_readnone_indirect -; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} +; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} declare void @func_readnone_indirect() #0 attributes #0 = { readnone } @@ -169,7 +169,7 @@ define void @entry_count() !prof !1 { ; // ----- ; CHECK-LABEL: @func_memory -; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite>} +; CHECK-SAME: attributes {memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite, errnoMem = readwrite, targetMem0 = readwrite, targetMem1 = readwrite>} ; CHECK: llvm.return define void @func_memory() memory(readwrite, argmem: none) { ret void diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index d48be66..32f730b 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -1,16 +1,14 @@ ; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s -; Check that debug intrinsics with an unsupported argument are dropped. - -declare void @llvm.dbg.value(metadata, metadata, metadata) +; Check that debug records with an unsupported argument are dropped. ; CHECK: import-failure.ll -; CHECK-SAME: warning: dropped intrinsic: tail call void @llvm.dbg.value(metadata !DIArgList(i64 %{{.*}}, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)) +; CHECK-SAME: warning: unhandled debug variable record #dbg_value(!DIArgList(i64 %{{.*}}, i64 undef), !{{.*}}, !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value), !{{.*}}) ; CHECK: import-failure.ll -; CHECK-SAME: warning: dropped intrinsic: tail call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()) +; CHECK-SAME: warning: unhandled debug variable record #dbg_value(!{{.*}}, !{{.*}}, !DIExpression(), !{{.*}}) define void @unsupported_argument(i64 %arg1) { - tail call void @llvm.dbg.value(metadata !DIArgList(i64 %arg1, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)), !dbg !5 - tail call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()), !dbg !5 + #dbg_value(!DIArgList(i64 %arg1, i64 undef), !3, !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value), !5) + #dbg_value(!6, !3, !DIExpression(), !5) ret void } diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll index be245e3..7f9c511 100644 --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -703,13 +703,13 @@ declare void @f() ; CHECK-LABEL: @call_memory_effects define void @call_memory_effects() { -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} call void @f() memory(none) -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = read>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = read, errnoMem = none, targetMem0 = none, targetMem1 = none>} call void @f() memory(none, argmem: write, inaccessiblemem: read) -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = write, argMem = none, inaccessibleMem = write>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = write, argMem = none, inaccessibleMem = write, errnoMem = write, targetMem0 = write, targetMem1 = write>} call void @f() memory(write, argmem: none) -; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = read>} +; CHECK: llvm.call @f() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = read, errnoMem = readwrite, targetMem0 = readwrite, targetMem1 = readwrite>} call void @f() memory(readwrite, inaccessiblemem: read) ; CHECK: llvm.call @f() ; CHECK-NOT: #llvm.memory_effects diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index d2bb809..2381d7a 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -1128,6 +1128,34 @@ define void @experimental_constrained_fpext(float %s, <4 x float> %v) { ret void } +; CHECK-LABEL: llvm.func @ucmp +define i2 @ucmp(i32 %a, i32 %b) { + ; CHECK: %{{.*}} = llvm.intr.ucmp(%{{.*}}, %{{.*}}) : (i32, i32) -> i2 + %r = call i2 @llvm.ucmp.i2.i32(i32 %a, i32 %b) + ret i2 %r +} + +; CHECK-LABEL: llvm.func @vector_ucmp +define <4 x i32> @vector_ucmp(<4 x i32> %a, <4 x i32> %b) { + ; CHECK: %{{.*}} = llvm.intr.ucmp(%{{.*}}, %{{.*}}) : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + %r = call <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %r +} + +; CHECK-LABEL: llvm.func @scmp +define i2 @scmp(i32 %a, i32 %b) { + ; CHECK: %{{.*}} = llvm.intr.scmp(%{{.*}}, %{{.*}}) : (i32, i32) -> i2 + %r = call i2 @llvm.scmp.i2.i32(i32 %a, i32 %b) + ret i2 %r +} + +; CHECK-LABEL: llvm.func @vector_scmp +define <4 x i32> @vector_scmp(<4 x i32> %a, <4 x i32> %b) { + ; CHECK: %{{.*}} = llvm.intr.scmp(%{{.*}}, %{{.*}}) : (vector<4xi32>, vector<4xi32>) -> vector<4xi32> + %r = call <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %r +} + declare float @llvm.fmuladd.f32(float, float, float) declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) declare float @llvm.fma.f32(float, float, float) @@ -1382,3 +1410,7 @@ declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x doubl declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata) declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f32(<4 x float>, metadata) declare double @llvm.experimental.constrained.fpext.f64.f32(float, metadata) +declare i2 @llvm.ucmp.i2.i32(i32, i32) +declare <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32>, <4 x i32>) +declare i2 @llvm.scmp.i2.i32(i32, i32) +declare <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32>, <4 x i32>) diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll index c623df0..3280625 100644 --- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll @@ -16,6 +16,22 @@ bb2: ; // ----- +; CHECK-LABEL: @cond_br_expected +define i64 @cond_br_expected(i1 %arg1, i64 %arg2) { +entry: + ; CHECK: llvm.cond_br + ; CHECK-SAME: weights([1, 2000]) + br i1 %arg1, label %bb1, label %bb2, !prof !0 +bb1: + ret i64 %arg2 +bb2: + ret i64 %arg2 +} + +!0 = !{!"branch_weights", !"expected", i32 1, i32 2000} + +; // ----- + ; CHECK-LABEL: @simple_switch( define i32 @simple_switch(i32 %arg1) { ; CHECK: llvm.switch @@ -36,6 +52,26 @@ bbd: ; // ----- +; CHECK-LABEL: @simple_switch_expected( +define i32 @simple_switch_expected(i32 %arg1) { + ; CHECK: llvm.switch + ; CHECK: {branch_weights = array<i32: 1, 1, 2000>} + switch i32 %arg1, label %bbd [ + i32 0, label %bb1 + i32 9, label %bb2 + ], !prof !0 +bb1: + ret i32 %arg1 +bb2: + ret i32 %arg1 +bbd: + ret i32 %arg1 +} + +!0 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000} + +; // ----- + ; Verify that a single weight attached to a call is not translated. ; The MLIR WeightedBranchOpInterface does not support this case. diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir new file mode 100644 index 0000000..95d12f3 --- /dev/null +++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir @@ -0,0 +1,99 @@ +// Tests single-team by-ref GPU reductions. + +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + omp.private {type = private} @_QFfooEi_private_i32 : i32 + omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + omp.yield(%2 : !llvm.ptr) + } init { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + omp.yield(%arg1 : !llvm.ptr) + } combiner { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + %3 = llvm.mlir.constant(1 : i32) : i32 + %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %6 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %7 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr + %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr + %12 = llvm.load %9 : !llvm.ptr -> f32 + %13 = llvm.load %11 : !llvm.ptr -> f32 + %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32 + llvm.store %14, %9 : f32, !llvm.ptr + omp.yield(%arg0 : !llvm.ptr) + } data_ptr_ptr { + ^bb0(%arg0: !llvm.ptr): + %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + omp.yield(%0 : !llvm.ptr) + } + + llvm.func @foo_() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(implicit, tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""} + %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, implicit, descriptor, to) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"} + omp.target map_entries(%10 -> %arg0 : !llvm.ptr) { + %13 = llvm.mlir.constant(1000 : i32) : i32 + %14 = llvm.mlir.constant(1 : i32) : i32 + omp.parallel { + omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg4 : !llvm.ptr) { + omp.loop_nest (%arg5) : i32 = (%14) to (%13) inclusive step (%14) { + omp.yield + } + } + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// CHECK: define {{.*}} @_omp_reduction_shuffle_and_reduce_func({{.*}}) {{.*}} { +// CHECK: %[[REMOTE_RED_LIST:.omp.reduction.remote_reduce_list]] = alloca [1 x ptr], align 8, addrspace(5) +// CHECK: %[[RED_ELEM:.omp.reduction.element]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5) +// CHECK: %[[RED_ELEM_1:.*]] = addrspacecast ptr addrspace(5) %[[RED_ELEM]] to ptr + +// CHECK: %[[SHUFFLE_ELEM:.*]] = alloca float, align 4, addrspace(5) +// CHECK: %[[REMOTE_RED_LIST_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[REMOTE_RED_LIST]] to ptr + +// CHECK: %[[REMOTE_RED_LIST_ELEM0:.*]] = getelementptr inbounds [1 x ptr], ptr %[[REMOTE_RED_LIST_ASCAST]], i64 0, i64 0 + +// CHECK: %[[SHUFFLE_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[SHUFFLE_ELEM]] to ptr +// CHECK: %[[SHUFFLE_RES:.*]] = call i32 @__kmpc_shuffle_int32({{.*}}) +// CHECK: store i32 %[[SHUFFLE_RES]], ptr %[[SHUFFLE_ELEM_ASCAST]], align 4 + +// CHECK: %[[RED_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RED_ELEM]] to ptr +// CHECK: %[[RED_ALLOC_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_ASCAST]], i32 0, i32 0 +// CHECK: %[[SHUFFLE_ELEM_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[SHUFFLE_ELEM]] to ptr +// CHECK: store ptr %[[SHUFFLE_ELEM_ASCAST]], ptr %[[RED_ALLOC_PTR]], align 8 +// CHECK: store ptr %[[RED_ELEM_1]], ptr %[[REMOTE_RED_LIST_ELEM0]], align 8 +// CHECK: } + +// CHECK: define {{.*}} @_omp_reduction_inter_warp_copy_func({{.*}}) {{.*}} { +// CHECK: %[[WARP_MASTER_CMP:.*]] = icmp eq i32 %nvptx_lane_id, 0 +// CHECK: br i1 %[[WARP_MASTER_CMP]], label %[[WARP_MASTER_BB:.*]], label %{{.*}} + +// CHECK: [[WARP_MASTER_BB]]: +// CHECK: %[[WARP_RESULT_PTR:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[WARP_RESULT:.*]] = load ptr, ptr %[[WARP_RESULT_PTR]], align 8 +// CHECK: %[[ALLOC_MEM_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[WARP_RESULT]], i32 0, i32 0 +// CHECK: %[[ALLOC_MEM:.*]] = load ptr, ptr %[[ALLOC_MEM_PTR]], align 8 +// CHECK: %[[WARP_TRANSFER_SLOT:.*]] = getelementptr inbounds [32 x i32], ptr addrspace(3) @__openmp_nvptx_data_transfer_temporary_storage, i64 0, i32 %nvptx_warp_id +// CHECK: %[[WARP_RED_RES:.*]] = load i32, ptr %[[ALLOC_MEM]], align 4 +// CHECK: store volatile i32 %[[WARP_RED_RES]], ptr addrspace(3) %[[WARP_TRANSFER_SLOT]], align 4 +// CHECK: } diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir new file mode 100644 index 0000000..1c73a49 --- /dev/null +++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir @@ -0,0 +1,121 @@ +// Tests cross-teams by-ref GPU reductions. + +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + omp.private {type = private} @_QFfooEi_private_i32 : i32 + omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + omp.yield(%2 : !llvm.ptr) + } init { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + omp.yield(%arg1 : !llvm.ptr) + } combiner { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr + %3 = llvm.mlir.constant(1 : i32) : i32 + %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %6 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %7 = llvm.mlir.constant(24 : i32) : i32 + "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr + %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr + %12 = llvm.load %9 : !llvm.ptr -> f32 + %13 = llvm.load %11 : !llvm.ptr -> f32 + %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32 + llvm.store %14, %9 : f32, !llvm.ptr + omp.yield(%arg0 : !llvm.ptr) + } data_ptr_ptr { + ^bb0(%arg0: !llvm.ptr): + %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + omp.yield(%0 : !llvm.ptr) + } + + llvm.func @foo_() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5> + %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr + %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""} + %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, descriptor, to, attach) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"} + omp.target map_entries(%10 -> %arg0 : !llvm.ptr) { + %14 = llvm.mlir.constant(1000000 : i32) : i32 + %15 = llvm.mlir.constant(1 : i32) : i32 + omp.teams reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg3 : !llvm.ptr) { + omp.parallel { + omp.distribute { + omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg3 -> %arg5 : !llvm.ptr) { + omp.loop_nest (%arg6) : i32 = (%15) to (%14) inclusive step (%15) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator + } + llvm.return + } +} + +// CHECK: %[[GLOBALIZED_LOCALS:.*]] = type { float } + +// CHECK: define internal void @_omp_reduction_list_to_global_copy_func({{.*}}) {{.*}} { +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0 +// CHECK: %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: %[[ALLOC_VAL:.*]] = load float, ptr %[[ALLOC_PTR]], align 4 +// Verify that the actual value managed by the descriptor is stored in the globalized +// locals arrays; rather than a pointer to the descriptor or a pointer to the value. +// CHECK: store float %[[ALLOC_VAL]], ptr %[[GLOB_ELEM_PTR]], align 4 +// CHECK: } + +// CHECK: define internal void @_omp_reduction_list_to_global_reduce_func({{.*}}) {{.*}} { +// Allocate a descriptor to manage the element retrieved from the globalized local array. +// CHECK: %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5) +// CHECK: %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr + +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0 +// Store the pointer to the gloalized local element into the locally allocated descriptor. +// CHECK: store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8 +// CHECK: } + +// CHECK: define internal void @_omp_reduction_global_to_list_copy_func({{.*}}) {{.*}} { +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LIST]], align 8 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[RED_ELEM_PTR]], i32 0, i32 0 +// Similar to _omp_reduction_list_to_global_copy_func(...) but in the reverse direction; i.e. +// the globalized local array is copied from rather than copied to. +// CHECK: %[[ALLOC_PTR:.*]] = load ptr, ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: %[[ALLOC_VAL:.*]] = load float, ptr %[[GLOB_ELEM_PTR]], align 4 +// CHECK: store float %[[ALLOC_VAL]], ptr %[[ALLOC_PTR]], align 4 +// CHECK: } + +// CHECK: define internal void @_omp_reduction_global_to_list_reduce_func({{.*}}) {{.*}} { +// Allocate a descriptor to manage the element retrieved from the globalized local array. +// CHECK: %[[ALLOC_DESC:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, align 8, addrspace(5) +// CHECK: %[[ALLOC_DESC_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ALLOC_DESC]] to ptr + +// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0 +// CHECK: %[[GLOB_ELEM_PTR:.*]] = getelementptr inbounds %[[GLOBALIZED_LOCALS]], ptr %{{.*}}, i32 0, i32 0 +// CHECK: %[[ALLOC_PTR_PTR:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[ALLOC_DESC_ASCAST]], i32 0, i32 0 +// Store the pointer to the gloalized local element into the locally allocated descriptor. +// CHECK: store ptr %[[GLOB_ELEM_PTR]], ptr %[[ALLOC_PTR_PTR]], align 8 +// CHECK: store ptr %[[ALLOC_DESC_ASCAST]], ptr %[[RED_ARR_LIST]], align 8 +// CHECK: } diff --git a/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir b/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir new file mode 100644 index 0000000..b54bfe4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/anonymous-tbaa.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +#tbaa_root_0 = #llvm.tbaa_root<> +#tbaa_type_desc_1 = #llvm.tbaa_type_desc<id = "omnipotent char", members = {<#tbaa_root_0, 0>}> +#tbaa_type_desc_2 = #llvm.tbaa_type_desc<id = "long long", members = {<#tbaa_type_desc_1, 0>}> +#tbaa_tag_3 = #llvm.tbaa_tag<access_type = #tbaa_type_desc_2, base_type = #tbaa_type_desc_2, offset = 0> + +// CHECK: define void @tbaa_anonymous_root(ptr %{{.*}}) { +// CHECK: %{{.*}} = load i64, ptr %{{.*}}, align 4, !tbaa ![[TAG:[0-9]+]] +// CHECK: ret void +// CHECK: } +// CHECK: !llvm.module.flags = !{![[FLAGS:[0-9]+]]} +// CHECK: ![[FLAGS]] = !{i32 2, !"Debug Info Version", i32 3} +// CHECK: ![[TAG]] = !{![[TYPE:[0-9]+]], ![[TYPE]], i64 0} +// CHECK: ![[TYPE]] = !{!"long long", ![[BASE:[0-9]+]], i64 0} +// CHECK: ![[BASE]] = !{!"omnipotent char", ![[ROOT:[0-9]+]], i64 0} +// CHECK: ![[ROOT]] = distinct !{![[ROOT]]} +llvm.func @tbaa_anonymous_root(%arg0: !llvm.ptr) { + %0 = llvm.load %arg0 {tbaa = [#tbaa_tag_3]} : !llvm.ptr -> i64 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 1e4cf8d..403c73f 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -1276,6 +1276,34 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) { llvm.return } +// CHECK-LABEL: @ucmp +llvm.func @ucmp(%a: i32, %b: i32) -> i2 { + // CHECK: call i2 @llvm.ucmp.i2.i32 + %r = llvm.intr.ucmp(%a, %b) : (i32, i32) -> i2 + llvm.return %r : i2 +} + +// CHECK-LABEL: @vector_ucmp +llvm.func @vector_ucmp(%a: vector<4 x i32>, %b: vector<4 x i32>) -> vector<4 x i32> { + // CHECK: call <4 x i32> @llvm.ucmp.v4i32.v4i32 + %0 = llvm.intr.ucmp(%a, %b) : (vector<4 x i32>, vector<4 x i32>) -> vector<4 x i32> + llvm.return %0 : vector<4 x i32> +} + +// CHECK-LABEL: @scmp +llvm.func @scmp(%a: i32, %b: i32) -> i2 { + // CHECK: call i2 @llvm.scmp.i2.i32 + %r = llvm.intr.scmp(%a, %b) : (i32, i32) -> i2 + llvm.return %r : i2 +} + +// CHECK-LABEL: @vector_scmp +llvm.func @vector_scmp(%a: vector<4 x i32>, %b: vector<4 x i32>) -> vector<4 x i32> { + // CHECK: call <4 x i32> @llvm.scmp.v4i32.v4i32 + %0 = llvm.intr.scmp(%a, %b) : (vector<4 x i32>, vector<4 x i32>) -> vector<4 x i32> + llvm.return %0 : vector<4 x i32> +} + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -1464,3 +1492,7 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) { // CHECK-DAG: declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(<4 x float>, metadata, metadata) // CHECK-DAG: declare double @llvm.experimental.constrained.fpext.f64.f32(float, metadata) // CHECK-DAG: declare <4 x double> @llvm.experimental.constrained.fpext.v4f64.v4f32(<4 x float>, metadata) +// CHECK-DAG: declare range(i2 -1, -2) i2 @llvm.ucmp.i2.i32(i32, i32) +// CHECK-DAG: declare range(i32 -1, 2) <4 x i32> @llvm.ucmp.v4i32.v4i32(<4 x i32>, <4 x i32>) +// CHECK-DAG: declare range(i2 -1, -2) i2 @llvm.scmp.i2.i32(i32, i32) +// CHECK-DAG: declare range(i32 -1, 2) <4 x i32> @llvm.scmp.v4i32.v4i32(<4 x i32>, <4 x i32>) diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index cc243c8..819a514 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -78,6 +78,9 @@ llvm.mlir.global internal @f8E8M0FNU_global_as_i8(1.0 : f8E8M0FNU) : i8 // CHECK: @bf16_global_as_i16 = internal global i16 16320 llvm.mlir.global internal @bf16_global_as_i16(1.5 : bf16) : i16 +// CHECK: @bool_global_as_i8 = internal global i8 1 +llvm.mlir.global internal @bool_global_as_i8(true) : i8 + // CHECK: @explicit_undef = global i32 undef llvm.mlir.global external @explicit_undef() : i32 { %0 = llvm.mlir.undef : i32 @@ -2371,17 +2374,17 @@ llvm.func @readonly_function(%arg0: !llvm.ptr {llvm.readonly}) // CHECK: declare void @arg_mem_none_func() #[[ATTR:[0-9]+]] llvm.func @arg_mem_none_func() attributes { - memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite>} + memory_effects = #llvm.memory_effects<other = readwrite, argMem = none, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} -// CHECK: attributes #[[ATTR]] = { memory(readwrite, argmem: none, errnomem: none) } +// CHECK: attributes #[[ATTR]] = { memory(readwrite, argmem: none, errnomem: none, target_mem0: none, target_mem1: none) } // ----- // CHECK: declare void @readwrite_func() #[[ATTR:[0-9]+]] llvm.func @readwrite_func() attributes { - memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = readwrite>} + memory_effects = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} -// CHECK: attributes #[[ATTR]] = { memory(readwrite, errnomem: none) } +// CHECK: attributes #[[ATTR]] = { memory(readwrite, errnomem: none, target_mem0: none, target_mem1: none) } // ----- @@ -2723,10 +2726,10 @@ llvm.func @fd() // CHECK: call void @fc() #[[ATTRS_2:[0-9]+]] // CHECK: call void @fd() #[[ATTRS_3:[0-9]+]] llvm.func @mem_effects_call() { - llvm.call @fa() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} : () -> () - llvm.call @fb() {memory_effects = #llvm.memory_effects<other = read, argMem = none, inaccessibleMem = write>} : () -> () - llvm.call @fc() {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = write>} : () -> () - llvm.call @fd() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} : () -> () + llvm.call @fa() {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @fb() {memory_effects = #llvm.memory_effects<other = read, argMem = none, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @fc() {memory_effects = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = write, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () + llvm.call @fd() {memory_effects = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite, errnoMem = none, targetMem0 = none, targetMem1 = none>} : () -> () llvm.return } @@ -2734,11 +2737,11 @@ llvm.func @mem_effects_call() { // CHECK: #[[ATTRS_0]] // CHECK-SAME: memory(none) // CHECK: #[[ATTRS_1]] -// CHECK-SAME: memory(read, argmem: none, inaccessiblemem: write, errnomem: none) +// CHECK-SAME: memory(read, argmem: none, inaccessiblemem: write, errnomem: none, target_mem0: none, target_mem1: none) // CHECK: #[[ATTRS_2]] -// CHECK-SAME: memory(read, inaccessiblemem: write, errnomem: none) +// CHECK-SAME: memory(read, inaccessiblemem: write, errnomem: none, target_mem0: none, target_mem1: none) // CHECK: #[[ATTRS_3]] -// CHECK-SAME: memory(readwrite, argmem: read, errnomem: none) +// CHECK-SAME: memory(readwrite, argmem: read, errnomem: none, target_mem0: none, target_mem1: none) // ----- diff --git a/mlir/test/Target/LLVMIR/nvvm/barrier.mlir b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir new file mode 100644 index 0000000..a18633e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/barrier.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s --check-prefix=LLVM +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// LLVM-LABEL: @llvm_nvvm_barrier( +// LLVM-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]], i32 %[[redOperand:.*]]) +llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32, %redOperand : i32) { + // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) + // CHECK: nvvm.barrier + nvvm.barrier + // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]]) + // CHECK: nvvm.barrier id = %{{.*}} + nvvm.barrier id = %barID + // LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]]) + // CHECK: nvvm.barrier id = %{{.*}} number_of_threads = %{{.*}} + nvvm.barrier id = %barID number_of_threads = %numberOfThreads + // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.and(i32 %[[redOperand]]) + // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<and> %{{.*}} -> i32 + %0 = nvvm.barrier #nvvm.reduction<and> %redOperand -> i32 + // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.or(i32 %[[redOperand]]) + // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<or> %{{.*}} -> i32 + %1 = nvvm.barrier #nvvm.reduction<or> %redOperand -> i32 + // LLVM: %{{.*}} = call i32 @llvm.nvvm.barrier0.popc(i32 %[[redOperand]]) + // CHECK: %{{.*}} = nvvm.barrier #nvvm.reduction<popc> %{{.*}} -> i32 + %2 = nvvm.barrier #nvvm.reduction<popc> %redOperand -> i32 + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir new file mode 100644 index 0000000..a4bece8 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rn +llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rz +llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic +llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> + + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn +llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz +llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic +llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir index b5bb223..03abcdd 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir @@ -10,7 +10,7 @@ gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] { %f1 = llvm.mlir.constant(1.0 : f32) : f32 %f2 = llvm.mlir.constant(2.0 : f32) : f32 %rbits = llvm.mlir.constant(0x12345678 : i32) : i32 - %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16> + %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> return } } @@ -21,77 +21,13 @@ gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] { %f1 = llvm.mlir.constant(1.0 : f32) : f32 %f2 = llvm.mlir.constant(2.0 : f32) : f32 %rbits = llvm.mlir.constant(0 : i32) : i32 - %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16> + %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> return } } // ----- -// Test F32x2 -> F16x2 with stochastic rounding (.rs) - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs -llvm.func @convert_f32x2_to_f16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_satfinite -llvm.func @convert_f32x2_to_f16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu -llvm.func @convert_f32x2_to_f16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu_satfinite -llvm.func @convert_f32x2_to_f16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { - // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> - llvm.return %res : vector<2xf16> -} - -// ----- - -// Test F32x2 -> BF16x2 with stochastic rounding (.rs) - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs -llvm.func @convert_f32x2_to_bf16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_satfinite -llvm.func @convert_f32x2_to_bf16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu -llvm.func @convert_f32x2_to_bf16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu_satfinite -llvm.func @convert_f32x2_to_bf16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { - // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) - %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> - llvm.return %res : vector<2xbf16> -} - -// ----- - // Test F32x4 -> F8x4 (E4M3) with stochastic rounding (.rs) // CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs diff --git a/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir new file mode 100644 index 0000000..22578b5 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/fence-invalid.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-translate --mlir-to-llvmir -verify-diagnostics -split-input-file %s + +llvm.func @fence_sync_restrict() { + // expected-error @below {{only acquire and release semantics are supported}} + nvvm.fence.sync_restrict {order = #nvvm.mem_order<weak>} + llvm.return +} + +// ----- + +llvm.func @fence_sync_restrict() { + // expected-error @below {{only acquire and release semantics are supported}} + nvvm.fence.sync_restrict {order = #nvvm.mem_order<mmio>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{tensormap proxy is not a supported proxy kind}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<tensormap>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{generic proxy not a supported proxy kind}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<generic>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{async_shared fence requires space attribute}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy() { + // expected-error @below {{only async_shared fence can have space attribute}} + nvvm.fence.proxy {kind = #nvvm.proxy_kind<alias>, space = #nvvm.shared_space<cta>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy_release() { + // expected-error @below {{uni-directional proxies only support generic for from_proxy attribute}} + nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy = #nvvm.proxy_kind<alias> to_proxy = #nvvm.proxy_kind<tensormap> + llvm.return +} + +// ----- + +llvm.func @fence_proxy_release() { + // expected-error @below {{uni-directional proxies only support tensormap for to_proxy attribute}} + nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy = #nvvm.proxy_kind<generic> to_proxy = #nvvm.proxy_kind<async> + llvm.return +} + +// ----- + +llvm.func @fence_proxy_sync_restrict() { + // expected-error @below {{only acquire and release semantics are supported}} + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<mmio>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy_sync_restrict() { + // expected-error @below {{only async is supported for to_proxy attribute}} + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>, toProxy = #nvvm.proxy_kind<alias>, + fromProxy = #nvvm.proxy_kind<generic>} + llvm.return +} + +// ----- + +llvm.func @fence_proxy_sync_restrict() { + // expected-error @below {{only generic is support for from_proxy attribute}} + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>, toProxy = #nvvm.proxy_kind<async>, + fromProxy = #nvvm.proxy_kind<tensormap>} + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/fence.mlir b/mlir/test/Target/LLVMIR/nvvm/fence.mlir new file mode 100644 index 0000000..0ab4cb7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/fence.mlir @@ -0,0 +1,85 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @llvm_nvvm_fence_sc_cluster +llvm.func @llvm_nvvm_fence_sc_cluster() { + // CHECK: nvvm.fence.sc.cluster + nvvm.fence.sc.cluster + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_sync_restrict +llvm.func @nvvm_fence_sync_restrict() { + // CHECK: call void @llvm.nvvm.fence.acquire.sync_restrict.space.cluster.scope.cluster() + nvvm.fence.sync_restrict {order = #nvvm.mem_order<acquire>} + // CHECK: call void @llvm.nvvm.fence.release.sync_restrict.space.cta.scope.cluster() + nvvm.fence.sync_restrict {order = #nvvm.mem_order<release>} + llvm.return +} + +// CHECK-LABEL: @fence_mbarrier_init +llvm.func @fence_mbarrier_init() { + // CHECK: call void @llvm.nvvm.fence.mbarrier_init.release.cluster() + nvvm.fence.mbarrier.init + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy +llvm.func @nvvm_fence_proxy() { + // CHECK: call void @llvm.nvvm.fence.proxy.alias() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<alias>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async.global() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.global>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async.shared_cta() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>} + + // CHECK: call void @llvm.nvvm.fence.proxy.async.shared_cluster() + nvvm.fence.proxy {kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>} + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy_sync_restrict +llvm.func @nvvm_fence_proxy_sync_restrict() { + // CHECK: call void @llvm.nvvm.fence.proxy.async_generic.acquire.sync_restrict.space.cluster.scope.cluster() + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<acquire>} + // CHECK: call void @llvm.nvvm.fence.proxy.async_generic.release.sync_restrict.space.cta.scope.cluster() + nvvm.fence.proxy.sync_restrict {order = #nvvm.mem_order<release>} + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_release +llvm.func @nvvm_fence_proxy_tensormap_generic_release() { + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cta() + nvvm.fence.proxy.release #nvvm.mem_scope<cta> + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cluster() + nvvm.fence.proxy.release #nvvm.mem_scope<cluster> + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.gpu() + nvvm.fence.proxy.release #nvvm.mem_scope<gpu> + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.sys() + nvvm.fence.proxy.release #nvvm.mem_scope<sys> + llvm.return +} + +// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_acquire +llvm.func @nvvm_fence_proxy_tensormap_generic_acquire(%addr : !llvm.ptr) { + %c128 = llvm.mlir.constant(128) : i32 + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cta(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %c128 + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cluster(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<cluster> %addr, %c128 + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.gpu(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<gpu> %addr, %c128 + + // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.sys(ptr {{%[0-9]+}}, i32 128) + nvvm.fence.proxy.acquire #nvvm.mem_scope<sys> %addr, %c128 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir new file mode 100644 index 0000000..37756c8 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/invalid_convert_fp16x2.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to f16x2.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{random_bits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_f16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.f16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rounding(%a : f32, %b : f32) { + // expected-error @below {{Only RN, RZ, and RS rounding modes are supported for conversions from f32x2 to bf16x2.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16> + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_1(%a : f32, %b : f32) { + // expected-error @below {{random_bits is required for RS rounding mode.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> + llvm.return +} + +llvm.func @nvvm_cvt_f32x2_to_bf16x2_invalid_rbits_2(%a : f32, %b : f32, %rbits : i32) { + // expected-error @below {{random_bits not supported for RN and RZ rounding modes.}} + %res = nvvm.convert.f32x2.to.bf16x2 %a, %b, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir new file mode 100644 index 0000000..4b3cafe --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_drop_expect_tx.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_drop_expect_tx_generic(%barrier: !llvm.ptr, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 %1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64 + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr, i32 -> i64 + %2 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i64 + + %3 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr, i32 -> i64 + %4 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr, i32 -> i64 + %5 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64 + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 -> i64 + %2 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i64 + + %3 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %4 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %5 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.expect.tx.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>, i32 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir new file mode 100644 index 0000000..b5389bd --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arr_expect_tx.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_expect_tx_generic(%barrier: !llvm.ptr, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 %1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64 + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr, i32 -> i64 + %2 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i64 + + %3 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr, i32 -> i64 + %4 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr, i32 -> i64 + %5 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64 + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 -> i64 + %2 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i64 + + %3 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %4 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + %5 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %txcount : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.expect.tx.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7>, i32 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir new file mode 100644 index 0000000..6e7e163 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir @@ -0,0 +1,103 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_generic(%barrier: !llvm.ptr, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %3, i32 1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: %15 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %16 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %15, i32 %1) + // CHECK-NEXT: %17 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %18 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cta(ptr addrspace(3) %17, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr -> i64 + %1 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr -> i64 + %2 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr -> i64 + %3 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr -> i64 + + %4 = nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr -> i64 + %5 = nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr -> i64 + %6 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr -> i64 + %7 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %9 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64 + %1 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64 + %2 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3> -> i64 + %3 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3> -> i64 + + %4 = nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr<3> -> i64 + %5 = nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr<3> -> i64 + %6 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3> -> i64 + %7 = nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3> -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_shared_cluster(%barrier: !llvm.ptr<7>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive %barrier : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7> + + nvvm.mbarrier.arrive %barrier {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7> + llvm.return +} + +llvm.func @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_arrive_nocomplete(ptr %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete(ptr %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_arrive_nocomplete_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete.shared(ptr addrspace(3) %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir new file mode 100644 index 0000000..c345c5d --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive_drop.mlir @@ -0,0 +1,103 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_arrive_drop_generic(%barrier: !llvm.ptr, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_generic(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %3, i32 1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: %11 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %12 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %11, i32 1) + // CHECK-NEXT: %13 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %14 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %13, i32 %1) + // CHECK-NEXT: %15 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %16 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %15, i32 %1) + // CHECK-NEXT: %17 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %18 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cta(ptr addrspace(3) %17, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr -> i64 + %1 = nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr -> i64 + %2 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr -> i64 + %3 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr -> i64 + + %4 = nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr -> i64 + %5 = nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr -> i64 + %6 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr -> i64 + %7 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_shared(%barrier: !llvm.ptr<3>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i64 @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %7 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 1) + // CHECK-NEXT: %8 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %9 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %10 = call i64 @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<3> -> i64 + %1 = nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr<3> -> i64 + %2 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3> -> i64 + %3 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3> -> i64 + + %4 = nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr<3> -> i64 + %5 = nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr<3> -> i64 + %6 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<3> -> i64 + %7 = nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<3> -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_shared_cluster(%barrier: !llvm.ptr<7>, %count : i32) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.arrive.drop.relaxed.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7> + + nvvm.mbarrier.arrive_drop %barrier {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cta>, relaxed = true} : !llvm.ptr<7> + nvvm.mbarrier.arrive_drop %barrier, %count {scope = #nvvm.mem_scope<cluster>, relaxed = true} : !llvm.ptr<7> + llvm.return +} + +llvm.func @mbarrier_arrive_drop_nocomplete(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_nocomplete(ptr %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.noComplete(ptr %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive_drop.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @mbarrier_arrive_drop_nocomplete_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_arrive_drop_nocomplete_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.drop.noComplete.shared(ptr addrspace(3) %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + %0 = nvvm.mbarrier.arrive_drop.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir new file mode 100644 index 0000000..99289fa --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_complete_tx.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_complete_tx_shared(%barrier: !llvm.ptr<3>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_complete_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.complete_tx %barrier, %tx_count : !llvm.ptr<3>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 + + llvm.return +} + +llvm.func @mbarrier_complete_tx_shared_cluster(%barrier: !llvm.ptr<7>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_complete_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.complete.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.complete_tx %barrier, %tx_count : !llvm.ptr<7>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + llvm.return +}
\ No newline at end of file diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir new file mode 100644 index 0000000..dad7237 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_expect_tx.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_expect_tx_shared(%barrier: !llvm.ptr<3>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_expect_tx_shared(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.expect_tx %barrier, %tx_count : !llvm.ptr<3>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<3>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 + + llvm.return +} + +llvm.func @mbarrier_expect_tx_shared_cluster(%barrier: !llvm.ptr<7>, %tx_count : i32) { + // CHECK-LABEL: define void @mbarrier_expect_tx_shared_cluster(ptr addrspace(7) %0, i32 %1) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cta.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.expect.tx.scope.cluster.space.cluster(ptr addrspace(7) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.expect_tx %barrier, %tx_count : !llvm.ptr<7>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<7>, i32 + + llvm.return +}
\ No newline at end of file diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir new file mode 100644 index 0000000..9c1d1cc --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) { + // CHECK-LABEL: define void @cp_async_mbarrier_arrive(ptr addrspace(3) %0, ptr %1) { + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %1) + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %1) + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %0) + // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr + nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr + nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3> + nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3> + llvm.return +} + +llvm.func @mbarrier_init_generic(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_init_generic(ptr %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init(ptr %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32 + llvm.return +} + +llvm.func @mbarrier_init_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_init_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init.shared(ptr addrspace(3) %0, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %count = nvvm.read.ptx.sreg.ntid.x : i32 + nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32 + llvm.return +} + +llvm.func @mbarrier_inval_generic(%barrier: !llvm.ptr) { + // CHECK-LABEL: define void @mbarrier_inval_generic(ptr %0) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval(ptr %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.inval %barrier : !llvm.ptr + llvm.return +} + +llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) { + // CHECK-LABEL: define void @mbarrier_inval_shared(ptr addrspace(3) %0) { + // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval.shared(ptr addrspace(3) %0) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + nvvm.mbarrier.inval %barrier : !llvm.ptr<3> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir new file mode 100644 index 0000000..4a7776d --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir @@ -0,0 +1,138 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +// ----- + +llvm.func @mbarrier_arrive_ret_check(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arrive_invalid_scope(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %0 = nvvm.mbarrier.arrive %barrier {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arrive_drop_ret_check(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arrive_drop_invalid_scope(%barrier: !llvm.ptr<7>) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %0 = nvvm.mbarrier.arrive_drop %barrier {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7> -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_expect_tx_scope(%barrier: !llvm.ptr<7>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + nvvm.mbarrier.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<7>, i32 + llvm.return +} + +// ----- + +llvm.func @mbarrier_complete_tx_scope(%barrier: !llvm.ptr<3>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_expect_tx(%barrier: !llvm.ptr<3>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count: i32) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_ret(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { + // expected-error @below {{return-value is not supported when using predicate}} + %1 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 -> i64 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_relaxed(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { + // expected-error @below {{mbarrier with relaxed semantics is not supported when using predicate}} + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred {relaxed = true} : !llvm.ptr<3>, i32, i1 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_cta(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { + // expected-error @below {{mbarrier scope must be CTA when using predicate}} + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i1 + llvm.return +} + +// ----- + +llvm.func @init_mbarrier_arrive_expect_tx_asm_cluster(%barrier : !llvm.ptr<7>, %txcount : i32, %pred : i1) { + // expected-error @below {{mbarrier in shared_cluster space is not supported when using predicate}} + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<7>, i32, i1 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_drop_expect_tx(%barrier: !llvm.ptr<3>, %tx_count: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_arr_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count: i32) { + // expected-error @below {{mbarrier in shared_cluster space cannot return any value}} + %1 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %tx_count {scope = #nvvm.mem_scope<cta>} : !llvm.ptr<7>, i32 -> i64 + llvm.return +} + +// ----- + +llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +// ----- + +llvm.func @mbarrier_try_wait(%barrier: !llvm.ptr<3>, %phase: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +// ----- + +llvm.func @mbarrier_try_wait_with_timelimit(%barrier: !llvm.ptr<3>, %phase: i32, %ticks: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32, i32 -> i1 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir new file mode 100644 index 0000000..21ab72e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) { + // CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) { + // CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir new file mode 100644 index 0000000..18aaf0e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir @@ -0,0 +1,147 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_try_wait_state(%barrier: !llvm.ptr, %state : i64) { + // CHECK-LABEL: define void @mbarrier_try_wait_state(ptr %0, i64 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr, i64 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1 + + llvm.return +} + +llvm.func @mbarrier_try_wait_state_with_timelimit(%barrier: !llvm.ptr, %state : i64, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_state_with_timelimit(ptr %0, i64 %1, i32 %2) { + // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %4, i64 %1, i32 %2) + // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i64 %1, i32 %2) + // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i64 %1, i32 %2) + // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i64 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr, i64, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr, i64, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64, i32 -> i1 + + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_state(ptr addrspace(3) %0, i64 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr<3>, i64 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_state_with_timelimit(%barrier: !llvm.ptr<3>, %state : i64, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_state_with_timelimit(ptr addrspace(3) %0, i64 %1, i32 %2) { + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr<3>, i64, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr<3>, i64, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_phase(%barrier: !llvm.ptr, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_phase(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_phase_with_timelimit(%barrier: !llvm.ptr, %phase : i32, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_phase_with_timelimit(ptr %0, i32 %1, i32 %2) { + // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %4, i32 %1, i32 %2) + // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i32 %1, i32 %2) + // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i32 %1, i32 %2) + // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i32 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr, i32, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_phase_with_timelimit(%barrier: !llvm.ptr<3>, %phase : i32, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase_with_timelimit(ptr addrspace(3) %0, i32 %1, i32 %2) { + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr<3>, i32, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32, i32 -> i1 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbarriers.mlir b/mlir/test/Target/LLVMIR/nvvm/mbarriers.mlir deleted file mode 100644 index 9bb3b08..0000000 --- a/mlir/test/Target/LLVMIR/nvvm/mbarriers.mlir +++ /dev/null @@ -1,116 +0,0 @@ -// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s - -llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) { - // CHECK-LABEL: define void @cp_async_mbarrier_arrive(ptr addrspace(3) %0, ptr %1) { - // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %1) - // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %1) - // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %0) - // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %0) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr - nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr - nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3> - nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3> - llvm.return -} - -llvm.func @mbarrier_init_generic(%barrier: !llvm.ptr) { - // CHECK-LABEL: define void @mbarrier_init_generic(ptr %0) { - // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init(ptr %0, i32 %2) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %count = nvvm.read.ptx.sreg.ntid.x : i32 - nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32 - llvm.return -} - -llvm.func @mbarrier_init_shared(%barrier: !llvm.ptr<3>) { - // CHECK-LABEL: define void @mbarrier_init_shared(ptr addrspace(3) %0) { - // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init.shared(ptr addrspace(3) %0, i32 %2) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %count = nvvm.read.ptx.sreg.ntid.x : i32 - nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32 - llvm.return -} - -llvm.func @mbarrier_inval_generic(%barrier: !llvm.ptr) { - // CHECK-LABEL: define void @mbarrier_inval_generic(ptr %0) { - // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval(ptr %0) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - nvvm.mbarrier.inval %barrier : !llvm.ptr - llvm.return -} - -llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) { - // CHECK-LABEL: define void @mbarrier_inval_shared(ptr addrspace(3) %0) { - // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval.shared(ptr addrspace(3) %0) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - nvvm.mbarrier.inval %barrier : !llvm.ptr<3> - llvm.return -} - -llvm.func @mbarrier_arrive(%barrier: !llvm.ptr) { - // CHECK-LABEL: define void @mbarrier_arrive(ptr %0) { - // CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive(ptr %0) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr -> i64 - llvm.return -} - -llvm.func @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) { - // CHECK-LABEL: define void @mbarrier_arrive_shared(ptr addrspace(3) %0) { - // CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive.shared(ptr addrspace(3) %0) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64 - llvm.return -} - -llvm.func @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) { - // CHECK-LABEL: define void @mbarrier_arrive_nocomplete(ptr %0) { - // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete(ptr %0, i32 %2) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %count = nvvm.read.ptx.sreg.ntid.x : i32 - %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64 - llvm.return -} - -llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) { - // CHECK-LABEL: define void @mbarrier_arrive_nocomplete_shared(ptr addrspace(3) %0) { - // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete.shared(ptr addrspace(3) %0, i32 %2) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %count = nvvm.read.ptx.sreg.ntid.x : i32 - %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64 - llvm.return -} - -llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 { - // CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) { - // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1) - // CHECK-NEXT: ret i1 %3 - // CHECK-NEXT: } - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1 - llvm.return %isComplete : i1 -} - -llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) { - // CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) { - // CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %count = nvvm.read.ptx.sreg.ntid.x : i32 - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1 - llvm.return -} diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir new file mode 100644 index 0000000..1d6c23c --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +llvm.func @invalid_default_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'default' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_f4e_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'f4e' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<f4e> %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_b4e_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'b4e' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<b4e> %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_rc8_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'rc8' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_ecl_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'ecl' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_ecr_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'ecr' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_rc16_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'rc16' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %lo, %hi : i32 + llvm.return %r : i32 +} diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir new file mode 100644 index 0000000..d2baae7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @test_prmt_default +llvm.func @test_prmt_default(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_f4e +llvm.func @test_prmt_f4e(%pos: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<f4e> %pos, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_b4e +llvm.func @test_prmt_b4e(%pos: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<b4e> %pos, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_rc8 +llvm.func @test_prmt_rc8(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_ecl +llvm.func @test_prmt_ecl(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_ecr +llvm.func @test_prmt_ecr(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_rc16 +llvm.func @test_prmt_rc16(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_mixed +llvm.func @test_prmt_mixed(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r1 = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32 + + // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}}) + %r2 = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %r1 : i32 + + // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r3 = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %r2, %sel : i32 + + llvm.return %r3 : i32 +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir new file mode 100644 index 0000000..db4574b --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir new file mode 100644 index 0000000..a15c3fb --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir new file mode 100644 index 0000000..f46b35a --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-translate --mlir-to-llvmir -verify-diagnostics -split-input-file %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_shared_ashift +llvm.func @nvvm_tcgen05_mma_shared_ashift(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + // expected-error @below {{A-shift can be applied only when matrix A is in tensor memory}} + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, i64, i64, i32, i1) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_ashift +llvm.func @nvvm_tcgen05_mma_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + // expected-error @below {{Cannot use collector buffer operation fill or use with ashift}} + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) { + // expected-error @below {{mxf4nvf4 requires block scale attribute}} + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_mxf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) { + // expected-error @below {{mxf4 kind does not support block16 attribute}} + nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLanev4: vector<4 x i32>, %disableOutputLanev8: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{Disable Output Lane of length 8 is incompatible with CtaGroupAttr}} + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLanev8 + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_sp_mma_shared_ashift +llvm.func @nvvm_tcgen05_sp_mma_shared_ashift(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{A-shift can be applied only when matrix A is in tensor memory}} + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_ashift +llvm.func @nvvm_tcgen05_mma_sp_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{Cannot use collector buffer operation fill or use with ashift}} + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{mxf4nvf4 requires block scale attribute}} + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} + +// ----- + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_default +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + // expected-error @below {{mxf4 kind does not support block16 attribute}} + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir new file mode 100644 index 0000000..286df36 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir @@ -0,0 +1,442 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_1 +llvm.func @nvvm_tcgen05_mma_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_2 +llvm.func @nvvm_tcgen05_mma_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + llvm.return +} + + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir new file mode 100644 index 0000000..5c7eabe --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir new file mode 100644 index 0000000..3200411 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir @@ -0,0 +1,229 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block16(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scale_a, %scale_b + {kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir new file mode 100644 index 0000000..96044cf --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir @@ -0,0 +1,442 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.shared.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir new file mode 100644 index 0000000..709beb0 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir @@ -0,0 +1,634 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_sp_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>, %spmetadata: !llvm.ptr<6>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.sp.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir new file mode 100644 index 0000000..798e311 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir @@ -0,0 +1,633 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_1 +llvm.func @nvvm_tcgen05_mma_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 1, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_cta_2 +llvm.func @nvvm_tcgen05_mma_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=discard */ i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=lastuse */ i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=fill */ i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f16 */ i32 0, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=tf32 */ i32 1, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=f8f6f4 */ i32 2, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, /* kind=i8 */ i32 3, /* cta_group= */ i32 2, /* collector=use */ i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 0, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, i32 1, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane : vector<4 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <4 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 2, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, <8 x i32> {{%[0-9]+}}, i32 3, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<i8>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, vector<8 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_1(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<4 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg1(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <4 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<4 x i32>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2 +llvm.func @nvvm_tcgen05_mma_scale_d_imm_disable_output_lane_cta_2(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %disableOutputLane: vector<8 x i32>) { + + %scale_d_imm = llvm.mlir.constant(0:i64) : i64 + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 0) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2.ashift(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 1) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 2) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 0, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<f16>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.tensor.scale_d.disable_output_lane.cg2(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 0, <8 x i32> {{%[0-9]+}}, i32 1, i32 3) + nvvm.tcgen05.mma %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d scale = %scale_d_imm mask = %disableOutputLane + {kind = #nvvm.tcgen05_mma_kind<tf32>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64, vector<8 x i32>) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir new file mode 100644 index 0000000..5f1aeb0 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws +llvm.func @nvvm_tcgen05_mma_ws(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir new file mode 100644 index 0000000..e390e35 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp +llvm.func @nvvm_tcgen05_mma_ws_sp(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_sp_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.shared.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir new file mode 100644 index 0000000..f7ce548 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp +llvm.func @nvvm_tcgen05_mma_ws_sp(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_sp_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_sp_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %spmetadata: !llvm.ptr<6>, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.sp.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws.sp %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir new file mode 100644 index 0000000..cecbb3f --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws +llvm.func @nvvm_tcgen05_mma_ws(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1) + + llvm.return +} + +// CHECK-LABEL: @nvvm_tcgen05_mma_ws_zero_col_mask +llvm.func @nvvm_tcgen05_mma_ws_zero_col_mask(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %zero_col_mask: i64) { + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 0, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 0) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 0, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f16>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 1, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<tf32>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 2, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<f8f6f4>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + // CHECK: call void @llvm.nvvm.tcgen05.mma.ws.tensor.zero_col_mask(ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 3, i32 1, i32 1) + nvvm.tcgen05.mma.ws %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %zero_col_mask + {kind = #nvvm.tcgen05_mma_kind<i8>, + collectorBBuffer = #nvvm.tcgen05_mma_collectorb<b1>, + collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, i64) + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir index 0daf245..240fab5 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir @@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>, llvm.return } +// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta +llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) { + // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true) + nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1> + + nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1> + + llvm.return +} + // CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) { // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3) diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir new file mode 100644 index 0000000..d762ff3 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) { + // expected-error @below {{Multicast is not supported with shared::cta mode.}} + nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1> + + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 42aa221..d5868ee 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -578,14 +578,6 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // ----- -llvm.func @nanosleep() { - // expected-error@+1 {{integer constant out of range for attribute}} - nvvm.nanosleep 100000000000000 - llvm.return -} - -// ----- - llvm.func @clusterlaunchcontrol_query_cancel_is_canceled_invalid_return_type(%try_cancel_response: i128) { // expected-error@+1 {{'nvvm.clusterlaunchcontrol.query.cancel' op is_canceled query type returns an i1}} %res = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %try_cancel_response : i32 diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 1ec5540..c4a6909 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -166,25 +166,6 @@ llvm.func @nvvm_rcp(%0: f32) -> f32 { llvm.return %1 : f32 } -// CHECK-LABEL: @llvm_nvvm_barrier0 -llvm.func @llvm_nvvm_barrier0() { - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) - nvvm.barrier0 - llvm.return -} - -// CHECK-LABEL: @llvm_nvvm_barrier( -// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]]) -llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) { - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0) - nvvm.barrier - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 %[[barId]]) - nvvm.barrier id = %barID - // CHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.count(i32 %[[barId]], i32 %[[numThreads]]) - nvvm.barrier id = %barID number_of_threads = %numberOfThreads - llvm.return -} - // CHECK-LABEL: @llvm_nvvm_cluster_arrive llvm.func @llvm_nvvm_cluster_arrive() { // CHECK: call void @llvm.nvvm.barrier.cluster.arrive() @@ -718,42 +699,6 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, llvm.return } - -// ----- -// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_release -llvm.func @nvvm_fence_proxy_tensormap_generic_release() { - %c128 = llvm.mlir.constant(128) : i32 - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cta() - nvvm.fence.proxy.release #nvvm.mem_scope<cta> - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.cluster() - nvvm.fence.proxy.release #nvvm.mem_scope<cluster> - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.gpu() - nvvm.fence.proxy.release #nvvm.mem_scope<gpu> - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.release.sys() - nvvm.fence.proxy.release #nvvm.mem_scope<sys> - llvm.return -} - -// ----- -// CHECK-LABEL: @nvvm_fence_proxy_tensormap_generic_acquire -llvm.func @nvvm_fence_proxy_tensormap_generic_acquire(%addr : !llvm.ptr) { - %c128 = llvm.mlir.constant(128) : i32 - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cta(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<cta> %addr, %c128 - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.cluster(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<cluster> %addr, %c128 - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.gpu(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<gpu> %addr, %c128 - - // CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.sys(ptr {{%[0-9]+}}, i32 128) - nvvm.fence.proxy.acquire #nvvm.mem_scope<sys> %addr, %c128 - llvm.return -} // ----- // CHECK-LABEL: @nvvm_exit @@ -970,8 +915,8 @@ llvm.func @nvvm_pmevent() { // ----- // CHECK-LABEL: @nanosleep -llvm.func @nanosleep() { - // CHECK: call void @llvm.nvvm.nanosleep(i32 4000) - nvvm.nanosleep 4000 +llvm.func @nanosleep(%duration: i32) { + // CHECK: call void @llvm.nvvm.nanosleep(i32 %{{.*}}) + nvvm.nanosleep %duration llvm.return } diff --git a/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir b/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir index f6860e5..d9be6d1 100644 --- a/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-data-use-dev-ordering.mlir @@ -67,18 +67,18 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: define void @mix_use_device_ptr_and_addr_and_map_(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) { // CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8 -// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK: %[[BASEPTR_0_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 // CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_0_GEP]], align 8 -// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 // CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 -// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 +// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: call void @__tgt_target_data_begin_mapper({{.*}}) // CHECK: %[[LOAD_BASEPTR_0:.*]] = load ptr, ptr %[[BASEPTR_0_GEP]], align 8 // store ptr %[[LOAD_BASEPTR_0]], ptr %[[ALLOCA]], align 8 // CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0 // CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0 // CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4 @@ -93,17 +93,17 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: define void @mix_use_device_ptr_and_addr_and_map_2(ptr %[[ARG_0:.*]], ptr %[[ARG_1:.*]], ptr %[[ARG_2:.*]], ptr %[[ARG_3:.*]], ptr %[[ARG_4:.*]], ptr %[[ARG_5:.*]], ptr %[[ARG_6:.*]], ptr %[[ARG_7:.*]]) { // CHECK: %[[ALLOCA:.*]] = alloca ptr, align 8 -// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 +// CHECK: %[[BASEPTR_1_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 // CHECK: store ptr %[[ARG_0]], ptr %[[BASEPTR_1_GEP]], align 8 -// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: %[[BASEPTR_2_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 // CHECK: store ptr %[[ARG_2]], ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[BASEPTR_6_GEP:.*]] = getelementptr inbounds [10 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 -// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[BASEPTR_3_GEP:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 +// CHECK: store ptr %[[ARG_4]], ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: call void @__tgt_target_data_begin_mapper({{.*}}) // CHECK: %[[LOAD_BASEPTR_1:.*]] = load ptr, ptr %[[BASEPTR_1_GEP]], align 8 // store ptr %[[LOAD_BASEPTR_1]], ptr %[[ALLOCA]], align 8 // CHECK: %[[LOAD_BASEPTR_2:.*]] = load ptr, ptr %[[BASEPTR_2_GEP]], align 8 -// CHECK: %[[LOAD_BASEPTR_6:.*]] = load ptr, ptr %[[BASEPTR_6_GEP]], align 8 +// CHECK: %[[LOAD_BASEPTR_3:.*]] = load ptr, ptr %[[BASEPTR_3_GEP]], align 8 // CHECK: %[[GEP_A4:.*]] = getelementptr { i64 }, ptr %[[ARG_4]], i32 0, i32 0 // CHECK: %[[GEP_A7:.*]] = getelementptr { i64 }, ptr %[[ARG_7]], i32 0, i32 0 // CHECK: %[[LOAD_A4:.*]] = load i64, ptr %[[GEP_A4]], align 4 diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir new file mode 100644 index 0000000..fa330b6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-device.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// This tests the replacement of operations for `declare target to` with the +// generated `declare target to` global variable inside of target op regions when +// lowering to IR for device. Unfortunately, as the host file is not passed as a +// module attribute, we miss out on the metadata and entry info. + +module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} { + // CHECK-DAG: @_QMtest_0Ezii = global [11 x float] zeroinitializer + llvm.mlir.global external @_QMtest_0Ezii() {addr_space = 0 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : !llvm.array<11 x f32> { + %0 = llvm.mlir.zero : !llvm.array<11 x f32> + llvm.return %0 : !llvm.array<11 x f32> + } + + // CHECK-LABEL: define weak_odr protected amdgpu_kernel void @{{.*}}(ptr %{{.*}}) {{.*}} { + // CHECK-DAG: omp.target: + // CHECK-DAG: store float 1.000000e+00, ptr @_QMtest_0Ezii, align 4 + // CHECK-DAG: br label %omp.region.cont + llvm.func @_QQmain() { + %0 = llvm.mlir.constant(1 : index) : i64 + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.mlir.constant(11 : index) : i64 + %3 = llvm.mlir.addressof @_QMtest_0Ezii : !llvm.ptr + %4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%2 : i64) extent(%2 : i64) stride(%0 : i64) start_idx(%1 : i64) {stride_in_bytes = true} + %5 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.array<11 x f32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr + omp.target map_entries(%5 -> %arg0 : !llvm.ptr) { + %6 = llvm.mlir.constant(1.0 : f32) : f32 + %7 = llvm.mlir.constant(0 : i64) : i64 + %8 = llvm.getelementptr %arg0[%7] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + llvm.store %6, %8 : f32, !llvm.ptr + omp.terminator + } + llvm.return + } +} diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir new file mode 100644 index 0000000..4202421 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-to-host.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + // CHECK-DAG: @_QMtest_0Ezii = global [11 x float] zeroinitializer + // CHECK-DAG: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 48] + // CHECK-DAG: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 3] + // CHECK-DAG: @.offloading.entry._QMtest_0Ezii = weak constant %struct.__tgt_offload_entry {{.*}} ptr @_QMtest_0Ezii, {{.*}}, i64 44,{{.*}} + llvm.mlir.global external @_QMtest_0Ezii() {addr_space = 0 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : !llvm.array<11 x f32> { + %0 = llvm.mlir.zero : !llvm.array<11 x f32> + llvm.return %0 : !llvm.array<11 x f32> + } + + // CHECK-DAG: %[[BASEPTR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 + // CHECK-DAG: store ptr @_QMtest_0Ezii, ptr %[[BASEPTR]], align 8 + // CHECK-DAG: %[[OFFLOADPTR:.*]] = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 + // CHECK-DAG: store ptr @_QMtest_0Ezii, ptr %[[OFFLOADPTR]], align 8 + llvm.func @_QQmain() { + %0 = llvm.mlir.constant(1 : index) : i64 + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.mlir.constant(11 : index) : i64 + %3 = llvm.mlir.addressof @_QMtest_0Ezii : !llvm.ptr + %4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%2 : i64) extent(%2 : i64) stride(%0 : i64) start_idx(%1 : i64) {stride_in_bytes = true} + %5 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.array<11 x f32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr + omp.target map_entries(%5 -> %arg0 : !llvm.ptr) { + %6 = llvm.mlir.constant(1.0 : f32) : f32 + %7 = llvm.mlir.constant(0 : i64) : i64 + %8 = llvm.getelementptr %arg0[%7] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + llvm.store %6, %8 : f32, !llvm.ptr + omp.terminator + } + llvm.return + } + // CHEKC-DAG: !{{.*}} = !{i32 {{.*}}, !"_QMtest_0Ezii", i32 {{.*}}, i32 {{.*}}} +} diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir index e6ea3aa..e289d5d 100644 --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} { // CHECK: br label %[[VAL_40]] // CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]] // CHECK: ret void + +// ----- + +module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) { + %map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr) + map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""} + omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) { + omp.terminator + } + llvm.return + } +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8] +// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288] +// CHECK-LABEL: define void @_QPomp_target_is_device_ptr diff --git a/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir index 87ff0ba..fac61e05 100644 --- a/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir @@ -7,7 +7,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : llvm.func @bar() {} llvm.func @baz() {} - omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr alloc { + omp.declare_reduction @add_reduction_byref_box_5xf32 : !llvm.ptr attributes {byref_element_type = !llvm.array<5 x f32>} alloc { %0 = llvm.mlir.constant(1 : i64) : i64 %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr<5> %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr @@ -23,7 +23,12 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ^bb3: // pred: ^bb1 llvm.call @baz() : () -> () omp.yield(%arg0 : !llvm.ptr) + } data_ptr_ptr { + ^bb0(%arg0: !llvm.ptr): + %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> + omp.yield(%0 : !llvm.ptr) } + llvm.func @foo_() { %c1 = llvm.mlir.constant(1 : i64) : i64 %10 = llvm.alloca %c1 x !llvm.array<5 x f32> {bindc_name = "x"} : (i64) -> !llvm.ptr<5> @@ -51,8 +56,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : } } -// CHECK: call void @__kmpc_parallel_51({{.*}}, i32 1, i32 -1, i32 -1, -// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1) +// CHECK: call void @__kmpc_parallel_60({{.*}}, i32 1, i32 -1, i32 -1, +// CHECK-SAME: ptr @[[PAR_OUTLINED:.*]], ptr null, ptr %2, i64 1, i32 0) // CHECK: define internal void @[[PAR_OUTLINED]]{{.*}} { // CHECK: .omp.reduction.then: @@ -67,9 +72,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : // CHECK: br label %[[CONT_BB:.*]] // CHECK: [[CONT_BB]]: -// CHECK-NEXT: %[[RED_RHS:.*]] = phi ptr [ %final.rhs, %{{.*}} ] -// CHECK-NEXT: store ptr %[[RED_RHS]], ptr %{{.*}}, align 8 -// CHECK-NEXT: br label %.omp.reduction.done +// CHECK-NEXT: %[[RED_RHS:.*]] = phi ptr [ %{{.*}}, %{{.*}} ] // CHECK: } // CHECK: define internal void @"{{.*}}$reduction$reduction_func"(ptr noundef %0, ptr noundef %1) #0 { diff --git a/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir index b8b7c78..8950db3 100644 --- a/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir @@ -109,19 +109,19 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: icmp eq i32 %[[MASTER]], 1 // CHECK: i1 %{{.+}}, label %[[THEN:[A-Za-z0-9_.]*]], label %[[DONE:[A-Za-z0-9_.]*]] // CHECK: [[THEN]]: -// CHECK-NEXT: %[[FINAL_RHS0:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_LHS0:[A-Za-z0-9_.]*]] = load double +// CHECK-NEXT: %[[FINAL_RHS0:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_RESULT0:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS0]], %[[FINAL_RHS0]] // CHECK-NEXT: store double %[[FINAL_RESULT0]] -// CHECK-NEXT: %[[FINAL_RHS1:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_LHS1:[A-Za-z0-9_.]*]] = load double +// CHECK-NEXT: %[[FINAL_RHS1:[A-Za-z0-9_.]*]] = load double // CHECK-NEXT: %[[FINAL_RESULT1:[A-Za-z0-9_.]*]] = fadd contract double %[[FINAL_LHS1]], %[[FINAL_RHS1]] // CHECK-NEXT: store double %[[FINAL_RESULT1]] -// CHECK-NEXT: %[[FINAL_RHS2:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_LHS2:[A-Za-z0-9_.]*]] = load float +// CHECK-NEXT: %[[FINAL_RHS2:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_RESULT2:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS2]], %[[FINAL_RHS2]] // CHECK-NEXT: store float %[[FINAL_RESULT2]] -// CHECK-NEXT: %[[FINAL_RHS3:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_LHS3:[A-Za-z0-9_.]*]] = load float +// CHECK-NEXT: %[[FINAL_RHS3:[A-Za-z0-9_.]*]] = load float // CHECK-NEXT: %[[FINAL_RESULT3:[A-Za-z0-9_.]*]] = fadd contract float %[[FINAL_LHS3]], %[[FINAL_RHS3]] // CHECK-NEXT: store float %[[FINAL_RESULT3]] diff --git a/mlir/test/Target/LLVMIR/omptarget-nowait.mlir b/mlir/test/Target/LLVMIR/omptarget-nowait.mlir index 19333c4..a96756f46 100644 --- a/mlir/test/Target/LLVMIR/omptarget-nowait.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-nowait.mlir @@ -25,34 +25,33 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} { // CHECK: %struct.[[TSK_WTH_PRVTS:.*]] = type { %struct.kmp_task_ompbuilder_t, %struct.[[PRVTS:.*]] } // CHECK: %struct.kmp_task_ompbuilder_t = type { ptr, ptr, i32, ptr, ptr } -// CHECK: %struct.[[PRVTS]] = type { [5 x ptr], [5 x ptr], [5 x i64] } +// CHECK: %struct.[[PRVTS]] = type { [6 x ptr], [6 x ptr], [6 x i64] } // CHECK: define void @launch_(ptr captures(none) %0) // CHECK: %[[STRUCTARG:.*]] = alloca { ptr, ptr }, align 8 -// CHECK: %[[BASEPTRS:.*]] = alloca [5 x ptr], align 8 -// CHECK: %[[PTRS:.*]] = alloca [5 x ptr], align 8 -// CHECK: %[[MAPPERS:.*]] = alloca [5 x ptr], align 8 -// CHECK: %[[SIZES:.*]] = alloca [5 x i64], align 4 +// CHECK: %[[BASEPTRS:.*]] = alloca [6 x ptr], align 8 +// CHECK: %[[PTRS:.*]] = alloca [6 x ptr], align 8 +// CHECK: %[[MAPPERS:.*]] = alloca [6 x ptr], align 8 +// CHECK: %[[SIZES:.*]] = alloca [6 x i64], align 4 - -// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [5 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 -// CHECK: %[[BASEPTRS_GEP:.*]] = getelementptr inbounds [5 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 -// CHECK: %[[PTRS_GEP:.*]] = getelementptr inbounds [5 x ptr], ptr %[[PTRS]], i32 0, i32 0 -// CHECK: %[[SIZES_GEP:.*]] = getelementptr inbounds [5 x i64], ptr %[[SIZES]], i32 0, i32 0 +// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [6 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 +// CHECK: %[[BASEPTRS_GEP:.*]] = getelementptr inbounds [6 x ptr], ptr %[[BASEPTRS]], i32 0, i32 0 +// CHECK: %[[PTRS_GEP:.*]] = getelementptr inbounds [6 x ptr], ptr %[[PTRS]], i32 0, i32 0 +// CHECK: %[[SIZES_GEP:.*]] = getelementptr inbounds [6 x i64], ptr %[[SIZES]], i32 0, i32 0 // CHECK: %[[GL_THRD_NUM:.*]] = call i32 @__kmpc_global_thread_num -// CHECK: %[[TASK_DESC:.*]] = call ptr @__kmpc_omp_target_task_alloc(ptr @4, i32 {{.*}}, i32 0, i64 160, i64 16, ptr [[TGT_TSK_PRXY_FNC:.*]], i64 -1) +// CHECK: %[[TASK_DESC:.*]] = call ptr @__kmpc_omp_target_task_alloc(ptr @4, i32 {{.*}}, i32 0, i64 184, i64 16, ptr [[TGT_TSK_PRXY_FNC:.*]], i64 -1) // CHECK: %[[TSK_PTR:.*]] = getelementptr inbounds nuw %struct.[[TSK_WTH_PRVTS]], ptr %[[TASK_DESC]], i32 0, i32 0 // CHECK: %[[SHAREDS:.*]] = getelementptr inbounds nuw %struct.kmp_task_ompbuilder_t, ptr %[[TSK_PTR]], i32 0, i32 0 // CHECK: %[[SHAREDS_PTR:.*]] = load ptr, ptr %[[SHAREDS]], align 8 // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[SHAREDS_PTR]], ptr align 1 %[[STRUCTARG]], i64 16, i1 false) // CHECK: %[[VAL_50:.*]] = getelementptr inbounds nuw %struct.[[TSK_WTH_PRVTS]], ptr %[[TASK_DESC]], i32 0, i32 1 // CHECK: %[[VAL_51:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 0 -// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_51]], ptr align 1 %[[BASEPTRS_GEP]], i64 40, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_51]], ptr align 1 %[[BASEPTRS_GEP]], i64 48, i1 false) // CHECK: %[[VAL_53:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 1 -// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_53]], ptr align 1 %[[PTRS_GEP]], i64 40, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_53]], ptr align 1 %[[PTRS_GEP]], i64 48, i1 false) // CHECK: %[[VAL_54:.*]] = getelementptr inbounds nuw %struct.[[PRVTS]], ptr %[[VAL_50]], i32 0, i32 2 -// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_54]], ptr align 1 %[[SIZES_GEP]], i64 40, i1 false) +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[VAL_54]], ptr align 1 %[[SIZES_GEP]], i64 48, i1 false) // CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_omp_task(ptr @4, i32 %[[GL_THRD_NUM]], ptr %[[TASK_DESC]]) // CHECK: define internal void @[[WORKER:.*]](i32 {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}) { diff --git a/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir b/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir new file mode 100644 index 0000000..1e9369f --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-overlapping-record-member-map.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_gpu = false, omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @_QQmain() attributes {fir.bindc_name = "main"} { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<"_QFTdtype", (f32, i32)> {bindc_name = "dtypev"} : (i64) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFTdtype", (f32, i32)> + %3 = omp.map.info var_ptr(%2 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "dtypev%value2"} + %4 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<"_QFTdtype", (f32, i32)>) map_clauses(to) capture(ByRef) members(%3 : [1] : !llvm.ptr) -> !llvm.ptr {name = "dtypev"} + omp.target map_entries(%4 -> %arg0, %3 -> %arg1 : !llvm.ptr, !llvm.ptr) { + omp.terminator + } + llvm.return + } +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [4 x i64] [i64 0, i64 0, i64 0, i64 4] +// CHECK: @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710657, i64 281474976710657, i64 281474976710659] + +// CHECK: %[[ALLOCA:.*]] = alloca %_QFTdtype, i64 1, align 8 +// CHECK: %[[ELEMENT_ACC:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 0, i32 1 + +// CHECK: %[[SIZE1_CALC_1:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 1 +// CHECK: %[[SIZE1_CALC_2:.*]] = ptrtoint ptr %[[SIZE1_CALC_1]] to i64 +// CHECK: %[[SIZE1_CALC_3:.*]] = ptrtoint ptr %[[ALLOCA]] to i64 +// CHECK: %[[SIZE1_CALC_4:.*]] = sub i64 %[[SIZE1_CALC_2]], %[[SIZE1_CALC_3]] +// CHECK: %[[SIZE1_CALC_5:.*]] = sdiv exact i64 %[[SIZE1_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) + +// CHECK: %[[SIZE2_CALC_1:.*]] = getelementptr %_QFTdtype, ptr %[[ALLOCA]], i32 1 +// CHECK: %[[SIZE2_CALC_2:.*]] = ptrtoint ptr %[[ELEMENT_ACC]] to i64 +// CHECK: %[[SIZE2_CALC_3:.*]] = ptrtoint ptr %[[ALLOCA]] to i64 +// CHECK: %[[SIZE2_CALC_4:.*]] = sub i64 %[[SIZE2_CALC_2]], %[[SIZE2_CALC_3]] +// CHECK: %[[SIZE2_CALC_5:.*]] = sdiv exact i64 %[[SIZE2_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) + +// CHECK: %[[SIZE3_CALC_1:.*]] = getelementptr i32, ptr %[[ELEMENT_ACC]], i32 1 +// CHECK: %[[SIZE3_CALC_2:.*]] = ptrtoint ptr %[[SIZE2_CALC_1]] to i64 +// CHECK: %[[SIZE3_CALC_3:.*]] = ptrtoint ptr %[[SIZE3_CALC_1]] to i64 +// CHECK: %[[SIZE3_CALC_4:.*]] = sub i64 %[[SIZE3_CALC_2]], %[[SIZE3_CALC_3]] +// CHECK: %[[SIZE3_CALC_5:.*]] = sdiv exact i64 %[[SIZE3_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK: store ptr %[[ALLOCA]], ptr %[[PTRS]], align 8 +// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK: store i64 %[[SIZE1_CALC_5]], ptr %[[SIZES]], align 8 + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 1 +// CHECK: store ptr %[[ALLOCA]], ptr %[[PTRS]], align 8 +// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 1 +// CHECK: store i64 %[[SIZE2_CALC_5]], ptr %[[SIZES]], align 8 + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 2 +// CHECK: store ptr %13, ptr %[[PTRS]], align 8 +// CHECK: %[[SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 2 +// CHECK: store i64 %[[SIZE3_CALC_5]], ptr %[[SIZES]], align 8 + +// CHECK: %[[BASEPTR:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 3 +// CHECK: store ptr %[[ALLOCA]], ptr %[[BASEPTR]], align 8 +// CHECK: %[[PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 3 +// CHECK: store ptr %[[ELEMENT_ACC]], ptr %[[PTRS]], align 8 diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir index 60c6fa4..cdb8dbb 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir @@ -70,31 +70,31 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8 // CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0 // CHECK: store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8 -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1) +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0) // CHECK: call void @__kmpc_target_deinit() // CHECK: define internal void @[[FUNC1]]( // CHECK-SAME: ptr noalias noundef {{.*}}, ptr noalias noundef {{.*}}, ptr {{.*}}) #{{[0-9]+}} { // Test if num_threads OpenMP clause for target region is correctly lowered -// and passed as a param to kmpc_parallel_51 function +// and passed as a param to kmpc_parallel_60 function // CHECK: define weak_odr protected amdgpu_kernel void [[FUNC_NUM_THREADS0:@.*]]( // CHECK-NOT: call void @__kmpc_push_num_threads( -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr), // CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156, -// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1) +// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1, i32 0) -// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause +// One of the arguments of kmpc_parallel_60 function is responsible for handling if clause // of omp parallel construct for target region. If this argument is nonzero, -// then kmpc_parallel_51 launches multiple threads for parallel region. +// then kmpc_parallel_60 launches multiple threads for parallel region. // // This test checks if MLIR expression: // %7 = llvm.icmp "ne" %5, %6 : i32 // omp.parallel if(%7) // is correctly lowered to LLVM IR code and the if condition variable -// is passed as a param to kmpc_parallel_51 function +// is passed as a param to kmpc_parallel_60 function // CHECK: define weak_odr protected amdgpu_kernel void @{{.*}}( // CHECK-SAME: ptr {{.*}}, ptr {{.*}}, ptr %[[IFCOND_ARG2:.*]]) #{{[0-9]+}} { @@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: %[[IFCOND_TMP2:.*]] = load i32, ptr %[[IFCOND_TMP1]], align 4 // CHECK: %[[IFCOND_TMP3:.*]] = icmp ne i32 %[[IFCOND_TMP2]], 0 // CHECK: %[[IFCOND_TMP4:.*]] = sext i1 %[[IFCOND_TMP3]] to i32 -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) {{.*}} to ptr), // CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1, -// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1) +// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1, i32 0) diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir index 5d2861a..917eaa0 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir @@ -26,10 +26,10 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo } } -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast +// CHECK: call void @__kmpc_parallel_60(ptr addrspacecast // CHECK-SAME: (ptr addrspace(1) @[[GLOB:[0-9]+]] to ptr), // CHECK-SAME: i32 %[[THREAD_NUM:.*]], i32 1, i32 -1, i32 -1, -// CHECK-SAME: ptr @[[PARALLEL_FUNC:.*]], ptr null, ptr %[[PARALLEL_ARGS:.*]], i64 1) +// CHECK-SAME: ptr @[[PARALLEL_FUNC:.*]], ptr null, ptr %[[PARALLEL_ARGS:.*]], i64 1, i32 0) // CHECK: define internal void @[[PARALLEL_FUNC]] // CHECK-SAME: (ptr noalias noundef %[[TID_ADDR:.*]], ptr noalias noundef %[[ZERO_ADDR:.*]], diff --git a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir index 9640f03..711b50a 100644 --- a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir @@ -59,9 +59,9 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: @[[FULL_ARR_GLOB:.*]] = internal global { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } undef // CHECK: @[[ARR_SECT_GLOB:.*]] = internal global { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } undef -// CHECK: @.offload_sizes = private unnamed_addr constant [12 x i64] [i64 0, i64 48, i64 8, i64 0, i64 0, i64 48, i64 8, i64 0, i64 0, i64 24, i64 8, i64 0] -// CHECK: @.offload_maptypes = private unnamed_addr constant [12 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675, i64 32, i64 1407374883553283, i64 1407374883553283, i64 1407374883553299, i64 32, i64 2533274790395907, i64 2533274790395907, i64 2533274790395923] -// CHECK: @.offload_mapnames = private constant [12 x ptr] [ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}] +// CHECK: @.offload_sizes = private unnamed_addr constant [15 x i64] [i64 0, i64 0, i64 0, i64 8, i64 0, i64 0, i64 0, i64 0, i64 8, i64 0, i64 0, i64 0, i64 0, i64 8, i64 0] +// CHECK: @.offload_maptypes = private unnamed_addr constant [15 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710659, i64 281474976710675, i64 32, i64 1688849860263939, i64 1688849860263939, i64 1688849860263939, i64 1688849860263955, i64 32, i64 3096224743817219, i64 3096224743817219, i64 3096224743817219, i64 3096224743817235] +// CHECK: @.offload_mapnames = private constant [15 x ptr] [ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}] // CHECK: define void @main() // CHECK: %[[SCALAR_ALLOCA:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8 @@ -85,74 +85,97 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a // CHECK: %[[ARR_SECT_PTR:.*]] = getelementptr inbounds i32, ptr %[[LARR_SECT]], i64 %[[ARR_SECT_OFFSET2]] // CHECK: %[[SCALAR_PTR_LOAD:.*]] = load ptr, ptr %[[SCALAR_BASE]], align 8 // CHECK: %[[FULL_ARR_DESC_SIZE:.*]] = sdiv exact i64 48, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) -// CHECK: %[[FULL_ARR_SIZE_CMP:.*]] = icmp eq ptr %[[FULL_ARR_PTR]], null -// CHECK: %[[FULL_ARR_SIZE_SEL:.*]] = select i1 %[[FULL_ARR_SIZE_CMP]], i64 0, i64 %[[FULL_ARR_SIZE]] +// CHECK: %[[FULL_ARR_SZ:.*]] = sdiv exact i64 40, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) +// CHECK: %[[NULL_CMP:.*]] = icmp eq ptr %[[FULL_ARR_PTR]], null +// CHECK: %[[IS_NULL:.*]] = select i1 %[[NULL_CMP]], i64 0, i64 %[[FULL_ARR_SIZE]] // CHECK: %[[ARR_SECT_DESC_SIZE:.*]] = sdiv exact i64 48, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) -// CHECK: %[[ARR_SECT_SIZE_CMP:.*]] = icmp eq ptr %[[ARR_SECT_PTR]], null -// CHECK: %[[ARR_SECT_SIZE_SEL:.*]] = select i1 %[[ARR_SECT_SIZE_CMP]], i64 0, i64 %[[ARR_SECT_SIZE]] +// CHECK: %[[ARR_SECT_SZ:.*]] = sdiv exact i64 40, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) +// CHECK: %[[NULL_CMP2:.*]] = icmp eq ptr %[[ARR_SECT_PTR]], null +// CHECK: %[[IS_NULL2:.*]] = select i1 %[[NULL_CMP2]], i64 0, i64 %[[ARR_SECT_SIZE]] // CHECK: %[[SCALAR_DESC_SZ4:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[SCALAR_ALLOCA]], i32 1 // CHECK: %[[SCALAR_DESC_SZ3:.*]] = ptrtoint ptr %[[SCALAR_DESC_SZ4]] to i64 // CHECK: %[[SCALAR_DESC_SZ2:.*]] = ptrtoint ptr %[[SCALAR_ALLOCA]] to i64 // CHECK: %[[SCALAR_DESC_SZ1:.*]] = sub i64 %[[SCALAR_DESC_SZ3]], %[[SCALAR_DESC_SZ2]] // CHECK: %[[SCALAR_DESC_SZ:.*]] = sdiv exact i64 %[[SCALAR_DESC_SZ1]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) - -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK: %[[SCALAR_BASE_2:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[SCALAR_ALLOCA]], i32 1 +// CHECK: %[[SCALAR_BASE_OFF:.*]] = getelementptr ptr, ptr %[[SCALAR_BASE]], i32 1 +// CHECK: %[[SCALAR_BASE_OFF_SZ1:.*]] = ptrtoint ptr %[[SCALAR_BASE_2]] to i64 +// CHECK: %[[SCALAR_BASE_OFF_SZ2:.*]] = ptrtoint ptr %[[SCALAR_BASE_OFF]] to i64 +// CHECK: %[[SCALAR_BASE_OFF_SZ3:.*]] = sub i64 %[[SCALAR_BASE_OFF_SZ1]], %[[SCALAR_BASE_OFF_SZ2]] +// CHECK: %[[SCALAR_BASE_OFF_SZ4:.*]] = sdiv exact i64 %[[SCALAR_BASE_OFF_SZ3]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64) +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 0 // CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 0 // CHECK: store i64 %[[FULL_ARR_DESC_SIZE]], ptr %[[OFFLOADSIZES]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 1 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 1 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 1 // CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 2 +// CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 2 +// CHECK: store ptr getelementptr inbounds nuw (i8, ptr @full_arr, i64 8), ptr %[[OFFLOADPTRS]], align 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 2 +// CHECK: store i64 %[[FULL_ARR_SZ]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 3 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 2 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 3 // CHECK: store ptr @full_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 3 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 // CHECK: store ptr @full_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 3 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 4 // CHECK: store ptr %[[FULL_ARR_PTR]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 3 -// CHECK: store i64 %[[FULL_ARR_SIZE_SEL]], ptr %[[OFFLOADSIZES]], align 8 - -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 4 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 4 +// CHECK: store i64 %[[IS_NULL]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 5 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 4 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 5 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 4 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 5 // CHECK: store i64 %[[ARR_SECT_DESC_SIZE]], ptr %[[OFFLOADSIZES]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 5 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 5 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 6 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 6 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 7 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 6 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 7 +// CHECK: store ptr getelementptr inbounds nuw (i8, ptr @sect_arr, i64 8), ptr %[[OFFLOADPTRS]], align 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 7 +// CHECK: store i64 %[[ARR_SECT_SZ]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 8 +// CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 8 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 7 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 // CHECK: store ptr @sect_arr, ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 7 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 9 // CHECK: store ptr %[[ARR_SECT_PTR]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 7 -// CHECK: store i64 %[[ARR_SECT_SIZE_SEL]], ptr %[[OFFLOADSIZES]], align 8 - -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 9 +// CHECK: store i64 %[[IS_NULL2]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 10 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 10 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [12 x i64], ptr %.offload_sizes, i32 0, i32 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 10 // CHECK: store i64 %[[SCALAR_DESC_SZ]], ptr %[[OFFLOADSIZES]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 9 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 11 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 9 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 11 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 10 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 12 +// CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 12 +// CHECK: store ptr %[[SCALAR_BASE_OFF]], ptr %[[OFFLOADPTRS]], align 8 +// CHECK: %[[OFFLOADSIZES:.*]] = getelementptr inbounds [15 x i64], ptr %.offload_sizes, i32 0, i32 12 +// CHECK: store i64 %[[SCALAR_BASE_OFF_SZ4]], ptr %[[OFFLOADSIZES]], align 8 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 13 // CHECK: store ptr %[[SCALAR_ALLOCA]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 10 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 13 // CHECK: store ptr %[[SCALAR_BASE]], ptr %[[OFFLOADPTRS]], align 8 -// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_baseptrs, i32 0, i32 11 +// CHECK: %[[OFFLOADBASEPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_baseptrs, i32 0, i32 14 // CHECK: store ptr %[[SCALAR_BASE]], ptr %[[OFFLOADBASEPTRS]], align 8 -// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [12 x ptr], ptr %.offload_ptrs, i32 0, i32 11 +// CHECK: %[[OFFLOADPTRS:.*]] = getelementptr inbounds [15 x ptr], ptr %.offload_ptrs, i32 0, i32 14 // CHECK: store ptr %[[SCALAR_PTR_LOAD]], ptr %[[OFFLOADPTRS]], align 8 diff --git a/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir b/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir new file mode 100644 index 0000000..a232bd7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-runtimecc.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_target_device = true, omp.is_gpu = true, omp.target_triples = ["spirv64-intel"], llvm.target_triple = "spirv64-intel"} { +// CHECK: call spir_func i32 @__kmpc_target_init +// CHECK: call spir_func void @__kmpc_target_deinit + llvm.func @target_if_variable(%x : i1) { + omp.target if(%x) { + omp.terminator + } + llvm.return + } + } diff --git a/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir index 9aba72d..b7cb102 100644 --- a/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction.mlir @@ -59,8 +59,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: call void @__kmpc_barrier // CHECK: [[THEN]]: -// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_LHS:[A-Za-z0-9_.]*]] = load i32 +// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_RESULT:[A-Za-z0-9_.]*]] = add i32 %[[FINAL_LHS]], %[[FINAL_RHS]] // CHECK-NEXT: store i32 %[[FINAL_RESULT]] diff --git a/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir b/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir index dc22fe1..36eb280 100644 --- a/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-teams-reduction.mlir @@ -62,8 +62,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: icmp eq i32 %[[MASTER]], 1 // CHECK: i1 %{{.+}}, label %[[THEN:[A-Za-z0-9_.]*]], label %[[DONE:[A-Za-z0-9_.]*]] // CHECK: [[THEN]]: -// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_LHS:[A-Za-z0-9_.]*]] = load i32 +// CHECK-NEXT: %[[FINAL_RHS:[A-Za-z0-9_.]*]] = load i32 // CHECK-NEXT: %[[FINAL_RESULT:[A-Za-z0-9_.]*]] = add i32 %[[FINAL_LHS]], %[[FINAL_RHS]] // CHECK-NEXT: store i32 %[[FINAL_RESULT]] diff --git a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir index c4b2456..6585549 100644 --- a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir +++ b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir @@ -29,22 +29,24 @@ llvm.func @test() { // CHECK: %[[VAL_14:.*]] = icmp eq i32 %[[VAL_13]], 0 // CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_17:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) -// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_17]]) -// CHECK: br label %[[VAL_19:.*]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[CNCL_BARRIER:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[TID]]) +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_11]] // CHECK: %[[VAL_20:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_21:.*]] = call i32 @__kmpc_cancel_barrier(ptr @3, i32 %[[VAL_20]]) // CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_21]], 0 // CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_24:.*]] // CHECK: omp.par.region1.split.cncl: ; preds = %[[VAL_15]] -// CHECK: br label %[[VAL_19]] +// CHECK: br label %[[FINI]] // CHECK: omp.par.region1.split.cont: ; preds = %[[VAL_15]] // CHECK: br label %[[VAL_25:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_23]] // CHECK: br label %[[VAL_26:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] -// CHECK: br label %[[VAL_19]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_26]], %[[VAL_24]], %[[VAL_16]] +// CHECK: br label %[[FINI]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir index 2124170..a6911f8 100644 --- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir +++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir @@ -24,16 +24,18 @@ llvm.func @cancel_parallel() { // CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 // CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] +// CHECK: br label %[[VAL_20:.*]] +// CHECK: .fini: // CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) -// CHECK: br label %[[VAL_20:.*]] +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] // CHECK: br label %[[VAL_21:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_16]] // CHECK: br label %[[VAL_22:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] // CHECK: br label %[[VAL_20]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancel_parallel_if(%arg0 : i1) { @@ -58,27 +60,36 @@ llvm.func @cancel_parallel_if(%arg0 : i1) { // CHECK: omp.par.region: ; preds = %[[VAL_17]] // CHECK: br label %[[VAL_20:.*]] // CHECK: omp.par.region1: ; preds = %[[VAL_19]] -// CHECK: br i1 %[[VAL_16]], label %[[VAL_21:.*]], label %[[VAL_22:.*]] +// CHECK: br i1 %[[VAL_16]], label %[[SPLIT:.*]], label %[[VAL_22:.*]] // CHECK: 3: ; preds = %[[VAL_20]] -// CHECK: br label %[[VAL_23:.*]] -// CHECK: 4: ; preds = %[[VAL_22]], %[[VAL_24:.*]] +// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[NOT_CANCELLED:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 1) +// CHECK: %[[COND:.*]] = icmp eq i32 %[[NOT_CANCELLED]], 0 +// CHECK: br i1 %[[COND]], label %[[VAL_23:.*]], label %[[CNCL:.*]] +// CHECK: .cncl: +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: +// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) +// CHECK: br label %[[EXIT_STUB:.*]] +// CHECK: .split: +// CHECK: br label %[[SEVEN:.*]] +// CHECK: 7: // CHECK: br label %[[VAL_25:.*]] -// CHECK: omp.region.cont: ; preds = %[[VAL_23]] +// CHECK: omp.region.cont: // CHECK: br label %[[VAL_26:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] // CHECK: br label %[[VAL_27:.*]] -// CHECK: 5: ; preds = %[[VAL_20]] +// CHECK: 8: ; preds = %[[VAL_20]] // CHECK: %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1) // CHECK: %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0 -// CHECK: br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]] -// CHECK: .cncl: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) -// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) -// CHECK: br label %[[VAL_27]] -// CHECK: .split: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_23]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]] +// CHECK: br i1 %[[VAL_30]], label %[[SPLIT5:.*]], label %[[VAL_31:.*]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[FINI]] +// CHECK: .split{{.*}}: +// CHECK: br label %[[SEVEN]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancel_sections_if(%cond : i1) { @@ -132,11 +143,16 @@ llvm.func @cancel_sections_if(%cond : i1) { // CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 3) // CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 // CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] -// CHECK: .split: ; preds = %[[VAL_27]] +// CHECK: .split{{.*}}: ; preds = %[[VAL_27]] // CHECK: br label %[[VAL_34:.*]] // CHECK: 12: ; preds = %[[VAL_25]] +// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[CANCEL_POINT:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 3) +// CHECK: %[[COND:.*]] = icmp eq i32 %13, 0 +// CHECK: br i1 %[[COND]], label %[[SPLIT:.*]], label %[[CNCL:.*]] +// CHECK: .split{{.*}}: // CHECK: br label %[[VAL_34]] -// CHECK: 13: ; preds = %[[VAL_28]], %[[VAL_32]] +// CHECK: 15: // CHECK: br label %[[VAL_35:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_34]] // CHECK: br label %[[VAL_23]] @@ -145,17 +161,17 @@ llvm.func @cancel_sections_if(%cond : i1) { // CHECK: omp_section_loop.inc: ; preds = %[[VAL_23]] // CHECK: %[[VAL_15]] = add nuw i32 %[[VAL_14]], 1 // CHECK: br label %[[VAL_12]] -// CHECK: omp_section_loop.exit: ; preds = %[[VAL_33]], %[[VAL_16]] +// CHECK: omp_section_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_7]]) // CHECK: %[[VAL_36:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]]) // CHECK: br label %[[VAL_37:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_19]] -// CHECK: br label %[[VAL_38:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_37]] // CHECK: ret void -// CHECK: .cncl: ; preds = %[[VAL_27]] -// CHECK: br label %[[VAL_19]] +// CHECK: .cncl: +// CHECK: br label %[[OMP_SECTION_LOOP_EXIT:.*]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[OMP_SECTION_LOOP_EXIT:.*]] llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { omp.wsloop { @@ -221,18 +237,23 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { // CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2) // CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0 // CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]] -// CHECK: .split: ; preds = %[[VAL_44]] +// CHECK: .split{{.*}}: // CHECK: br label %[[VAL_51:.*]] -// CHECK: 28: ; preds = %[[VAL_42]] +// CHECK: 28: +// CHECK: %[[GTN:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[CANCEL_POINT:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[GTN]], i32 2) +// CHECK: %[[COND:.*]] = icmp eq i32 %[[CANCEL_POINT]], 0 +// CHECK: br i1 %[[COND]], label %[[SPLIT3:.*]], label %[[CNCL4:.*]] +// CHECK: .split{{.*}}: // CHECK: br label %[[VAL_51]] -// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]] +// CHECK: 31: // CHECK: br label %[[VAL_52:.*]] // CHECK: omp.region.cont1: ; preds = %[[VAL_51]] // CHECK: br label %[[VAL_32]] // CHECK: omp_loop.inc: ; preds = %[[VAL_52]] // CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1 // CHECK: br label %[[VAL_31]] -// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]] +// CHECK: omp_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]]) // CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]]) @@ -241,8 +262,12 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { // CHECK: br label %[[VAL_55:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_54]] // CHECK: ret void -// CHECK: .cncl: ; preds = %[[VAL_44]] -// CHECK: br label %[[VAL_38]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: +// CHECK: br label %[[OMP_LOOP_EXIT:.*]] +// CHECK: .cncl{{.*}}: +// CHECK: br label %[[FINI:.*]] omp.private {type = firstprivate} @i32_priv : i32 copy { ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir index 5e0d3f9..93fa2064 100644 --- a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir +++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir @@ -24,16 +24,18 @@ llvm.func @cancellation_point_parallel() { // CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 // CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: // CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) -// CHECK: br label %[[VAL_20:.*]] +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] // CHECK: br label %[[VAL_21:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_16]] // CHECK: br label %[[VAL_22:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_20]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] +// CHECK: br label %[[FINI]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancellation_point_sections() { @@ -94,14 +96,12 @@ llvm.func @cancellation_point_sections() { // CHECK: omp_section_loop.inc: ; preds = %[[VAL_46]] // CHECK: %[[VAL_38]] = add nuw i32 %[[VAL_37]], 1 // CHECK: br label %[[VAL_35]] -// CHECK: omp_section_loop.exit: ; preds = %[[VAL_53]], %[[VAL_39]] +// CHECK: omp_section_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_30]]) // CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]]) // CHECK: br label %[[VAL_56:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_42]] -// CHECK: br label %[[VAL_57:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_56]] // CHECK: ret void // CHECK: omp.section.region.cncl: ; preds = %[[VAL_48]] // CHECK: br label %[[VAL_42]] @@ -175,7 +175,7 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) { // CHECK: omp_loop.inc: ; preds = %[[VAL_106]] // CHECK: %[[VAL_92]] = add nuw i32 %[[VAL_91]], 1 // CHECK: br label %[[VAL_89]] -// CHECK: omp_loop.exit: ; preds = %[[VAL_105]], %[[VAL_93]] +// CHECK: omp_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_84]]) // CHECK: %[[VAL_107:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_107]]) diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir new file mode 100644 index 0000000..a0dd556 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir @@ -0,0 +1,34 @@ +// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate + +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +llvm.func @distribute_dist_schedule_chunk_size(%lb : i32, %ub : i32, %step : i32, %x : i32) { + // CHECK: call void @[[RUNTIME_FUNC:__kmpc_for_static_init_4u]](ptr @1, i32 %omp_global_thread_num, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024) + // We want to make sure that the next call is not another init builder. + // CHECK-NOT: call void @[[RUNTIME_FUNC]] + %1 = llvm.mlir.constant(1024: i32) : i32 + omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } + llvm.return +} + +// When a chunk size is present, we need to make sure the correct parallel accesses metadata is added +// CHECK: !2 = !{!"llvm.loop.parallel_accesses", !3} +// CHECK-NEXT: !3 = distinct !{} + +// ----- + +llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) { + // CHECK: call void @[[RUNTIME_FUNC:__kmpc_for_static_init_4u]](ptr @1, i32 %omp_global_thread_num, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0) + // We want to make sure that the next call is not another init builder. + // CHECK-NOT: call void @[[RUNTIME_FUNC]] + omp.distribute dist_schedule_static { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir new file mode 100644 index 0000000..dad32b4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir @@ -0,0 +1,205 @@ +// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate while using workshare loops. + +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked(%n: i32, %teams: i32, %threads: i32, %dcs: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %scs = llvm.mlir.constant(64 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) { + omp.wsloop schedule(static = %scs : i32) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 %3) + +llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %dcs = llvm.mlir.constant(1024 : i64) : i64 + %scs = llvm.mlir.constant(64 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) { + omp.wsloop schedule(static = %scs : i64) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64) +// call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024) + +// ----- + +llvm.func @distribute_wsloop_dist_schedule_chunked(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %dcs = llvm.mlir.constant(1024 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024) + +llvm.func @distribute_wsloop_dist_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %dcs = llvm.mlir.constant(1024 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0) +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024) + +// ----- + +llvm.func @distribute_wsloop_schedule_chunked(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + %scs = llvm.mlir.constant(64 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static = %scs : i32) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64) +// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0) + +llvm.func @distribute_wsloop_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %scs = llvm.mlir.constant(64 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static = %scs : i64) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} + +// CHECK: define internal void @distribute_wsloop_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64) +// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0) + +// ----- + +llvm.func @distribute_wsloop_no_chunks(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_no_chunks..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i32 1, i32 0) +// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i32 1, i32 0) + +llvm.func @distribute_wsloop_no_chunks_i64(%n: i32, %teams: i32, %threads: i32) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %n64 = llvm.zext %n : i32 to i64 + + omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) { + omp.parallel { + omp.distribute dist_schedule_static { + omp.wsloop schedule(static) { + omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) { + omp.yield + } + } {omp.composite} + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + llvm.return +} +// CHECK: define internal void @distribute_wsloop_no_chunks_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 { +// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i64 1, i64 0) +// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i64 1, i64 0)
\ No newline at end of file diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 8bd33a3..1eb501c 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -328,6 +328,52 @@ llvm.func @test_omp_masked(%arg0: i32)-> () { // ----- +llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { +// CHECK-LABEL: @wsloop_linear + +// CHECK: %p.lastiter = alloca i32, align 4 +// CHECK: %p.lowerbound = alloca i32, align 4 +// CHECK: %p.upperbound = alloca i32, align 4 +// CHECK: %p.stride = alloca i32, align 4 +// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4 +// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4 + +// CHECK: omp_loop.preheader: +// CHECK: %[[LOAD:.*]] = load i32, ptr %{{.*}}, align 4 +// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4 + +// CHECK: omp_loop.body: +// CHECK: %[[LOOP_IV_CALC:.*]] = add i32 %omp_loop.iv, {{.*}} +// CHECK: %[[LINEAR_VAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4 +// CHECK: %[[MUL:.*]] = mul i32 %[[LOOP_IV_CALC]], {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_VAR_LOAD]], %[[MUL]] +// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4 + +// CHECK: omp_loop.linear_finalization: +// CHECK: %[[ITER:.*]] = load i32, ptr %p.lastiter, align 4 +// CHECK: %[[CMP:.*]] = icmp ne i32 %[[ITER]], 0 +// CHECK: br i1 %[[CMP]], label %omp_loop.linear_lastiter_exit, label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_lastiter_exit: +// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4 +// CHECK: store i32 %[[LOAD]], ptr {{.*}}, align 4 +// CHECK: br label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_exit: +// CHECK: %[[THREAD_ID:.*]] = call i32 @__kmpc_global_thread_num(ptr {{.*}}) +// CHECK: call void @__kmpc_barrier(ptr {{.*}}, i32 %[[THREAD_ID]]) +// CHECK: br label %omp_loop.after + + omp.wsloop linear(%x = %step : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {linear_var_types = [i32]} + llvm.return +} + +// ----- + // CHECK: %struct.ident_t = type // CHECK: @[[$loc:.*]] = private unnamed_addr constant {{.*}} c";unknown;unknown;{{[0-9]+}};{{[0-9]+}};;\00" // CHECK: @[[$loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$loc]] {{.*}} @@ -695,6 +741,34 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) { // ----- +llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { + +// CHECK-LABEL: @simd_linear + +// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4 +// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4 + +// CHECK: omp_loop.preheader: +// CHECK: %[[LOAD:.*]] = load i32, ptr {{.*}}, align 4 +// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4 + +// CHECK: omp_loop.body: +// CHECK: %[[LOOP_IV_CALC:.*]] = mul i32 %omp_loop.iv, {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LOOP_IV_CALC]], {{.*}} +// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4, !llvm.access.group !1 +// CHECK: %[[MUL:.*]] = mul i32 %omp_loop.iv, {{.*}} +// CHECK: %[[ADD:.*]] = add i32 %[[LOAD]], %[[MUL]] +// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4, !llvm.access.group !1 + omp.simd linear(%x = %step : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } {linear_var_types = [i32]} + llvm.return +} + +// ----- + // CHECK-LABEL: @simd_simple_multiple llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd { diff --git a/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir b/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir index faccfc6..99f37c7 100644 --- a/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir +++ b/mlir/test/Target/LLVMIR/openmp-outline-infinite-loop.mlir @@ -21,9 +21,11 @@ llvm.func @parallel_infinite_loop() -> () { // CHECK: omp.region.cont: ; No predecessors! // CHECK: br label %[[VAL_4:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_5:.*]] -// CHECK: br label %[[VAL_6:.*]] -// CHECK: omp.par.exit: ; preds = %[[VAL_4]] +// CHECK: br label %[[FINI:.*]] +// CHECK: [[OMP_PAR_EXIT:omp.par.exit]]: ; preds = %[[FINI]] // CHECK: ret void +// CHECK: [[FINI]]: +// CHECK: br label %[[OMP_PAR_EXIT]] // CHECK: } // CHECK-LABEL: define internal void @parallel_infinite_loop..omp_par( diff --git a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir index 887d297..c79c369 100644 --- a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir +++ b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-multiblock.mlir @@ -108,6 +108,8 @@ llvm.func @missordered_blocks_(%arg0: !llvm.ptr {fir.bindc_name = "x"}, %arg1: ! // CHECK: reduce.finalize: ; preds = %[[VAL_49]], %[[VAL_43]] // CHECK: br label %[[VAL_53:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_48]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: // CHECK: %[[VAL_54:.*]] = load ptr, ptr %[[VAL_20]], align 8 // CHECK: %[[VAL_55:.*]] = load ptr, ptr %[[VAL_21]], align 8 // CHECK: br label %[[VAL_56:.*]] @@ -115,5 +117,5 @@ llvm.func @missordered_blocks_(%arg0: !llvm.ptr {fir.bindc_name = "x"}, %arg1: ! // CHECK: br label %[[VAL_38]] // CHECK: omp.reduction.neutral1: ; preds = %[[VAL_25]] // CHECK: br label %[[VAL_30]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_53]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir index b302b4b..13f52f0 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-array-sections.mlir @@ -127,8 +127,6 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]]) // CHECK: br label %[[VAL_37:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_35]] -// CHECK: br label %[[VAL_38:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_37]] // CHECK: %[[VAL_39:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_14]], i64 0, i64 0 // CHECK: store ptr %[[VAL_21]], ptr %[[VAL_39]], align 8 // CHECK: %[[VAL_40:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) @@ -137,9 +135,9 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: i32 1, label %[[VAL_43:.*]] // CHECK: i32 2, label %[[VAL_44:.*]] // CHECK: ] -// CHECK: reduce.switch.atomic: ; preds = %[[VAL_38]] +// CHECK: reduce.switch.atomic: ; preds = %[[VAL_37]] // CHECK: unreachable -// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_38]] +// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_37]] // CHECK: %[[VAL_45:.*]] = load ptr, ptr %[[VAL_21]], align 8 // CHECK: br label %[[VAL_46:.*]] // CHECK: omp.reduction.nonatomic.body: ; preds = %[[VAL_43]] @@ -157,7 +155,7 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: omp.reduction.nonatomic.body17: ; preds = %[[VAL_47]] // CHECK: %[[VAL_50]] = sub i64 %[[VAL_49]], 1 // CHECK: br label %[[VAL_47]] -// CHECK: reduce.finalize: ; preds = %[[VAL_53]], %[[VAL_38]] +// CHECK: reduce.finalize: ; preds = %[[VAL_53]], %[[VAL_37]] // CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]]) // CHECK: %[[VAL_56:.*]] = load ptr, ptr %[[VAL_21]], align 8 @@ -173,7 +171,9 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: omp.region.cont: ; preds = %[[VAL_62]] // CHECK: br label %[[VAL_64:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_63]] -// CHECK: br label %[[VAL_65:.*]] +// CHECK: br label %[[FINI:.fini.*]] +// CHECK: [[FINI]]: +// CHECK: br label %[[EXIT:.*]] // CHECK: omp.reduction.cleanup21: ; preds = %[[VAL_57]] // CHECK: br label %[[VAL_61]] // CHECK: omp_section_loop.body: ; preds = %[[VAL_32]] @@ -219,5 +219,5 @@ llvm.func @sectionsreduction_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attribute // CHECK: omp_section_loop.inc: ; preds = %[[VAL_69]] // CHECK: %[[VAL_31]] = add nuw i32 %[[VAL_30]], 1 // CHECK: br label %[[VAL_28]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_64]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir index a714ca6..cb30d3b 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir @@ -96,8 +96,10 @@ module { // CHECK: reduce.finalize: ; preds = %[[VAL_34]], %[[VAL_28]] // CHECK: br label %[[VAL_38:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_33]] +// CHECK: br label %[[FINI:.*]] +// CHECK: [[FINI]]: // CHECK: br label %[[VAL_39:.*]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_38]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void // CHECK: %[[VAL_40:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_41:.*]], i64 0, i64 0 // CHECK: %[[VAL_42:.*]] = load ptr, ptr %[[VAL_40]], align 8 diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir index 19da6f8..00f6c1b 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-sections.mlir @@ -86,8 +86,6 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_40]]) // CHECK: br label %[[VAL_41:.*]] // CHECK: omp_section_loop.after: ; preds = %[[VAL_39]] -// CHECK: br label %[[VAL_42:.*]] -// CHECK: omp_section_loop.aftersections.fini: ; preds = %[[VAL_41]] // CHECK: %[[VAL_43:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_21]], i64 0, i64 0 // CHECK: store ptr %[[VAL_20]], ptr %[[VAL_43]], align 8 // CHECK: %[[VAL_44:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) @@ -96,23 +94,25 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in // CHECK: i32 1, label %[[VAL_47:.*]] // CHECK: i32 2, label %[[VAL_48:.*]] // CHECK: ] -// CHECK: reduce.switch.atomic: ; preds = %[[VAL_42]] +// CHECK: reduce.switch.atomic: ; preds = %[[VAL_41]] // CHECK: unreachable -// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_42]] +// CHECK: reduce.switch.nonatomic: ; preds = %[[VAL_41]] // CHECK: %[[VAL_49:.*]] = load float, ptr %[[VAL_11]], align 4 // CHECK: %[[VAL_50:.*]] = load float, ptr %[[VAL_20]], align 4 // CHECK: %[[VAL_51:.*]] = fadd contract float %[[VAL_49]], %[[VAL_50]] // CHECK: store float %[[VAL_51]], ptr %[[VAL_11]], align 4 // CHECK: call void @__kmpc_end_reduce(ptr @1, i32 %[[VAL_44]], ptr @.gomp_critical_user_.reduction.var) // CHECK: br label %[[VAL_46]] -// CHECK: reduce.finalize: ; preds = %[[VAL_47]], %[[VAL_42]] +// CHECK: reduce.finalize: ; preds = %[[VAL_47]], %[[VAL_41]] // CHECK: %[[VAL_52:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_52]]) // CHECK: br label %[[VAL_53:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_46]] // CHECK: br label %[[VAL_54:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_53]] -// CHECK: br label %[[VAL_55:.*]] +// CHECK: br label %[[FINI:.fini.*]] +// CHECK: [[FINI]]: +// CHECK: br label %[[EXIT:.*]] // CHECK: omp_section_loop.body: ; preds = %[[VAL_36]] // CHECK: %[[VAL_56:.*]] = add i32 %[[VAL_34]], %[[VAL_28]] // CHECK: %[[VAL_57:.*]] = mul i32 %[[VAL_56]], 1 @@ -144,8 +144,10 @@ llvm.func @sections_(%arg0: !llvm.ptr {fir.bindc_name = "x"}) attributes {fir.in // CHECK: omp_section_loop.inc: ; preds = %[[VAL_59]] // CHECK: %[[VAL_35]] = add nuw i32 %[[VAL_34]], 1 // CHECK: br label %[[VAL_32]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_54]] +// CHECK: omp.par.exit.exitStub: ; preds = %[[FINI]] // CHECK: ret void + +// CHECK-LABEL: define internal void @.omp.reduction.func // CHECK: %[[VAL_70:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_71:.*]], i64 0, i64 0 // CHECK: %[[VAL_72:.*]] = load ptr, ptr %[[VAL_70]], align 8 // CHECK: %[[VAL_73:.*]] = load float, ptr %[[VAL_72]], align 4 diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir index 504d91b..5c37817 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir @@ -102,7 +102,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // DEVICE: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}}) // DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}}) -// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}) +// DEVICE: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}, i32 {{.*}}) // DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}}) // DEVICE: call void @__kmpc_for_static_loop{{.*}}({{.*}}) diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir index 20202fc..dae80ba 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir @@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // DEVICE: call void @__kmpc_target_deinit() // DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}}) -// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}) +// DEVICE: call void @__kmpc_parallel_60(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}}, i32 {{.*}}) // DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}}) // DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}}) diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index af6d254..396c57a 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -39,19 +39,6 @@ llvm.func @distribute_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr // ----- -llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) { - // expected-error@below {{not yet implemented: Unhandled clause dist_schedule with chunk_size in omp.distribute operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.distribute}} - omp.distribute dist_schedule_static dist_schedule_chunk_size(%x : i32) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - -// ----- - llvm.func @distribute_order(%lb : i32, %ub : i32, %step : i32) { // expected-error@below {{not yet implemented: Unhandled clause order in omp.distribute operation}} // expected-error@below {{LLVM Translation failed for operation: omp.distribute}} @@ -116,19 +103,6 @@ llvm.func @sections_private(%x : !llvm.ptr) { // ----- -llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause linear in omp.simd operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.simd}} - omp.simd linear(%x = %step : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - -// ----- - omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32): @@ -238,17 +212,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) { // ----- -llvm.func @target_is_device_ptr(%x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target is_device_ptr(%x : !llvm.ptr) { - omp.terminator - } - llvm.return -} - -// ----- - llvm.func @target_enter_data_depend(%x: !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}} // expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}} @@ -448,19 +411,6 @@ llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { } // ----- - -llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause linear in omp.wsloop operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} - omp.wsloop linear(%x = %step : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - -// ----- llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) { // expected-error@below {{not yet implemented: Unhandled clause order in omp.wsloop operation}} // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 3fbd9e0..2c748ad5 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -14,30 +14,36 @@ llvm.func @rocdl_special_regs() -> i32 { %5 = rocdl.workgroup.id.y : i32 // CHECK: call i32 @llvm.amdgcn.workgroup.id.z() %6 = rocdl.workgroup.id.z : i32 + // CHECK: call i32 @llvm.amdgcn.cluster.id.x() + %7 = rocdl.cluster.id.x : i32 + // CHECK: call i32 @llvm.amdgcn.cluster.id.y() + %8 = rocdl.cluster.id.y : i32 + // CHECK: call i32 @llvm.amdgcn.cluster.id.z() + %9 = rocdl.cluster.id.z : i32 // CHECK: call i64 @__ockl_get_local_size(i32 0) - %7 = rocdl.workgroup.dim.x : i64 + %10 = rocdl.workgroup.dim.x : i64 // CHECK: call i64 @__ockl_get_local_size(i32 1) - %8 = rocdl.workgroup.dim.y : i64 + %11 = rocdl.workgroup.dim.y : i64 // CHECK: call i64 @__ockl_get_local_size(i32 2) - %9 = rocdl.workgroup.dim.z : i64 + %12 = rocdl.workgroup.dim.z : i64 // CHECK: call i64 @__ockl_get_num_groups(i32 0) - %10 = rocdl.grid.dim.x : i64 + %13 = rocdl.grid.dim.x : i64 // CHECK: call i64 @__ockl_get_num_groups(i32 1) - %11 = rocdl.grid.dim.y : i64 + %14 = rocdl.grid.dim.y : i64 // CHECK: call i64 @__ockl_get_num_groups(i32 2) - %12 = rocdl.grid.dim.z : i64 + %15 = rocdl.grid.dim.z : i64 // CHECK: call range(i32 0, 64) i32 @llvm.amdgcn.workitem.id.x() - %13 = rocdl.workitem.id.x range <i32, 0, 64> : i32 + %16 = rocdl.workitem.id.x range <i32, 0, 64> : i32 // CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0) - %14 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64 + %17 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64 // CHECK: call i32 @llvm.amdgcn.wavefrontsize() - %15 = rocdl.wavefrontsize : i32 + %18 = rocdl.wavefrontsize : i32 // CHECK: call range(i32 32, 65) i32 @llvm.amdgcn.wavefrontsize() - %16 = rocdl.wavefrontsize range <i32, 32, 65> : i32 + %19 = rocdl.wavefrontsize range <i32, 32, 65> : i32 llvm.return %1 : i32 } @@ -55,6 +61,59 @@ llvm.func @kernel_func_workgroups() llvm.return } +llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) { + // CHECK-LABEL: kernel_math_ops + // CHECK: call float @llvm.amdgcn.tanh.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.tanh.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.tanh.bf16(bfloat %{{.*}}) + %tanh0 = rocdl.tanh %a f32 -> f32 + %tanh1 = rocdl.tanh %b f16 -> f16 + %tanh2 = rocdl.tanh %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.sin.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.sin.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.sin.bf16(bfloat %{{.*}}) + %sin0 = rocdl.sin %a f32 -> f32 + %sin1 = rocdl.sin %b f16 -> f16 + %sin2 = rocdl.sin %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.cos.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.cos.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.cos.bf16(bfloat %{{.*}}) + %cos0 = rocdl.cos %a f32 -> f32 + %cos1 = rocdl.cos %b f16 -> f16 + %cos2 = rocdl.cos %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.rcp.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.rcp.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.rcp.bf16(bfloat %{{.*}}) + %rcp0 = rocdl.rcp %a f32 -> f32 + %rcp1 = rocdl.rcp %b f16 -> f16 + %rcp2 = rocdl.rcp %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.exp2.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.exp2.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.exp2.bf16(bfloat %{{.*}}) + %exp2_0 = rocdl.exp2 %a f32 -> f32 + %exp2_1 = rocdl.exp2 %b f16 -> f16 + %exp2_2 = rocdl.exp2 %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.log.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.log.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.log.bf16(bfloat %{{.*}}) + %log0 = rocdl.log %a f32 -> f32 + %log1 = rocdl.log %b f16 -> f16 + %log2 = rocdl.log %c bf16 -> bf16 + + // CHECK: call float @llvm.amdgcn.sqrt.f32(float %{{.*}}) + // CHECK: call half @llvm.amdgcn.sqrt.f16(half %{{.*}}) + // CHECK: call bfloat @llvm.amdgcn.sqrt.bf16(bfloat %{{.*}}) + %sqrt0 = rocdl.sqrt %a f32 -> f32 + %sqrt1 = rocdl.sqrt %b f16 -> f16 + %sqrt2 = rocdl.sqrt %c bf16 -> bf16 + llvm.return +} + llvm.func @known_block_sizes() attributes {rocdl.kernel, rocdl.flat_work_group_size = "128,128", @@ -248,6 +307,13 @@ llvm.func @rocdl.s.get.barrier.state() { llvm.return } +llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) { + // CHECK-LABEL: rocdl.s.get.named.barrier.state + // CHECK: %[[STATE:.+]] = call i32 @llvm.amdgcn.s.get.named.barrier.state(ptr addrspace(3) %[[PTR:.+]]) + %0 = rocdl.s.get.named.barrier.state %ptr : i32 + llvm.return +} + llvm.func @rocdl.s.wait.dscnt() { // CHECK-LABEL: rocdl.s.wait.dscnt // CHECK-NEXT: call void @llvm.amdgcn.s.wait.dscnt(i16 0) @@ -875,140 +941,182 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>, %arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>, %arg13 : vector<64xf32>, %arg14 : vector<64xi32>, %arg15 : vector<64xf16>, %arg16 : vector<16xbf16>, %arg17 : vector<32xbf16>) -> vector<8xf32> { - %zero = llvm.mlir.constant(false) : i1 - %zero_i16 = llvm.mlir.constant(0 : i16) : i16 - // ---- Wave32 ----- + // ---- Wave32 ----- // f16 -> f32 - // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}}) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x float> %{{.*}}) %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> // bf16 -> f32 - // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}}) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x float> %{{.*}}) %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> // f16 -> f16 (OPSEL = {0,1}) - // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}}) - %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <16 x half> %{{.*}} i1 false) + %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16> // bf16 -> bf16 (OPSEL = {0,1}) - // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}}) - %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <16 x i16> %{{.*}} i1 false) + %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16> // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=false for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5a = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=false, signB=true for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5b = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = true, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=true, clamp=true for iu8 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true) + %r5c = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = true, signB = true, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=false for iu4 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6a = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = true, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=false, signB=true, clamp=true for iu4 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 true) + %r6b = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = true, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> + + // Test signA=true, signB=true for iu4 gfx12 + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 true, <2 x i32> %{{.*}} i1 true, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6c = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = true, signB = true, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> // f32 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}}, <16 x float> %{{.*}}, i1 {{.*}}, <16 x float> %{{.*}}, i16 0, <4 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32> + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %{{.*}} i1 false, <16 x float> %{{.*}} i16 0, <4 x float> %{{.*}} i1 false, i1 false) + %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32> // f16 -> f32 - // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32> + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32> // bf16 -> f32 - // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32> + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32> // f16 -> f16 - // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16> + // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x half> %{{.*}} i1 false, i1 false) + %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %arg1, %arg1, %arg9 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf16>) -> vector<32xf16> // bf16 -> bf16 - // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x bfloat> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg17, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xbf16>, i1, i1) -> vector<32xbf16> + // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x bfloat> %{{.*}} i1 false, i1 false) + %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %arg16, %arg16, %arg17 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xbf16>) -> vector<32xbf16> // bf16 -> bf16 / f32 - // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xbf16> + // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xbf16> // f8/bf8 -> f16/f32 - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> // iu8 -> i32 - // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <64 x i32> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg14, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32> + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false) + %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = false, signB = false} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=true for iu8 gfx1250 + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 true, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false) + %r23a.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=false, reuseA=true, reuseB=true for iu8 gfx1250 + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 true, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 true, i1 true) + %r23b.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = true, signB = false, reuseA = true, reuseB = true} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> + + // Test signA=true, signB=true with modC=1 for f32 gfx1250 + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 true, <16 x float> %{{.*}} i1 true, <16 x float> %{{.*}} i16 1, <4 x float> %{{.*}} i1 false, i1 false) + %r1a.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = true, signB = true, modC = 1 : i16, reuseA = false, reuseB = false} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32> + + // Test with modC=2 and signA=false, signB=true, reuseA=true for f16 gfx1250 + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 true, <16 x half> %{{.*}} i16 2, <32 x float> %{{.*}} i1 true, i1 false) + %r2a.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = true, modC = 2 : i16, reuseA = true, reuseB = false} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32> + + // Test with modC=3 and signA=true, signB=true, reuseB=true for bf16 gfx1250 + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 true, <16 x bfloat> %{{.*}} i1 true, <16 x bfloat> %{{.*}} i16 3, <32 x float> %{{.*}} i1 false, i1 true) + %r3a.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = true, signB = true, modC = 3 : i16, reuseA = false, reuseB = true} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32> // ---- Wave64 ----- // f16 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <4 x float> %{{.*}}) %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> // bf16 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}}) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <4 x float> %{{.*}}) %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> // f16 -> f16 (OPSEL = {0,1}) - // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}}) - %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x half> %{{.*}} i1 false) + %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16> // bf16 -> bf16 (OPSEL = {0,1}) - // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}}) - %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x i16> %{{.*}} i1 false) + %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16> // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) - %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true) + %r12 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg5 {signA = false, signB = false, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) - %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true) + %r13 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg5 {signA = false, signB = false, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<4xi32>) -> vector<4xi32> llvm.return %r0 : vector<8xf32> } @@ -1028,6 +1136,39 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> { llvm.return %r3 : vector<4xf16> } +llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) { + // CHECK-LABEL: rocdl.load.tr.ops + // CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]]) + // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x half> @llvm.amdgcn.global.load.tr.b128.v8f16(ptr addrspace(1) %[[GL_PTR]]) + // CHECK: call <8 x bfloat> @llvm.amdgcn.global.load.tr.b128.v8bf16(ptr addrspace(1) %[[GL_PTR]]) + + // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x half> @llvm.amdgcn.ds.load.tr16.b128.v8f16(ptr addrspace(3) %[[DS_PTR]]) + // CHECK: call <8 x bfloat> @llvm.amdgcn.ds.load.tr16.b128.v8bf16(ptr addrspace(3) %[[DS_PTR]]) + + rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2xi32> + rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3xi32> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xi16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xf16> + rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8xbf16> + + rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2xi32> + rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3xi32> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xi16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xf16> + rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8xbf16> + llvm.return +} + llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) { //CHECK: call void @llvm.amdgcn.load.to.lds.p7 rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7> @@ -1053,6 +1194,19 @@ llvm.func @rocdl.global.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3 llvm.return } +// CHECK-LABEL: rocdl.cluster.load.async.to.lds +llvm.func @rocdl.cluster.load.async.to.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) { + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b8 + rocdl.cluster.load.async.to.lds.b8 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b32 + rocdl.cluster.load.async.to.lds.b32 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b64 + rocdl.cluster.load.async.to.lds.b64 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + // CHECK: call void @llvm.amdgcn.cluster.load.async.to.lds.b128 + rocdl.cluster.load.async.to.lds.b128 %src, %dst, 0, 0, 0 : !llvm.ptr<1>, !llvm.ptr<3> + llvm.return +} + // CHECK-LABEL: rocdl.tensor.load.to.lds llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>, %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) { @@ -1187,6 +1341,113 @@ llvm.func @rocdl.raw.ptr.buffer.load.lds(%rsrc : !llvm.ptr<8>, %dstLds : !llvm.p llvm.return } +llvm.func @rocdl.wmma.scale(%arg0: i32, %arg1: vector<4xf32>, %arg2: vector<8xi32>, + %arg3: vector<12xi32>, %arg5: vector<16xi32>, + %arg8: i64, %arg9: vector<8xf32>) -> vector<4xf32> { + // CHECK-LABEL: rocdl.wmma.scale + + // Test with default attributes (all zeros/false) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 0, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false) + %r00 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 0 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with different matrix formats (FP8 x BF8) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false) + %r01 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 1 : i32, modC = 0 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with FP8 x FP6 (different vector sizes) and modC = 1 (negate) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 0, <16 x i32> %{{.*}}, i32 2, <12 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i32 2, i32 2, i32 %{{.*}}, i1 false, i1 false) + %r02 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 2 : i32, modC = 1 : i16, + scaleAType = 2 : i32, fmtScaleA = 2 : i32, + scaleBType = 2 : i32, fmtScaleB = 2 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with BF8 x BF6 and modC = 2 (abs) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v12i32(i32 1, <16 x i32> %{{.*}}, i32 3, <12 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 false) + %r03 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg3, %arg1, %arg0, %arg0 + {fmtA = 1 : i32, fmtB = 3 : i32, modC = 2 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with FP8 x FP4 and modC = 3 (negate(abs)) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v8i32(i32 0, <16 x i32> %{{.*}}, i32 4, <8 x i32> %{{.*}}, i16 3, <4 x float> %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i32 3, i32 3, i32 %{{.*}}, i1 false, i1 false) + %r04 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg2, %arg1, %arg0, %arg0 + {fmtA = 0 : i32, fmtB = 4 : i32, modC = 3 : i16, + scaleAType = 3 : i32, fmtScaleA = 3 : i32, + scaleBType = 3 : i32, fmtScaleB = 3 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with reuseA = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 2, <16 x i32> %{{.*}}, i32 2, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 true, i1 false) + %r10 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 2 : i32, fmtB = 2 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = true, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with reuseB = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 3, <16 x i32> %{{.*}}, i32 3, <16 x i32> %{{.*}}, i16 0, <4 x float> %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i32 0, i32 0, i32 %{{.*}}, i1 false, i1 true) + %r11 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 3 : i32, fmtB = 3 : i32, modC = 0 : i16, + scaleAType = 0 : i32, fmtScaleA = 0 : i32, + scaleBType = 0 : i32, fmtScaleB = 0 : i32, + reuseA = false, reuseB = true} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test with both reuseA and reuseB = true + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 4, <16 x i32> %{{.*}}, i32 4, <16 x i32> %{{.*}}, i16 1, <4 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 true, i1 true) + %r12 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg0, %arg0 + {fmtA = 4 : i32, fmtB = 4 : i32, modC = 1 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = true, reuseB = true} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32> + + // Test scale16 variant with i64 scale exponents + // CHECK: call <4 x float> @llvm.amdgcn.wmma.scale16.f32.16x16x128.f8f6f4.v4f32.v16i32.v16i32(i32 0, <16 x i32> %{{.*}}, i32 1, <16 x i32> %{{.*}}, i16 2, <4 x float> %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i32 2, i32 2, i64 %{{.*}}, i1 false, i1 false) + %r_scale16 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %arg5, %arg5, %arg1, %arg8, %arg8 + {fmtA = 0 : i32, fmtB = 1 : i32, modC = 2 : i16, + scaleAType = 2 : i32, fmtScaleA = 2 : i32, + scaleBType = 2 : i32, fmtScaleB = 2 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32> + + // Test f4 variant (no matrix format parameters) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 0, <8 x float> %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i32 1, i32 1, i32 %{{.*}}, i1 false, i1 false) + %r_f4 = rocdl.wmma.scale.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg0, %arg0 + {modC = 0 : i16, + scaleAType = 1 : i32, fmtScaleA = 1 : i32, + scaleBType = 1 : i32, fmtScaleB = 1 : i32, + reuseA = false, reuseB = false} : + (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32> + + // Test f4 scale16 variant with varied attributes + // CHECK: call <8 x float> @llvm.amdgcn.wmma.scale16.f32.32x16x128.f4.v8f32.v16i32.v8i32(<16 x i32> %{{.*}}, <8 x i32> %{{.*}}, i16 3, <8 x float> %{{.*}}, i32 2, i32 3, i64 %{{.*}}, i32 3, i32 2, i64 %{{.*}}, i1 true, i1 true) + %r_f4_scale16 = rocdl.wmma.scale16.f32.32x16x128.f4 %arg5, %arg2, %arg9, %arg8, %arg8 + {modC = 3 : i16, + scaleAType = 2 : i32, fmtScaleA = 3 : i32, + scaleBType = 3 : i32, fmtScaleB = 2 : i32, + reuseA = true, reuseB = true} : + (vector<16xi32>, vector<8xi32>, vector<8xf32>, i64, i64) -> vector<8xf32> + + llvm.return %r00 : vector<4xf32> +} + llvm.func @rocdl.raw.ptr.buffer.atomic.f32(%rsrc : !llvm.ptr<8>, %offset : i32, %soffset : i32, %vdata1 : f32) { diff --git a/mlir/test/Target/LLVMIR/target-ext-type.mlir b/mlir/test/Target/LLVMIR/target-ext-type.mlir index 6b2d2ea..cee6301 100644 --- a/mlir/test/Target/LLVMIR/target-ext-type.mlir +++ b/mlir/test/Target/LLVMIR/target-ext-type.mlir @@ -6,6 +6,12 @@ llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.target<"spirv llvm.return %0 : !llvm.target<"spirv.DeviceEvent"> } +// CHECK: @amdgcn_named_barrier = internal addrspace(3) global target("amdgcn.named.barrier", 0) poison +llvm.mlir.global internal @amdgcn_named_barrier() {addr_space = 3 : i32} : !llvm.target<"amdgcn.named.barrier", 0> { + %0 = llvm.mlir.poison : !llvm.target<"amdgcn.named.barrier", 0> + llvm.return %0 : !llvm.target<"amdgcn.named.barrier", 0> +} + // CHECK-LABEL: define target("spirv.Event") @func2() { // CHECK-NEXT: %1 = alloca target("spirv.Event"), align 8 // CHECK-NEXT: %2 = load target("spirv.Event"), ptr %1, align 8 diff --git a/mlir/test/Target/SPIRV/consecutive-selection.spv b/mlir/test/Target/SPIRV/consecutive-selection.spvasm index 3752058..3752058 100644 --- a/mlir/test/Target/SPIRV/consecutive-selection.spv +++ b/mlir/test/Target/SPIRV/consecutive-selection.spvasm diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir index 712fd17..29b5d4f 100644 --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -78,6 +78,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // ----- spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { + // CHECK: coherent + spirv.GlobalVariable @var {coherent} : !spirv.ptr<vector<2xf32>, Output> +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> { // CHECK: linkage_attributes = #spirv.linkage_attributes<linkage_name = "outSideGlobalVar1", linkage_type = <Import>> spirv.GlobalVariable @var1 { linkage_attributes=#spirv.linkage_attributes< diff --git a/mlir/test/Target/SPIRV/mlir-translate.mlir b/mlir/test/Target/SPIRV/mlir-translate.mlir index cbce351..b1966fe 100644 --- a/mlir/test/Target/SPIRV/mlir-translate.mlir +++ b/mlir/test/Target/SPIRV/mlir-translate.mlir @@ -1,7 +1,6 @@ // Check that `--spirv-save-validation-files-with-prefix` generates // a correct number of files. -// REQUIRES: shell // RUN: rm -rf %t // RUN: mkdir %t && mlir-translate --serialize-spirv --no-implicit-module \ // RUN: --split-input-file --spirv-save-validation-files-with-prefix=%t/foo %s \ diff --git a/mlir/test/Target/SPIRV/module.mlir b/mlir/test/Target/SPIRV/module.mlir index 7e52e54..fb4d9bc 100644 --- a/mlir/test/Target/SPIRV/module.mlir +++ b/mlir/test/Target/SPIRV/module.mlir @@ -1,6 +1,5 @@ // RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip --split-input-file %s | FileCheck %s -// REQUIRES: shell // RUN: %if spirv-tools %{ rm -rf %t %} // RUN: %if spirv-tools %{ mkdir %t %} // RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %} diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir index 12daf68..d0ad118 100644 --- a/mlir/test/Target/SPIRV/selection.mlir +++ b/mlir/test/Target/SPIRV/selection.mlir @@ -220,3 +220,129 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { spirv.EntryPoint "GLCompute" @main spirv.ExecutionMode @main "LocalSize", 1, 1, 1 } + +// ----- + +// Selection with switch + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +// CHECK-LABEL: @selection_switch + spirv.func @selection_switch(%selector: i32) -> () "None" { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %four = spirv.Constant 4: i32 +// CHECK: {{%.*}} = spirv.Variable init({{%.*}}) : !spirv.ptr<i32, Function> + %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function> +// CHECK: spirv.mlir.selection { + spirv.mlir.selection { +// CHECK-NEXT: spirv.Switch {{%.*}} : i32, [ +// CHECK-NEXT: default: ^[[DEFAULT:.+]], +// CHECK-NEXT: 0: ^[[CASE0:.+]], +// CHECK-NEXT: 1: ^[[CASE1:.+]], +// CHECK-NEXT: 2: ^[[CASE2:.+]] + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1, + 2: ^case2 + ] +// CHECK: ^[[DEFAULT]] + ^default: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %one : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE0]] + ^case0: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %two : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE1]] + ^case1: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %three : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE2]] + ^case2: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %four : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[MERGE]] + ^merge: +// CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge +// CHECK-NEXT: } + } +// CHECK-NEXT: spirv.Return + spirv.Return + } + + spirv.func @main() -> () "None" { + spirv.Return + } + spirv.EntryPoint "GLCompute" @main + spirv.ExecutionMode @main "LocalSize", 1, 1, 1 +} + +// ----- + +// Selection with switch and block operands + +spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> { +// CHECK-LABEL: @selection_switch_operands + spirv.func @selection_switch_operands(%selector : si32) "None" { + %cst1 = spirv.Constant 1.000000e+00 : f32 + %vec0 = spirv.Undef : vector<3xf32> +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32> + %vec1 = spirv.CompositeInsert %cst1, %vec0[0 : i32] : f32 into vector<3xf32> + spirv.Branch ^bb1 + ^bb1: +// CHECK: {{%.*}} = spirv.mlir.selection -> vector<3xf32> { + %vec4 = spirv.mlir.selection -> vector<3xf32> { +// CHECK-NEXT: spirv.Switch {{%.*}} : si32, [ +// CHECK-NEXT: default: ^[[DEFAULT:.+]]({{%.*}} : vector<3xf32>), +// CHECK-NEXT: 0: ^[[CASE0:.+]]({{%.*}} : vector<3xf32>), +// CHECK-NEXT: 1: ^[[CASE1:.+]]({{%.*}} : vector<3xf32>) + spirv.Switch %selector : si32, [ + default: ^bb3(%vec1 : vector<3xf32>), + 0: ^bb1(%vec1 : vector<3xf32>), + 1: ^bb2(%vec1 : vector<3xf32>) + ] +// CHECK: ^[[CASE0]]({{%.*}}: vector<3xf32>) + ^bb1(%vecbb1: vector<3xf32>): + %cst3 = spirv.Constant 3.000000e+00 : f32 +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32> + %vec2 = spirv.CompositeInsert %cst3, %vecbb1[1 : i32] : f32 into vector<3xf32> +// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>) + spirv.Branch ^bb3(%vec2 : vector<3xf32>) +// CHECK-NEXT: ^[[CASE1]]({{%.*}}: vector<3xf32>) + ^bb2(%vecbb2: vector<3xf32>): + %cst4 = spirv.Constant 4.000000e+00 : f32 +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32> + %vec3 = spirv.CompositeInsert %cst4, %vecbb2[1 : i32] : f32 into vector<3xf32> +// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>) + spirv.Branch ^bb3(%vec3 : vector<3xf32>) +// CHECK-NEXT: ^[[DEFAULT]]({{%.*}}: vector<3xf32>) + ^bb3(%vecbb3: vector<3xf32>): +// CHECK-NEXT: spirv.mlir.merge {{%.*}} : vector<3xf32> + spirv.mlir.merge %vecbb3 : vector<3xf32> +// CHECK-NEXT: } + } + %cst2 = spirv.Constant 2.000000e+00 : f32 +// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[2 : i32] : f32 into vector<3xf32> + %vec5 = spirv.CompositeInsert %cst2, %vec4[2 : i32] : f32 into vector<3xf32> + spirv.Return + } + + spirv.func @main() -> () "None" { + spirv.Return + } + + spirv.EntryPoint "GLCompute" @main + spirv.ExecutionMode @main "LocalSize", 1, 1, 1 +} diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spvasm index 9642d0a..9642d0a 100644 --- a/mlir/test/Target/SPIRV/selection.spv +++ b/mlir/test/Target/SPIRV/selection.spvasm diff --git a/mlir/test/Target/SPIRV/selection_switch.spvasm b/mlir/test/Target/SPIRV/selection_switch.spvasm new file mode 100644 index 0000000..81fecf3 --- /dev/null +++ b/mlir/test/Target/SPIRV/selection_switch.spvasm @@ -0,0 +1,69 @@ +; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %} + +; This test is analogous to selection.spv but tests switch op. + +; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { +; CHECK-NEXT: spirv.func @switch({{%.*}}: si32) "None" { +; CHECK: {{%.*}} = spirv.Constant 1.000000e+00 : f32 +; CHECK-NEXT: {{%.*}} = spirv.Undef : vector<3xf32> +; CHECK-NEXT: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32> +; CHECK-NEXT: spirv.Branch ^[[bb:.+]] +; CHECK-NEXT: ^[[bb:.+]]: +; CHECK-NEXT: {{%.*}} = spirv.mlir.selection -> vector<3xf32> { +; CHECK-NEXT: spirv.Switch {{%.*}} : si32, [ +; CHECK-NEXT: default: ^[[bb:.+]]({{%.*}}: vector<3xf32>), +; CHECK-NEXT: 0: ^[[bb:.+]]({{%.*}}: vector<3xf32>), +; CHECK-NEXT: 1: ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK-NEXT: spirv.mlir.merge %8 : vector<3xf32> +; CHECK-NEXT } +; CHECK: spirv.Return +; CHECK-NEXT: } +; CHECK: } + + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpName %switch "switch" + OpName %main "main" + %void = OpTypeVoid + %int = OpTypeInt 32 1 + %1 = OpTypeFunction %void %int + %float = OpTypeFloat 32 + %float_1 = OpConstant %float 1 + %v3float = OpTypeVector %float 3 + %9 = OpUndef %v3float + %float_3 = OpConstant %float 3 + %float_4 = OpConstant %float 4 + %float_2 = OpConstant %float 2 + %25 = OpTypeFunction %void + %switch = OpFunction %void None %1 + %5 = OpFunctionParameter %int + %6 = OpLabel + OpBranch %12 + %12 = OpLabel + %11 = OpCompositeInsert %v3float %float_1 %9 0 + OpSelectionMerge %15 None + OpSwitch %5 %15 0 %13 1 %14 + %13 = OpLabel + %16 = OpPhi %v3float %11 %12 + %18 = OpCompositeInsert %v3float %float_3 %16 1 + OpBranch %15 + %14 = OpLabel + %19 = OpPhi %v3float %11 %12 + %21 = OpCompositeInsert %v3float %float_4 %19 1 + OpBranch %15 + %15 = OpLabel + %22 = OpPhi %v3float %21 %14 %18 %13 %11 %12 + %24 = OpCompositeInsert %v3float %float_2 %22 2 + OpReturn + OpFunctionEnd + %main = OpFunction %void None %25 + %27 = OpLabel + OpReturn + OpFunctionEnd diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir index c1604e2..31a4f64d 100644 --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -880,6 +880,18 @@ func.func @no_speculate_divui( return } +func.func @no_speculate_udiv( +// CHECK-LABEL: @no_speculate_udiv( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.udiv + %val = llvm.udiv %num, %denom : i32 + } + + return +} + func.func @no_speculate_divsi( // CHECK-LABEL: @no_speculate_divsi( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -892,6 +904,18 @@ func.func @no_speculate_divsi( return } +func.func @no_speculate_sdiv( +// CHECK-LABEL: @no_speculate_sdiv( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.sdiv + %val = llvm.sdiv %num, %denom : i32 + } + + return +} + func.func @no_speculate_ceildivui( // CHECK-LABEL: @no_speculate_ceildivui( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -928,6 +952,18 @@ func.func @no_speculate_divui_const(%num: i32, %lb: index, %ub: index, %step: in return } +func.func @no_speculate_udiv_const(%num: i32, %lb: index, %ub: index, %step: index) { +// CHECK-LABEL: @no_speculate_udiv_const( + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.udiv + %val = llvm.udiv %num, %c0 : i32 + } + + return +} + func.func @speculate_divui_const( // CHECK-LABEL: @speculate_divui_const( %num: i32, %lb: index, %ub: index, %step: index) { @@ -941,6 +977,19 @@ func.func @speculate_divui_const( return } +func.func @speculate_udiv_const( +// CHECK-LABEL: @speculate_udiv_const( + %num: i32, %lb: index, %ub: index, %step: index) { + %c5 = llvm.mlir.constant(5 : i32) : i32 +// CHECK: llvm.udiv +// CHECK: scf.for + scf.for %i = %lb to %ub step %step { + %val = llvm.udiv %num, %c5 : i32 + } + + return +} + func.func @no_speculate_ceildivui_const(%num: i32, %lb: index, %ub: index, %step: index) { // CHECK-LABEL: @no_speculate_ceildivui_const( %c0 = arith.constant 0 : i32 @@ -979,6 +1028,19 @@ func.func @no_speculate_divsi_const0( return } +func.func @no_speculate_sdiv_const0( +// CHECK-LABEL: @no_speculate_sdiv_const0( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.sdiv + %val = llvm.sdiv %num, %c0 : i32 + } + + return +} + func.func @no_speculate_divsi_const_minus1( // CHECK-LABEL: @no_speculate_divsi_const_minus1( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -992,6 +1054,19 @@ func.func @no_speculate_divsi_const_minus1( return } +func.func @no_speculate_sdiv_const_minus1( +// CHECK-LABEL: @no_speculate_sdiv_const_minus1( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %cm1 = arith.constant -1 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.sdiv + %val = llvm.sdiv %num, %cm1 : i32 + } + + return +} + func.func @speculate_divsi_const( // CHECK-LABEL: @speculate_divsi_const( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -1005,6 +1080,19 @@ func.func @speculate_divsi_const( return } +func.func @speculate_sdiv_const( +// CHECK-LABEL: @speculate_sdiv_const( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c5 = arith.constant 5 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: llvm.sdiv +// CHECK: scf.for + %val = llvm.sdiv %num, %c5 : i32 + } + + return +} + func.func @no_speculate_ceildivsi_const0( // CHECK-LABEL: @no_speculate_ceildivsi_const0( %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { @@ -1057,6 +1145,19 @@ func.func @no_speculate_divui_range( return } +func.func @no_speculate_udiv_range( +// CHECK-LABEL: @no_speculate_udiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: llvm.udiv + %val = llvm.udiv %num, %denom : i8 + } + + return +} + func.func @no_speculate_divsi_range( // CHECK-LABEL: @no_speculate_divsi_range( %num: i8, %lb: index, %ub: index, %step: index) { @@ -1072,6 +1173,21 @@ func.func @no_speculate_divsi_range( return } +func.func @no_speculate_sdiv_range( +// CHECK-LABEL: @no_speculate_sdiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom0 = test.with_bounds {smax = -1: i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK-COUNT-2: llvm.sdiv + %val0 = llvm.sdiv %num, %denom0 : i8 + %val1 = llvm.sdiv %num, %denom1 : i8 + } + + return +} + func.func @no_speculate_ceildivui_range( // CHECK-LABEL: @no_speculate_ceildivui_range( %num: i8, %lb: index, %ub: index, %step: index) { @@ -1113,6 +1229,19 @@ func.func @speculate_divui_range( return } +func.func @speculate_udiv_range( +// CHECK-LABEL: @speculate_udiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK: llvm.udiv +// CHECK: scf.for + %val = llvm.udiv %num, %denom : i8 + } + + return +} + func.func @speculate_divsi_range( // CHECK-LABEL: @speculate_divsi_range( %num: i8, %lb: index, %ub: index, %step: index) { @@ -1129,6 +1258,22 @@ func.func @speculate_divsi_range( return } +func.func @speculate_sdiv_range( +// CHECK-LABEL: @speculate_sdiv_range( + %num: i8, %lb: index, %ub: index, %step: index) { + %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8 + scf.for %i = %lb to %ub step %step { +// CHECK-COUNT-2: llvm.sdiv +// CHECK: scf.for + %val0 = llvm.sdiv %num, %denom0 : i8 + %val1 = llvm.sdiv %num, %denom1 : i8 + + } + + return +} + func.func @speculate_ceildivui_range( // CHECK-LABEL: @speculate_ceildivui_range( %num: i8, %lb: index, %ub: index, %step: index) { diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir index 75d8386..3119fd3 100644 --- a/mlir/test/Transforms/move-operation-deps.mlir +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -238,25 +238,26 @@ module attributes {transform.with_named_sequence} { // ----- // Check simple move value definitions before insertion operation. -func.func @simple_move_values() -> f32 { - %0 = "before"() : () -> (f32) - %1 = "moved_op_1"() : () -> (f32) - %2 = "moved_op_2"() : () -> (f32) - %3 = "foo"(%1, %2) : (f32, f32) -> (f32) - return %3 : f32 -} -// CHECK-LABEL: func @simple_move_values() -// CHECK: %[[MOVED1:.+]] = "moved_op_1" -// CHECK: %[[MOVED2:.+]] = "moved_op_2" +func.func @simple_move_values(%arg0 : index) -> index { + %c0 = arith.constant 0 : index + %0 = "before"() : () -> (index) + %1 = arith.addi %arg0, %c0 {"moved_op_1"} : index + %2 = arith.subi %arg0, %c0 {"moved_op_2"} : index + %3 = "foo"(%1, %2) : (index, index) -> (index) + return %3 : index +} +// CHECK-LABEL: func @simple_move_values( +// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1} +// CHECK: %[[MOVED2:.+]] = arith.subi {{.*}} {moved_op_2} // CHECK: %[[BEFORE:.+]] = "before" // CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]]) // CHECK: return %[[FOO]] module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { - %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0 + %op1 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0 + %op2 = transform.structured.match ops{["arith.subi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op @@ -271,23 +272,26 @@ module attributes {transform.with_named_sequence} { // ----- // Compute slice including the implicitly captured values. -func.func @move_region_dependencies_values() -> f32 { - %0 = "before"() : () -> (f32) - %1 = "moved_op_1"() : () -> (f32) - %2 = "moved_op_2"() ({ - %3 = "inner_op"(%1) : (f32) -> (f32) - "yield"(%3) : (f32) -> () - }) : () -> (f32) - return %2 : f32 +func.func @move_region_dependencies_values(%arg0 : index, %cond : i1) -> index { + %0 = "before"() : () -> (index) + %1 = arith.addi %arg0, %arg0 {moved_op_1} : index + %2 = scf.if %cond -> index { + %3 = arith.muli %1, %1 {inner_op} : index + scf.yield %3 : index + } else { + scf.yield %1 : index + } + return %2 : index } -// CHECK-LABEL: func @move_region_dependencies_values() -// CHECK: %[[MOVED1:.+]] = "moved_op_1" -// CHECK: %[[MOVED2:.+]] = "moved_op_2" +// CHECK-LABEL: func @move_region_dependencies_values( +// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1} +// CHECK: scf.if +// CHECK: arith.muli %[[MOVED1]], %[[MOVED1]] {inner_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { - %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0 + %op1 = transform.structured.match ops{["scf.if"]} in %arg0 : (!transform.any_op) -> !transform.any_op %op2 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op @@ -301,31 +305,31 @@ module attributes {transform.with_named_sequence} { // ----- // Move operations in toplogical sort order -func.func @move_values_in_topological_sort_order() -> f32 { - %0 = "before"() : () -> (f32) - %1 = "moved_op_1"() : () -> (f32) - %2 = "moved_op_2"() : () -> (f32) - %3 = "moved_op_3"(%1) : (f32) -> (f32) - %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32) - %5 = "moved_op_5"(%2) : (f32) -> (f32) - %6 = "foo"(%4, %5) : (f32, f32) -> (f32) - return %6 : f32 -} -// CHECK-LABEL: func @move_values_in_topological_sort_order() -// CHECK: %[[MOVED_1:.+]] = "moved_op_1" -// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]]) -// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]]) -// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2" -// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]]) +func.func @move_values_in_topological_sort_order(%arg0 : index, %arg1 : index) -> index { + %0 = "before"() : () -> (index) + %1 = arith.addi %arg0, %arg0 {moved_op_1} : index + %2 = arith.addi %arg1, %arg1 {moved_op_2} : index + %3 = arith.muli %1, %1 {moved_op_3} : index + %4 = arith.andi %1, %3 {moved_op_4} : index + %5 = arith.subi %2, %2 {moved_op_5} : index + %6 = "foo"(%4, %5) : (index, index) -> (index) + return %6 : index +} +// CHECK-LABEL: func @move_values_in_topological_sort_order( +// CHECK: %[[MOVED_1:.+]] = arith.addi {{.*}} {moved_op_1} +// CHECK-DAG: %[[MOVED_2:.+]] = arith.muli %[[MOVED_1]], %[[MOVED_1]] {moved_op_3} +// CHECK-DAG: %[[MOVED_3:.+]] = arith.andi %[[MOVED_1]], %[[MOVED_2]] {moved_op_4} +// CHECK-DAG: %[[MOVED_4:.+]] = arith.addi {{.*}} {moved_op_2} +// CHECK-DAG: %[[MOVED_5:.+]] = arith.subi %[[MOVED_4]], %[[MOVED_4]] {moved_op_5} // CHECK: %[[BEFORE:.+]] = "before" // CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]]) // CHECK: return %[[FOO]] module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { - %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0 + %op1 = transform.structured.match ops{["arith.andi"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0 + %op2 = transform.structured.match ops{["arith.subi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op @@ -341,17 +345,17 @@ module attributes {transform.with_named_sequence} { // Move only those value definitions that are not dominated by insertion point -func.func @move_only_required_defns() -> (f32, f32, f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - %1 = "dummy_op"() : () -> (f32) - %2 = "before"() : () -> (f32) - %3 = "moved_op"() : () -> (f32) - return %0, %1, %2, %3 : f32, f32, f32, f32 +func.func @move_only_required_defns(%arg0 : index) -> (index, index, index, index) { + %0 = "unmoved_op"() : () -> (index) + %1 = "dummy_op"() : () -> (index) + %2 = "before"() : () -> (index) + %3 = arith.addi %arg0, %arg0 {moved_op} : index + return %0, %1, %2, %3 : index, index, index, index } -// CHECK-LABEL: func @move_only_required_defns() +// CHECK-LABEL: func @move_only_required_defns( // CHECK: %[[UNMOVED:.+]] = "unmoved_op" // CHECK: %[[DUMMY:.+]] = "dummy_op" -// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { @@ -362,7 +366,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op4 = transform.structured.match ops{["moved_op"]} in %arg0 + %op4 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value @@ -374,19 +378,19 @@ module attributes {transform.with_named_sequence} { // ----- -// Move only those value definitions that are not dominated by insertion point +// Move only those value definitions that are not dominated by insertion point (duplicate test) -func.func @move_only_required_defns() -> (f32, f32, f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - %1 = "dummy_op"() : () -> (f32) - %2 = "before"() : () -> (f32) - %3 = "moved_op"() : () -> (f32) - return %0, %1, %2, %3 : f32, f32, f32, f32 +func.func @move_only_required_defns_2(%arg0 : index) -> (index, index, index, index) { + %0 = "unmoved_op"() : () -> (index) + %1 = "dummy_op"() : () -> (index) + %2 = "before"() : () -> (index) + %3 = arith.subi %arg0, %arg0 {moved_op} : index + return %0, %1, %2, %3 : index, index, index, index } -// CHECK-LABEL: func @move_only_required_defns() +// CHECK-LABEL: func @move_only_required_defns_2( // CHECK: %[[UNMOVED:.+]] = "unmoved_op" // CHECK: %[[DUMMY:.+]] = "dummy_op" -// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[MOVED:.+]] = arith.subi {{.*}} {moved_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { @@ -397,7 +401,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %op3 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op4 = transform.structured.match ops{["moved_op"]} in %arg0 + %op4 = transform.structured.match ops{["arith.subi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value @@ -410,23 +414,23 @@ module attributes {transform.with_named_sequence} { // ----- // Check handling of block arguments -func.func @move_only_required_defns() -> (f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - cf.br ^bb0(%0 : f32) - ^bb0(%arg0 : f32) : - %1 = "before"() : () -> (f32) - %2 = "moved_op"(%arg0) : (f32) -> (f32) - return %1, %2 : f32, f32 -} -// CHECK-LABEL: func @move_only_required_defns() -// CHECK: %[[MOVED:.+]] = "moved_op" +func.func @move_with_block_arguments() -> (index, index) { + %0 = "unmoved_op"() : () -> (index) + cf.br ^bb0(%0 : index) + ^bb0(%arg0 : index) : + %1 = "before"() : () -> (index) + %2 = arith.addi %arg0, %arg0 {moved_op} : index + return %1, %2 : index, index +} +// CHECK-LABEL: func @move_with_block_arguments() +// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { %op1 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op"]} in %arg0 + %op2 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value transform.test.move_value_defns %v1 before %op1 @@ -438,20 +442,20 @@ module attributes {transform.with_named_sequence} { // ----- // Do not move across basic blocks -func.func @no_move_across_basic_blocks() -> (f32, f32) { - %0 = "unmoved_op"() : () -> (f32) - %1 = "before"() : () -> (f32) - cf.br ^bb0(%0 : f32) - ^bb0(%arg0 : f32) : - %2 = "moved_op"(%arg0) : (f32) -> (f32) - return %1, %2 : f32, f32 +func.func @no_move_across_basic_blocks() -> (index, index) { + %0 = "unmoved_op"() : () -> (index) + %1 = "before"() : () -> (index) + cf.br ^bb0(%0 : index) + ^bb0(%arg0 : index) : + %2 = arith.addi %arg0, %arg0 {moved_op} : index + return %1, %2 : index, index } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { %op1 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved_op"]} in %arg0 + %op2 = transform.structured.match ops{["arith.addi"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value // expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}} @@ -463,24 +467,22 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @move_isolated_from_above() -> () { - %1 = "before"() : () -> (f32) - %2 = "moved0"() : () -> (f32) - %3 = test.isolated_one_region_op %2 {} : f32 -> f32 - %4 = "moved1"(%3) : (f32) -> (f32) +func.func @move_isolated_from_above(%arg0 : index) -> () { + %1 = "before"() : () -> (index) + %2 = arith.addi %arg0, %arg0 {moved0} : index + %3 = arith.muli %2, %2 {moved1} : index return } -// CHECK-LABEL: func @move_isolated_from_above() -// CHECK: %[[MOVED0:.+]] = "moved0" -// CHECK: %[[ISOLATED:.+]] = test.isolated_one_region_op %[[MOVED0]] -// CHECK: %[[MOVED1:.+]] = "moved1"(%[[ISOLATED]]) +// CHECK-LABEL: func @move_isolated_from_above( +// CHECK: %[[MOVED0:.+]] = arith.addi {{.*}} {moved0} +// CHECK: %[[MOVED1:.+]] = arith.muli %[[MOVED0]], %[[MOVED0]] {moved1} // CHECK: %[[BEFORE:.+]] = "before" module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { %op1 = transform.structured.match ops{["before"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %op2 = transform.structured.match ops{["moved1"]} in %arg0 + %op2 = transform.structured.match ops{["arith.muli"]} in %arg0 : (!transform.any_op) -> !transform.any_op %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value transform.test.move_value_defns %v1 before %op1 diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index e730450..7130667 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -118,6 +118,17 @@ func.func @main(%arg0 : i32) { // ----- +// CHECK-LABEL: func.func private @clean_func_op_remove_side_effecting_op() { +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func private @clean_func_op_remove_side_effecting_op(%arg0: i32) -> (i32) { + // vector.print has a side effect but the op is dead. + vector.print %arg0 : i32 + return %arg0 : i32 +} + +// ----- + // %arg0 is not live because it is never used. %arg1 is not live because its // user `arith.addi` doesn't have any uses and the value that it is forwarded to // (%non_live_0) also doesn't have any uses. @@ -674,3 +685,32 @@ func.func @dead_value_loop_ivs_no_result(%lb: index, %ub: index, %step: index, % } return } + +// ----- + +// CHECK-LABEL: func @op_block_have_dead_arg +func.func @op_block_have_dead_arg(%arg0: index, %arg1: index, %arg2: i1) { + scf.execute_region { + cf.cond_br %arg2, ^bb1(%arg0 : index), ^bb1(%arg1 : index) + ^bb1(%0: index): + scf.yield + } + // CHECK-NEXT: return + return +} + +// ----- + +// CHECK-LABEL: func private @remove_dead_branch_op() +// CHECK-NEXT: ub.unreachable +// CHECK-NEXT: ^{{.*}}: +// CHECK-NEXT: return +// CHECK-NEXT: ^{{.*}}: +// CHECK-NEXT: return +func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64) { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + return %arg0 : i64 +^bb2: + return %arg1 : i64 +} diff --git a/mlir/test/Transforms/test-legalizer-no-rollback.mlir b/mlir/test/Transforms/test-legalizer-no-rollback.mlir new file mode 100644 index 0000000..5f421a3 --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-rollback.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @conditional_replacement( +// CHECK-SAME: %[[arg0:.*]]: i43) +// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42 +// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42 +// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42 +// Uses were replaced for dummy_user_1. +// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> () +// Uses were also replaced for dummy_user_2, but not by value_replace. The uses +// were replaced due to the block signature conversion. +// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> () +// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> () +func.func @conditional_replacement(%arg0: i42) { + %repl = "test.legal_op"() : () -> (i42) + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> () + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_2"(%arg0) {} : (i42) -> () + // Perform a conditional 1:N replacement. + "test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp index 8e2f03b..99f72c6 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp @@ -56,6 +56,17 @@ struct TestLivenessAnalysisPass liveness->print(os); os << "\n"; } + for (auto [regionIndex, region] : llvm::enumerate(op->getRegions())) { + os << " region: #" << regionIndex << ":\n"; + for (auto [argumntIndex, argument] : + llvm::enumerate(region.getArguments())) { + const Liveness *liveness = livenessAnalysis.getLiveness(argument); + assert(liveness && "expected a sparse lattice"); + os << " argument: #" << argumntIndex << ": "; + liveness->print(os); + os << "\n"; + } + } }); } }; diff --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp index 027b0a1..3ff0dc8 100644 --- a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp +++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp @@ -46,7 +46,7 @@ struct TestPointerLikeTypeInterfacePass Pass::Option<std::string> testMode{ *this, "test-mode", - llvm::cl::desc("Test mode: walk, alloc, copy, or free"), + llvm::cl::desc("Test mode: walk, alloc, copy, free, load, or store"), llvm::cl::init("walk")}; StringRef getArgument() const override { @@ -75,6 +75,10 @@ private: void testGenCopy(Operation *srcOp, Operation *destOp, Value srcResult, Value destResult, PointerLikeType pointerType, OpBuilder &builder); + void testGenLoad(Operation *op, Value result, PointerLikeType pointerType, + OpBuilder &builder); + void testGenStore(Operation *op, Value result, PointerLikeType pointerType, + OpBuilder &builder, Value providedValue = {}); struct PointerCandidate { Operation *op; @@ -92,9 +96,12 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() { auto func = getOperation(); OpBuilder builder(&getContext()); - if (testMode == "alloc" || testMode == "free") { + if (testMode == "alloc" || testMode == "free" || testMode == "load" || + testMode == "store") { // Collect all candidates first SmallVector<PointerCandidate> candidates; + // For store mode, also look for a test value to use + Value testValue; func.walk([&](Operation *op) { if (op->hasAttr("test.ptr")) { for (auto result : op->getResults()) { @@ -105,6 +112,11 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() { } } } + // Collect value marked with test.value for store tests + if (testMode == "store" && op->hasAttr("test.value")) { + if (op->getNumResults() > 0) + testValue = op->getResult(0); + } }); // Now test all candidates @@ -115,6 +127,12 @@ void TestPointerLikeTypeInterfacePass::runOnOperation() { else if (testMode == "free") testGenFree(candidate.op, candidate.result, candidate.pointerType, builder); + else if (testMode == "load") + testGenLoad(candidate.op, candidate.result, candidate.pointerType, + builder); + else if (testMode == "store") + testGenStore(candidate.op, candidate.result, candidate.pointerType, + builder, testValue); } } else if (testMode == "copy") { // Collect all source and destination candidates @@ -292,6 +310,105 @@ void TestPointerLikeTypeInterfacePass::testGenCopy( } } +void TestPointerLikeTypeInterfacePass::testGenLoad(Operation *op, Value result, + PointerLikeType pointerType, + OpBuilder &builder) { + Location loc = op->getLoc(); + + // Create a new builder with the listener and set insertion point + OperationTracker tracker; + OpBuilder newBuilder(op->getContext()); + newBuilder.setListener(&tracker); + newBuilder.setInsertionPointAfter(op); + + // Call the genLoad API + auto typedResult = cast<TypedValue<PointerLikeType>>(result); + Value loadRes = pointerType.genLoad(newBuilder, loc, typedResult, Type()); + + if (loadRes) { + llvm::errs() << "Successfully generated load for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "\tLoaded value type: "; + loadRes.getType().print(llvm::errs()); + llvm::errs() << "\n"; + + // Print all operations that were inserted + for (Operation *insertedOp : tracker.insertedOps) { + llvm::errs() << "\tGenerated: "; + insertedOp->print(llvm::errs()); + llvm::errs() << "\n"; + } + } else { + llvm::errs() << "Failed to generate load for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + } +} + +void TestPointerLikeTypeInterfacePass::testGenStore(Operation *op, Value result, + PointerLikeType pointerType, + OpBuilder &builder, + Value providedValue) { + Location loc = op->getLoc(); + + // Create a new builder with the listener and set insertion point + OperationTracker tracker; + OpBuilder newBuilder(op->getContext()); + newBuilder.setListener(&tracker); + newBuilder.setInsertionPointAfter(op); + + // Use provided value if available, otherwise create a constant + Value valueToStore = providedValue; + if (!valueToStore) { + // Create a test value to store - use a constant matching the element type + Type elementType = pointerType.getElementType(); + if (!elementType) { + llvm::errs() << "Failed to generate store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + return; + } + + if (elementType.isIntOrIndex()) { + auto attr = newBuilder.getIntegerAttr(elementType, 42); + valueToStore = + arith::ConstantOp::create(newBuilder, loc, elementType, attr); + } else if (auto floatType = dyn_cast<FloatType>(elementType)) { + auto attr = newBuilder.getFloatAttr(floatType, 42.0); + valueToStore = + arith::ConstantOp::create(newBuilder, loc, floatType, attr); + } else { + llvm::errs() << "Failed to generate store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + return; + } + } + + // Call the genStore API + auto typedResult = cast<TypedValue<PointerLikeType>>(result); + bool success = + pointerType.genStore(newBuilder, loc, valueToStore, typedResult); + + if (success) { + llvm::errs() << "Successfully generated store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + + // Print all operations that were inserted + for (Operation *insertedOp : tracker.insertedOps) { + llvm::errs() << "\tGenerated: "; + insertedOp->print(llvm::errs()); + llvm::errs() << "\n"; + } + } else { + llvm::errs() << "Failed to generate store for operation: "; + op->print(llvm::errs()); + llvm::errs() << "\n"; + } +} + } // namespace //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp index 35f092c..2506ca4 100644 --- a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp +++ b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp @@ -93,6 +93,29 @@ void TestRecipePopulatePass::runOnOperation() { if (!recipe) { op->emitError("Failed to create firstprivate recipe for ") << varName; } + } else if (recipeType == "private_from_firstprivate") { + // First create a firstprivate recipe, then use it to drive creation of a + // matching private recipe via the convenience overload. Give each recipe + // a stable, predictable name so tests can check both. + std::string firstprivName = "first_firstprivate_" + varName; + std::string privName = "private_from_firstprivate_" + varName; + + auto firstpriv = FirstprivateRecipeOp::createAndPopulate( + builder, loc, firstprivName, var.getType(), varName, bounds); + + if (!firstpriv) { + op->emitError("Failed to create firstprivate recipe for ") << varName; + return; + } + + auto priv = PrivateRecipeOp::createAndPopulate(builder, loc, privName, + *firstpriv); + + if (!priv) { + op->emitError( + "Failed to create private recipe (from firstprivate) for ") + << varName; + } } } } diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index e21cf94..8689265 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -320,10 +320,10 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( } //===----------------------------------------------------------------------===// -// OpWithResultShapePerDimInterfaceOp +// ReifyShapedTypeUsingReifyResultShapesOp //===----------------------------------------------------------------------===// -LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( +LogicalResult ReifyShapedTypeUsingReifyResultShapesOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { Location loc = getLoc(); shapes.reserve(getNumOperands()); @@ -345,6 +345,103 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( } //===----------------------------------------------------------------------===// +// ReifyShapedTypeUsingReifyShapeOfResultOp +//===----------------------------------------------------------------------===// + +LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +ReifyShapedTypeUsingReifyShapeOfResultOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + Location loc = getLoc(); + Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex); + SmallVector<OpFoldResult> shape = + tensor::getMixedSizes(builder, loc, sourceOperand); + return shape; +} + +//===----------------------------------------------------------------------===// +// ReifyShapedTypeUsingReifyDimOfResultOp +//===----------------------------------------------------------------------===// + +LogicalResult ReifyShapedTypeUsingReifyDimOfResultOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +ReifyShapedTypeUsingReifyDimOfResultOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + return failure(); +} + +FailureOr<OpFoldResult> +ReifyShapedTypeUsingReifyDimOfResultOp::reifyDimOfResult(OpBuilder &builder, + int resultIndex, + int dim) { + Location loc = getLoc(); + Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex); + OpFoldResult shape = tensor::getMixedSize(builder, loc, sourceOperand, dim); + return shape; +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapesOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableResultShapesOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + Location loc = getLoc(); + shapes.resize(1); + shapes[0] = {tensor::getMixedSize(builder, loc, getOperand(), 0), + OpFoldResult()}; + return success(); +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableResultShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +UnreifiableResultShapeOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + SmallVector<OpFoldResult> shape = { + tensor::getMixedSize(builder, getLoc(), getOperand(), 0), OpFoldResult()}; + return shape; +} + +//===----------------------------------------------------------------------===// +// UnreifableResultShapeOp +//===----------------------------------------------------------------------===// + +LogicalResult UnreifiableDimOfResultShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + return failure(); +} + +FailureOr<SmallVector<OpFoldResult>> +UnreifiableDimOfResultShapeOp::reifyShapeOfResult(OpBuilder &builder, + int resultIndex) { + return failure(); +} + +FailureOr<OpFoldResult> +UnreifiableDimOfResultShapeOp::reifyDimOfResult(OpBuilder &builder, + int resultIndex, int dim) { + if (dim == 0) + return tensor::getMixedSize(builder, getLoc(), getOperand(), 0); + return failure(); +} + +//===----------------------------------------------------------------------===// // SideEffectOp //===----------------------------------------------------------------------===// @@ -1540,3 +1637,14 @@ test::TestCreateTensorOp::getBufferType( return convertTensorToBuffer(getOperation(), options, type); } + +// Define a custom builder for ManyRegionsOp declared in TestOps.td. +// OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$firstRegion, +// "::std::unique_ptr<::mlir::Region>":$secondRegion)> +void test::ManyRegionsOp::build( + mlir::OpBuilder &builder, mlir::OperationState &state, + llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &®ions) { + for (auto &®ionPtr : std::move(regions)) + state.addRegion(std::move(regionPtr)); + ManyRegionsOp::build(builder, state, {}, regions.size()); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index 4201ade..6792743 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -42,6 +42,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" namespace test { class TestDialect; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 620d950..5417ae9 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -120,6 +120,13 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> { OptionalAttr<StrAttr>:$sym_visibility); } +def SymbolWithResultOp : TEST_Op<"symbol_with_result", [Symbol]> { + let summary = "invalid symbol operation that produces an SSA result"; + let arguments = (ins StrAttr:$sym_name, + OptionalAttr<StrAttr>:$sym_visibility); + let results = (outs AnyType:$result); +} + def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [ DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>, ]> { @@ -915,13 +922,97 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface", let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); } -def OpWithResultShapePerDimInterfaceOp : - TEST_Op<"op_with_result_shape_per_dim_interface", - [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> { +def ReifyShapedTypeUsingReifyResultShapesOp : + TEST_Op<"reify_shaped_type_using_reify_result_shapes", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes"]>]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that doesnt implement `reifyShapeOfResult` nor implements `reifyDimOfResult` + calls into the implementation of `reifyResultShapes` to get the required value. + The op semantics is that the first result has the same shape as the second operand + and the second result has the same shape as the first operand. + }]; let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); } +def ReifyShapedTypeUsingReifyShapeOfResultOp : + TEST_Op<"reify_shaped_type_using_reify_shape_of_result", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult"]>]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that doesnt implement `reifyDimOfResult` but implements `reifyShapeOfResult`, which + is used to get the required value. `reifyResultShapes` is implemented as a failure + (which is also the default implementation) to ensure it is not called. + The op semantics is that the first result has the same shape as the second operand + and the second result has the same shape as the first operand. + }]; + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def ReifyShapedTypeUsingReifyDimOfResultOp : + TEST_Op<"reify_shaped_type_using_reify_dim_of_result", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> { + let description = [{ + Test that when resolving a single dimension of a result for an operation + that implements `reifyDimOfResult`, which is used to get the required value. + `reifyResultShapes` and `reifyShapeOfResult` are implemented as failures + to ensure they are not called. The op semantics is that the first result has + the same shape as the second operand and the second result has the same shape + as the first operand. + }]; + let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2); + let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2); +} + +def UnreifiableResultShapesOp : TEST_Op<"unreifiable_result_shapes", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes"]>]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyResultShapes` is implemented. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + +def UnreifiableResultShapeOp : TEST_Op<"unreifiable_result_shape", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult"]>]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyShapeOfResult` is implemented, + but not `reifyDimOfResult` with `reifyResultShapes` implemented as a failure. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + +def UnreifiableDimOfResultShapeOp : TEST_Op<"unreifiable_dim_of_result_shape", + [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, + ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> { + let description = [{ + Test handling of case where some dimension of the result cannot be + reified. This tests the path when `reifyDimOfResult` is implemented, + and `reifyDimOfResult` with `reifyResultShapes` are implemented as a failure. + + Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but + dim 1 of `result` is not reifiable. + }]; + let arguments = (ins 2DTensorOf<[AnyType]>:$operand); + let results = (outs 2DTensorOf<[AnyType]>:$result); +} + def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>; def UpdateAttr : Pat<(I32ElementsAttrOp $attr), @@ -1108,6 +1199,12 @@ def TestLocationDstNoResOp : TEST_Op<"loc_dst_no_res"> { let results = (outs); } +def TestLocationAttrOp : TEST_Op<"op_with_loc_attr"> { + let arguments = (ins LocationAttr:$loc_attr); + let results = (outs ); + let assemblyFormat = "$loc_attr attr-dict"; +} + //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// @@ -2255,6 +2352,24 @@ def IsolatedGraphRegionOp : TEST_Op<"isolated_graph_region", [ let assemblyFormat = "attr-dict-with-keyword $region"; } +def ManyRegionsOp : TEST_Op<"many_regions", []> { + let summary = "operation created with move-only objects"; + let description = [{ + Test op with multiple regions with a `create` function that + takes parameters containing move-only objects. + }]; + + let regions = (region VariadicRegion<AnyRegion>:$regions); + let builders = + [OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$singleRegion), [{ + $_state.addRegion(std::move(singleRegion)); + build($_builder, $_state, {}, /*regionsCount=*/1); + }]>, + // Define in TestOps.cpp. + OpBuilder<(ins "::llvm::SmallVectorImpl<::std::unique_ptr<::mlir::" + "Region>>&&":$regions)>]; +} + def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> { let summary = "affine scope operation"; let description = [{ diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 9b64bc6..7eabaae 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern { // Replace the first operand with 2x the second operand. Value from = op->getOperand(0); Value repl = op->getOperand(1); - rewriter.replaceAllUsesWith(from, {repl, repl}); + if (op->hasAttr("conditional")) { + rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) { + return use.getOwner()->hasAttr("replace_uses"); + }); + } else { + rewriter.replaceAllUsesWith(from, {repl, repl}); + } rewriter.modifyOpInPlace(op, [&] { // If the "trigger_rollback" attribute is set, keep the op illegal, so // that a rollback is triggered. diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 614121f..9cf64a8 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -569,11 +569,17 @@ TestTensorType::getBufferType( ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( ::mlir::bufferization::BufferLikeType bufferType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { - auto testMemref = dyn_cast<TestMemrefType>(bufferType); - if (!testMemref) - return emitError() << "expected TestMemrefType"; + if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) { + const bool valid = getShape() == testMemref.getShape() && + getElementType() == testMemref.getElementType(); + return mlir::success(valid); + } + + if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) { + const bool valid = getShape() == builtinMemref.getShape() && + getElementType() == builtinMemref.getElementType(); + return mlir::success(valid); + } - const bool valid = getShape() == testMemref.getShape() && - getElementType() == testMemref.getElementType(); - return mlir::success(valid); + return emitError() << "expected MemRefType or TestMemrefType"; } diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bb..f834d0c 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -180,6 +180,34 @@ struct TestVectorUnrollingPatterns })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef<int64_t>{8, 8}) + .setFilterConstraint([](Operation *op) { + return success(isa<vector::CreateMaskOp>(op)); + })); + populateVectorUnrollPatterns( + patterns, + UnrollVectorOptions() + .setNativeShapeFn( + [](Operation *op) -> std::optional<SmallVector<int64_t>> { + auto shapeCast = dyn_cast<vector::ShapeCastOp>(op); + if (!shapeCast) + return std::nullopt; + + auto resultShape = shapeCast.getResultVectorType().getShape(); + // Special case with leading unit dims and different inner dim + // for result and target shape. + if (resultShape.size() == 2 && resultShape[0] == 1 && + resultShape[1] == 32) { + return SmallVector<int64_t>{1, 16}; + } + // Default case: [2,4] for all tests. + return SmallVector<int64_t>{2, 4}; + }) + .setFilterConstraint([](Operation *op) { + return success(isa<vector::ShapeCastOp>(op)); + })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) .setFilterConstraint([](Operation *op) { return success(isa<vector::TransposeOp>(op)); diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 326fec3..583d68b 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" @@ -170,9 +171,71 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, // TestFuseConsumerOp //===----------------------------------------------------------------------===// +/// Fuse the consumer and store both the original consumer operation as well as +/// the fused consumer operation. +static LogicalResult +applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, + Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops, + TransformResults &transformResults) { + SmallVector<Operation *> fusedConsumerOps; + rewriter.setInsertionPoint(consumer); + + FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = + scf::tileAndFuseConsumer(rewriter, consumer, loops); + if (failed(fuseConsumerResults)) + return consumer->emitOpError("failed to fuse consumer of slice"); + + // Report back the relevant handles to the transform op. + for (OpOperand *tiledAndFusedConsumerOperand : + fuseConsumerResults->tiledAndFusedConsumerOperands) { + fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); + } + transformResults.set(transformOp->getOpResult(0), fusedConsumerOps); + for (auto [index, loop] : llvm::enumerate(loops)) { + transformResults.set(transformOp->getOpResult(index + 1), {loop}); + } + return success(); +} + +DiagnosedSilenceableFailure +transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, + TransformResults &transformResults, + TransformState &state) { + Operation *consumer = *state.getPayloadOps(getConsumer()).begin(); + + SmallVector<LoopLikeOpInterface> loops; + // Since the matcher works inside-out, we need to iterate the loops in + // reverse. + for (auto loop : llvm::reverse(getLoops())) { + auto loopLikeOp = + dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(loop).begin()); + if (!loopLikeOp) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + loops.push_back(loopLikeOp); + } + LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer, + loops, transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +void transform::TestFuseConsumerOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getConsumerMutable(), effects); + consumesHandle(getLoopsMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// TestFuseConsumerUsingSliceOp +//===----------------------------------------------------------------------===// + /// Apply fusing of consumer transformation to all payload ops and store both /// the original consumer operation as well as the fused consumer operation. -static LogicalResult applyFuseConsumer( +static LogicalResult applyFuseConsumerUsingSlices( RewriterBase &rewriter, Operation *transformOp, ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops, uint32_t numConsumerToFuse, TransformResults &transformResults) { @@ -204,10 +267,9 @@ static LogicalResult applyFuseConsumer( return success(); } -DiagnosedSilenceableFailure -transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, - TransformResults &transformResults, - TransformState &state) { +DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { SmallVector<Operation *> slices; for (auto op : getTargets()) { auto sliceOp = *state.getPayloadOps(op).begin(); @@ -224,13 +286,13 @@ transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, loops.push_back(loopLikeOp); } LogicalResult result = - applyFuseConsumer(rewriter, getOperation(), slices, loops, - getNumConsumerToFuse(), transformResults); + applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops, + getNumConsumerToFuse(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } -void transform::TestFuseConsumerOp::getEffects( +void transform::TestFuseConsumerUsingSliceOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { consumesHandle(getTargetsMutable(), effects); consumesHandle(getLoopsMutable(), effects); @@ -622,6 +684,110 @@ DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TestQueryProducerFusability +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { + for (Operation *target : state.getPayloadOps(getTarget())) { + auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); + if (!tilingInterfaceOp) { + return emitSilenceableError() + << "target operation does not implement TilingInterface"; + } + + // Collect operand numbers and their corresponding producer insert_slice + // offsets and sizes. + SmallVector<unsigned> operandNumbers; + SmallVector<SmallVector<OpFoldResult>> allOffsets; + SmallVector<SmallVector<OpFoldResult>> allSizes; + + for (OpOperand &operand : target->getOpOperands()) { + Value operandValue = operand.get(); + Operation *definingOp = operandValue.getDefiningOp(); + + // Look for a producer tensor.insert_slice. This is only for testing + // purposes and otherwise is not a useful transformation. + if (auto insertSliceOp = + dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) { + operandNumbers.push_back(operand.getOperandNumber()); + allOffsets.push_back(insertSliceOp.getMixedOffsets()); + allSizes.push_back(insertSliceOp.getMixedSizes()); + } + } + + if (!operandNumbers.empty()) { + bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices( + operandNumbers, allOffsets, allSizes); + + if (isFusable) { + target->emitRemark() + << "can be fused with producer tensor.insert_slice ops"; + } else { + target->emitRemark() + << "cannot be fused with producer tensor.insert_slice ops"; + } + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::TestQueryProducerFusability::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsPayload(effects); +} + +//===----------------------------------------------------------------------===// +// TestQueryConsumerFusability +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TestQueryConsumerFusability::apply( + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { + for (Operation *target : state.getPayloadOps(getTarget())) { + auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); + if (!tilingInterfaceOp) { + return emitSilenceableError() + << "target operation does not implement TilingInterface"; + } + + // Look for tensor.extract_slice ops that consume results of the tilable op. + for (OpResult result : target->getResults()) { + for (OpOperand &use : result.getUses()) { + Operation *user = use.getOwner(); + + // Look for a consumer tensor.extract_slice. This is only for testing + // purposes and otherwise is not a useful transformation. + if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) { + bool isFusable = tilingInterfaceOp.isOpFusableWithConsumerSlice( + result.getResultNumber(), extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedSizes()); + + if (isFusable) { + target->emitRemark() + << "can be fused with consumer tensor.extract_slice op"; + } else { + target->emitRemark() + << "cannot be fused with consumer tensor.extract_slice op"; + } + } + } + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::TestQueryConsumerFusability::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsPayload(effects); +} + #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.cpp.inc" diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 694c422..8c4f64d 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -49,14 +49,19 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield", }]; } -def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer", +def TestFuseConsumerUsingSliceOp : Op<Transform_Dialect, "test.fuse_consumer_using_slice", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, ReportTrackingListenerFailuresOpTrait]> { let description = [{ - Fuses the consumer of the operation pointed to by the target handle - using the options provided as attributes. + For the `insert_slice`-like operations (that are typically generated through tiling), + within the loop nests passed in as `loops` (that are typically generated through tiling), + find the consumer that these slices map to (have to be the same consumer) and fuse + the consumer into the loop. + + Returns a handle to the original consumer operation and the consumer operation after + fusion. }]; let arguments = (ins @@ -73,6 +78,32 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer", }]; } +def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + For the `consumer` that uses the result of the outer-most loop of a loop nest passed in + as `loops` (that are typically generated through tiling), fuse the consumer into the + loop. + + Returns a handle to the consumer operation after fusion and the loops that might be + modified. + }]; + + let arguments = (ins + TransformHandleTypeInterface:$consumer, + Variadic<TransformHandleTypeInterface>:$loops); + let results = (outs TransformHandleTypeInterface:$fused_consumer, + Variadic<TransformHandleTypeInterface>:$result_loops); + + let assemblyFormat = [{ + $consumer `into` `(` $loops `)` + attr-dict `:` functional-type(operands, results) + }]; +} + + def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall", [DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, @@ -166,11 +197,55 @@ def TestTileUsingCustomLoopOp : Op< DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes); let results = (outs TransformHandleTypeInterface:$tiled_ops, Variadic<TransformHandleTypeInterface>:$loops); - + let assemblyFormat = [{ $root_op `tile_sizes` `=` $tile_sizes attr-dict `:` functional-type(operands, results) }]; } +def TestQueryProducerFusability : Op< + Transform_Dialect, "test.query_producer_fusability", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { + let description = [{ + Test operation for the producer fusability query method in the + TilingInterface. + + For each operation in the target handle, this looks for tensor.insert_slice + ops that produce operands to the tilable op. The offset/sizes from those + inserts is used as the arguments to `isOpFusableWithProducerSlices` and + emits a remark with the result of the query. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; +} + +def TestQueryConsumerFusability + : Op<Transform_Dialect, "test.query_consumer_fusability", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { + let description = [{ + Test operation for the consumer fusability query method in the + TilingInterface. + + For each operation in the target handle, this looks for tensor.extract_slice + ops that consume results of the tilable op. The offset/sizes from those + extracts is used as the arguments to `isOpFusableWithConsumerSlice` and + emits a remark with the result of the query. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; +} + #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td index 9b0a260..bc53b23 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.td +++ b/mlir/test/lib/Transforms/TestTransformsOps.td @@ -44,7 +44,9 @@ def TestMoveValueDefns : DeclareOpInterfaceMethods<TransformOpInterface>, ReportTrackingListenerFailuresOpTrait]> { let description = [{ - Moves all dependencies of on operation before another operation. + Moves all dependencies of a list of values before another operation. + Only pure operations are moved. If there is a side effecting op in the + dependency chain no operations are moved. }]; let arguments = diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 6ff12d6..675ded3 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -44,7 +44,7 @@ config.suffixes = [ ".test", ".pdll", ".c", - ".spv", + ".spvasm", ] # test_source_root: The root path where tests are located. @@ -214,6 +214,11 @@ tools = [ "not", ] +if "Linux" in config.host_os: + # TODO: Run only on Linux until we figure out how to build + # mlir_apfloat_wrappers in a platform-independent way. + tools.extend([add_runtime("mlir_apfloat_wrappers")]) + if config.enable_vulkan_runner: tools.extend([add_runtime("mlir_vulkan_runtime")]) diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td new file mode 100644 index 0000000..ff39fd9 --- /dev/null +++ b/mlir/test/mlir-tblgen/dialect-interface.td @@ -0,0 +1,65 @@ +// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL + +include "mlir/IR/Interfaces.td" + +def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> { + let description = [{ + This is an example dialect interface without default method body. + }]; + + let cppNamespace = "::mlir::example"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Check if it's an example dialect", + /*returnType=*/ "bool", + /*methodName=*/ "isExampleDialect", + /*args=*/ (ins) + >, + InterfaceMethod< + /*desc=*/ "second method to check if multiple methods supported", + /*returnType=*/ "unsigned", + /*methodName=*/ "supportSecondMethod", + /*args=*/ (ins "::mlir::Type":$type) + > + + ]; +} + +// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod> +// DECL: public: +// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {} +// DECL: virtual bool isExampleDialect() const {} +// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const {} + +def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> { + let description = [{ + This is an example dialect interface with default method bodies. + }]; + + let cppNamespace = "::mlir::example"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Check if it's an example dialect", + /*returnType=*/ "bool", + /*methodName=*/ "isExampleDialect", + /*args=*/ (ins), + /*methodBody=*/ [{ + return true; + }] + >, + InterfaceMethod< + /*desc=*/ "second method to check if multiple methods supported", + /*returnType=*/ "unsigned", + /*methodName=*/ "supportSecondMethod", + /*args=*/ (ins "::mlir::Type":$type) + > + + ]; +} + +// DECL: virtual bool isExampleDialect() const { +// DECL-NEXT: return true; +// DECL-NEXT: } + diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index 0e87373..80dedb84 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -235,14 +235,14 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint", // DEFS: FOp FOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value a) { // DEFS: ::mlir::OperationState __state__(location, getOperationName()); -// DEFS: build(builder, __state__, a); +// DEFS: build(builder, __state__, std::forward<decltype(a)>(a)); // DEFS: auto __res__ = ::llvm::dyn_cast<FOp>(builder.create(__state__)); // DEFS: assert(__res__ && "builder didn't return the right type"); // DEFS: return __res__; // DEFS: } // DEFS: FOp FOp::create(::mlir::ImplicitLocOpBuilder &builder, ::mlir::Value a) { -// DEFS: return create(builder, builder.getLoc(), a); +// DEFS: return create(builder, builder.getLoc(), std::forward<decltype(a)>(a)); // DEFS: } def NS_GOp : NS_Op<"op_with_fixed_return_type", []> { diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 42de7e4..ff16ad8 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip) // CHECK: @builtins.property - // CHECK: def f32(self) -> _ods_ir.Value: + // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]: // CHECK: return self.operation.operands[1] let arguments = (ins I32, F32:$f32, I64); // CHECK: @builtins.property - // CHECK: def i32(self) -> _ods_ir.OpResult: + // CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[0] // // CHECK: @builtins.property - // CHECK: def i64(self) -> _ods_ir.OpResult: + // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[2] let results = (outs I32:$i32, AnyFloat, I64:$i64); } @@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> { // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip) // CHECK: @builtins.property - // CHECK: def i32(self) -> _ods_ir.Value: + // CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]: // CHECK: return self.operation.operands[0] // // CHECK: @builtins.property - // CHECK: def f32(self) -> _ods_ir.Value: + // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]: // CHECK: return self.operation.operands[1] let arguments = (ins I32:$i32, F32:$f32); // CHECK: @builtins.property - // CHECK: def i64(self) -> _ods_ir.OpResult: + // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[0] // // CHECK: @builtins.property - // CHECK: def f64(self) -> _ods_ir.OpResult: + // CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]: // CHECK: return self.operation.results[1] let results = (outs I64:$i64, AnyFloat:$f64); } diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py index 3945c99..1a009b7 100644 --- a/mlir/test/python/dialects/gpu/dialect.py +++ b/mlir/test/python/dialects/gpu/dialect.py @@ -133,9 +133,10 @@ def testGPUFuncOp(): ), func.known_grid_size func = gpu.GPUFuncOp( - func_type, + ir.FunctionType.get(inputs=[T.index()], results=[]), sym_name="non_kernel_func", body_builder=builder, + arg_attrs=[{"gpu.some_attribute": ir.StringAttr.get("foo")}], ) assert not func.is_kernel assert func.known_block_size is None @@ -154,10 +155,11 @@ def testGPUFuncOp(): # CHECK: %[[VAL_0:.*]] = gpu.global_id x # CHECK: gpu.return # CHECK: } - # CHECK: gpu.func @non_kernel_func() { - # CHECK: %[[VAL_0:.*]] = gpu.global_id x - # CHECK: gpu.return - # CHECK: } + # CHECK: gpu.func @non_kernel_func( + # CHECK-SAME: %[[ARG0:.*]]: index {gpu.some_attribute = "foo"}) { + # CHECK: %[[GLOBAL_ID_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } # CHECK-LABEL: testGPULaunchFuncOp diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 709a1d2..92591cd 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -1,7 +1,8 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.dialects import arith, func, linalg, tensor, memref +from mlir.dialects import arith, func, linalg, tensor, memref, builtin from mlir.dialects.linalg.opdsl.lang import * +from mlir.extras import types as T from mlir.ir import * @@ -857,3 +858,76 @@ def testElementwiseOp(): ) print(module) + + +@run +def testReduceOp(): + with Context(), Location.unknown(): + f32 = T.f32() + tensor_type = T.tensor(10, f32) + + @builtin.module + def module(): + @func.func(tensor_type) + def reduce_op(input): + c1 = arith.constant(f32, 1.0) + single_result = ir.RankedTensorType.get((), f32) + dims = ir.DenseI64ArrayAttr.get([0]) + init = tensor.splat(single_result, c1, []) + + @linalg.reduce( + result=[single_result], + inputs=[input], + inits=[init], + dimensions=dims, + ) + def reduced(element: f32, acc: f32): + return arith.mulf(acc, element) + + return tensor.extract(reduced, []) + + print(module) + + +# CHECK-LABEL: func.func @reduce_op( +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> f32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32 +# CHECK: %[[SPLAT_0:.*]] = tensor.splat %[[CONSTANT_0]] : tensor<f32> +# CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.mulf } ins(%[[ARG0]] : tensor<10xf32>) outs(%[[SPLAT_0]] : tensor<f32>) dimensions = [0] +# CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor<f32> +# CHECK: return %[[EXTRACT_0]] : f32 +# CHECK: } + + +@run +def testMapOp(): + with Context(), Location.unknown(): + f32 = T.f32() + tensor_type = T.tensor(10, f32) + + @builtin.module + def module(): + @func.func(tensor_type) + def map_op(input): + empty = tensor.empty(tensor_type.shape, f32) + + @linalg.map( + result=[tensor_type], + inputs=[input, input], + init=empty, + ) + def add(element: f32, acc: f32, init: f32): + return arith.addf(element, acc) + + return add + + module.verify() + print(module) + + +# CHECK-LABEL: func.func @map_op( +# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> tensor<10xf32> { +# CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<10xf32> +# CHECK: %[[MAP_0:.*]] = linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG0]] : tensor<10xf32>, tensor<10xf32>) outs(%[[EMPTY_0]] : tensor<10xf32>) +# CHECK: return %[[MAP_0]] : tensor<10xf32> +# CHECK: } diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 5f7cb6a..8ab53b4 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -208,3 +208,43 @@ def test_get_indexing_maps_attr(): assert maps[0] == a_map assert maps[1] == b_map assert maps[2] == c_map + + +@run +def test_infer_contraction_dimensions_from_maps(): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + # === Test valid contraction (matmul) === + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + + dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map]) + assert dims is not None + + # Expect m=[0], n=[1], k=[2] as per standard matmul. + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}" + + # === Test invalid input (wrong number of maps) === + invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map]) + assert invalid_dims is None + + # === Test element-wise operation === + dim_i = AffineDimExpr.get(0) + dim_j = AffineDimExpr.get(1) + elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j]) + elementwise_dims = linalg.infer_contraction_dimensions_from_maps( + [elementwise_map, elementwise_map, elementwise_map] + ) + assert elementwise_dims is not None + assert len(elementwise_dims.m) == 0 + assert len(elementwise_dims.n) == 0 + assert len(elementwise_dims.k) == 0 + assert list(elementwise_dims.batch) == [0, 1] diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py index 8ea0fdd..305ed9a 100644 --- a/mlir/test/python/dialects/llvm.py +++ b/mlir/test/python/dialects/llvm.py @@ -98,6 +98,9 @@ def testStructType(): assert opaque.opaque # CHECK: !llvm.struct<"opaque", opaque> + typ = Type.parse('!llvm.struct<"zoo", (i32, i64)>') + assert isinstance(typ, llvm.StructType) + # CHECK-LABEL: testSmoke @constructAndPrintInModule @@ -120,6 +123,9 @@ def testPointerType(): # CHECK: !llvm.ptr<1> print(ptr_with_addr) + typ = Type.parse("!llvm.ptr<1>") + assert isinstance(typ, llvm.PointerType) + # CHECK-LABEL: testConstant @constructAndPrintInModule diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py index 3eb62be..d795524 100644 --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/nvvm.py @@ -15,7 +15,9 @@ def constructAndPrintInModule(f): module = Module.create() with InsertionPoint(module.body): f() + print(module) + module.operation.verify() return f @@ -89,3 +91,133 @@ def test_inline_ptx(): arith.addf(a, b) arith.addi(c, d) arith.addf(wo0, wo1) + + +@constructAndPrintInModule +def test_barriers(): + i32 = T.i32() + f32 = T.f32() + + @func.FuncOp.from_py_func(i32, i32, f32) + def barriers(mask, vi32, vf32): + c0 = arith.constant(T.i32(), 0) + cffff = arith.constant(T.i32(), 0xFFFF) + res = nvvm.barrier( + res=i32, + barrier_id=c0, + number_of_threads=cffff, + ) + + for reduction in ( + nvvm.BarrierReduction.AND, + nvvm.BarrierReduction.OR, + nvvm.BarrierReduction.POPC, + ): + res = nvvm.barrier( + res=i32, + reduction_op=reduction, + reduction_predicate=res, + ) + + nvvm.barrier0() + nvvm.bar_warp_sync(mask) + nvvm.cluster_arrive() + nvvm.cluster_arrive(aligned=True) + nvvm.cluster_arrive_relaxed() + nvvm.cluster_arrive_relaxed(aligned=True) + nvvm.cluster_wait() + nvvm.cluster_wait(aligned=True) + nvvm.fence_mbarrier_init() + nvvm.bar_warp_sync(mask) + return res + + +# CHECK-LABEL: func.func @barriers( +# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32 +# CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32 +# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32 +# CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32 +# CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32 +# CHECK: nvvm.barrier0 +# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32 +# CHECK: nvvm.cluster.arrive +# CHECK: nvvm.cluster.arrive {aligned} +# CHECK: nvvm.cluster.arrive.relaxed +# CHECK: nvvm.cluster.arrive.relaxed {aligned} +# CHECK: nvvm.cluster.wait +# CHECK: nvvm.cluster.wait {aligned} +# CHECK: nvvm.fence.mbarrier.init +# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32 +# CHECK: return %[[BARRIER_3]] : i32 +# CHECK: } + + +@constructAndPrintInModule +def test_reductions(): + i32 = T.i32() + f32 = T.f32() + + @func.FuncOp.from_py_func(i32, i32, f32) + def reductions(mask, vi32, vf32): + for abs in (True, False): + for nan in (True, False): + for kind in ( + nvvm.ReduxKind.AND, + nvvm.ReduxKind.MAX, + nvvm.ReduxKind.MIN, + nvvm.ReduxKind.OR, + nvvm.ReduxKind.UMAX, + nvvm.ReduxKind.UMIN, + nvvm.ReduxKind.XOR, + ): + nvvm.redux_sync(i32, vi32, kind, vi32) + + for kind in ( + nvvm.ReduxKind.FMIN, + nvvm.ReduxKind.FMAX, + ): + nvvm.redux_sync(f32, vf32, kind, vi32, abs=abs, nan=nan) + + +# CHECK-LABEL: func.func @reductions( +# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) { +# CHECK: %[[REDUX_0:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_1:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_2:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_3:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_4:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_5:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_6:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_7:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32 +# CHECK: %[[REDUX_8:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true, nan = true} : f32 -> f32 +# CHECK: %[[REDUX_9:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_10:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_11:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_12:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_13:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_14:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_15:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_16:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32 +# CHECK: %[[REDUX_17:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {abs = true} : f32 -> f32 +# CHECK: %[[REDUX_18:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_19:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_20:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_21:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_22:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_23:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_24:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_25:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32 +# CHECK: %[[REDUX_26:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] {nan = true} : f32 -> f32 +# CHECK: %[[REDUX_27:.*]] = nvvm.redux.sync and %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_28:.*]] = nvvm.redux.sync max %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_29:.*]] = nvvm.redux.sync min %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_30:.*]] = nvvm.redux.sync or %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_31:.*]] = nvvm.redux.sync umax %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_32:.*]] = nvvm.redux.sync umin %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_33:.*]] = nvvm.redux.sync xor %[[ARG1]], %[[ARG1]] : i32 -> i32 +# CHECK: %[[REDUX_34:.*]] = nvvm.redux.sync fmin %[[ARG2]], %[[ARG1]] : f32 -> f32 +# CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32 +# CHECK: return +# CHECK: } diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 1194e32..f0f74eb 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -554,7 +554,7 @@ def testOptionalOperandOp(): ) assert ( typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"] - is OpResult + == OpResult[IntegerType] ) assert type(op1.result) is OpResult @@ -663,6 +663,13 @@ def testCustomType(): @run +# CHECK-LABEL: TEST: testValue +def testValue(): + # Check that Value is a generic class at runtime. + assert hasattr(Value, "__class_getitem__") + + +@run # CHECK-LABEL: TEST: testTensorValue def testTensorValue(): with Context() as ctx, Location.unknown(): diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py index a4a50af..c73a536 100644 --- a/mlir/test/python/dialects/rocdl.py +++ b/mlir/test/python/dialects/rocdl.py @@ -29,13 +29,12 @@ def testSmoke(): a_frag = arith.constant(v16f32, f32_array) b_frag = arith.constant(v16f32, f32_array) c_frag = arith.constant(v16f32, f32_array) - false = arith.constant(T.bool(), False) - c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false]) - # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16 + c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, a_frag, b_frag, c_frag, opsel=False) + # CHECK: %{{.*}} = "rocdl.wmma.f16.16x16x16.f16" print(c_frag) assert isinstance(c_frag, OpView) - # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16 - c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false]) + # CHECK: Value(%{{.*}} = "rocdl.wmma.f16.16x16x16.f16" + c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, a_frag, b_frag, c_frag, opsel=False) print(c_frag) assert isinstance(c_frag, Value) diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index 62d11d5..0c0c9b9 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -1,10 +1,14 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -from mlir.dialects import arith -from mlir.dialects import func -from mlir.dialects import memref -from mlir.dialects import scf +from mlir.extras import types as T +from mlir.dialects import ( + arith, + func, + memref, + scf, + cf, +) from mlir.passmanager import PassManager @@ -355,3 +359,117 @@ def testIfWithElse(): # CHECK: scf.yield %[[TWO]], %[[THREE]] # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 # CHECK: return + + +@constructAndPrintInModule +def testIndexSwitch(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + switch_op = scf.IndexSwitchOp([i32], index, range(3)) + + assert switch_op.regions[0] == switch_op.default_region + assert switch_op.regions[1] == switch_op.case_regions[0] + assert switch_op.regions[1] == switch_op.case_region(0) + assert len(switch_op.case_regions) == 3 + assert len(switch_op.regions) == 4 + + with InsertionPoint(switch_op.default_block): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + for i, block in enumerate(switch_op.case_blocks): + with InsertionPoint(block): + scf.yield_([arith.constant(i32, i)]) + + func.return_([switch_op.results[0]]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } + + +@constructAndPrintInModule +def testIndexSwitchWithBodyBuilders(): + i32 = T.i32() + + @func.FuncOp.from_py_func(T.index(), results=[i32]) + def index_switch(index): + c1 = arith.constant(i32, 1) + c0 = arith.constant(i32, 0) + value = arith.constant(i32, 5) + + def default_body_builder(switch_op): + cf.assert_(arith.constant(T.bool(), 0), "Whoops!") + scf.yield_([c1]) + + def case_body_builder(switch_op, case_index: int, case_value: int): + scf.yield_([arith.constant(i32, case_value)]) + + result = scf.index_switch( + results=[i32], + arg=index, + cases=range(3), + case_body_builder=case_body_builder, + default_body_builder=default_body_builder, + ) + + func.return_([result]) + + return index_switch + + +# CHECK-LABEL: func.func @index_switch( +# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 { +# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32 +# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32 +# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32 +# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32 +# CHECK: case 0 { +# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +# CHECK: scf.yield %[[CONSTANT_3]] : i32 +# CHECK: } +# CHECK: case 1 { +# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32 +# CHECK: scf.yield %[[CONSTANT_4]] : i32 +# CHECK: } +# CHECK: case 2 { +# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32 +# CHECK: scf.yield %[[CONSTANT_5]] : i32 +# CHECK: } +# CHECK: default { +# CHECK: %[[CONSTANT_6:.*]] = arith.constant false +# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!" +# CHECK: scf.yield %[[CONSTANT_0]] : i32 +# CHECK: } +# CHECK: return %[[INDEX_SWITCH_0]] : i32 +# CHECK: } diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py index 819a3be..ca9ce5d 100644 --- a/mlir/test/python/dialects/transform_interpreter.py +++ b/mlir/test/python/dialects/transform_interpreter.py @@ -32,6 +32,20 @@ def print_self(): @test_in_context +def print_self_via_apply_method(): + m = ir.Module.parse( + print_root_module.replace("from interpreter", "print_self_via_apply_method") + ) + m.body.operations[0].apply(m) + + +# CHECK-LABEL: print_self_via_apply_method +# CHECK: transform.named_sequence @__transform_main +# CHECK: transform.print +# CHECK: transform.yield + + +@test_in_context def print_other(): transform = ir.Module.parse( print_root_module.replace("from interpreter", "print_other") diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index 0b587d2..2b11acb0 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -3,7 +3,7 @@ from mlir.ir import * from mlir.dialects import transform from mlir.dialects.transform import xegpu -from mlir.dialects.transform import AnyValueType +from mlir.dialects.transform import structured, AnyValueType def run(f): @@ -25,7 +25,7 @@ def getDescOpDefaultIndex(): ) with InsertionPoint(sequence.body): operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) - desc_handle = xegpu.GetDescOp(operand) + desc_handle = xegpu.get_desc_op(operand) transform.YieldOp() # CHECK-LABEL: TEST: getDescOpDefaultIndex # CHECK: transform.xegpu.get_desc_op % @@ -39,7 +39,7 @@ def setDescLayoutMinimal(): transform.OperationType.get("xegpu.create_nd_tdesc"), ) with InsertionPoint(sequence.body): - xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16]) + xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16]) transform.YieldOp() # CHECK-LABEL: TEST: setDescLayoutMinimal # CHECK: %0 = transform.xegpu.set_desc_layout % @@ -55,7 +55,7 @@ def setDescLayoutInstData(): transform.OperationType.get("xegpu.create_nd_tdesc"), ) with InsertionPoint(sequence.body): - xegpu.SetDescLayoutOp( + xegpu.set_desc_layout( sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16] ) transform.YieldOp() @@ -67,6 +67,25 @@ def setDescLayoutInstData(): @run +def setDescLayoutSlice(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.create_nd_tdesc"), + ) + with InsertionPoint(sequence.body): + xegpu.set_desc_layout( + sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setDescLayoutSlice + # CHECK: %0 = transform.xegpu.set_desc_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: slice_dims = [0] + + +@run def setOpLayoutAttrOperandMinimal(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, @@ -74,7 +93,7 @@ def setOpLayoutAttrOperandMinimal(): transform.OperationType.get("xegpu.dpas"), ) with InsertionPoint(sequence.body): - xegpu.SetOpLayoutAttrOp( + xegpu.set_op_layout_attr( sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], @@ -97,7 +116,7 @@ def setOpLayoutAttrResult(): transform.OperationType.get("xegpu.dpas"), ) with InsertionPoint(sequence.body): - xegpu.SetOpLayoutAttrOp( + xegpu.set_op_layout_attr( sequence.bodyTarget, index=0, sg_layout=[6, 4], @@ -106,10 +125,172 @@ def setOpLayoutAttrResult(): result=True, ) transform.YieldOp() - # CHECK-LABEL: TEST: setOpLayoutAttr + # CHECK-LABEL: TEST: setOpLayoutAttrResult + # CHECK: transform.xegpu.set_op_layout_attr % + # NO-CHECK: index = 0 + # CHECK: result + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16] + + +@run +def setOpLayoutAttrResultSlice(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.set_op_layout_attr( + sequence.bodyTarget, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16], + slice_dims=[0], + result=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice # CHECK: transform.xegpu.set_op_layout_attr % # NO-CHECK: index = 0 # CHECK: result # CHECK: sg_layout = [6, 4] # CHECK: sg_data = [32, 16] # CHECK: inst_data = [8, 16] + # CHECK: slice_dims = [0] + + +@run +def setGPULaunchThreadsOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("gpu.launch"), + ) + with InsertionPoint(sequence.body): + xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1]) + transform.YieldOp() + # CHECK-LABEL: TEST: setGPULaunchThreadsOp + # CHECK: transform.xegpu.set_gpu_launch_threads + # CHECK: threads = [8, 4, 1] + + +@run +def insertPrefetch0(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.insert_prefetch( + operand, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetch0 + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + + +@run +def insertPrefetchNbPrefetch(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.insert_prefetch( + operand, + nb_prefetch=2, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetch + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = 2 + + +@run +def insertPrefetchNbPrefetchParam(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + int32_t = IntegerType.get_signless(32) + param_int32_t = transform.ParamType.get(int32_t) + nb_param = transform.ParamConstantOp( + param_int32_t, + IntegerAttr.get(int32_t, 2), + ) + xegpu.insert_prefetch( + operand, + nb_prefetch=nb_param, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2 + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = %[[PARAM_OP]] + + +@run +def ConvertLayoutMinimal(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.convert_layout( + operand, + input_sg_layout=[6, 4], + input_sg_data=[32, 16], + target_sg_layout=[6, 4], + target_sg_data=[8, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: ConvertLayoutMinimal + # CHECK: transform.xegpu.convert_layout % + # CHECK: input_sg_layout = [6, 4] + # CHECK: input_sg_data = [32, 16] + # CHECK: target_sg_layout = [6, 4] + # CHECK: target_sg_data = [8, 16] + + +@run +def ConvertLayout(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1]) + xegpu.convert_layout( + operand, + input_sg_layout=[6, 4], + input_sg_data=[32, 32], + input_inst_data=[32, 16], + target_sg_layout=[6, 4], + target_sg_data=[32, 32], + target_inst_data=[8, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: ConvertLayout + # CHECK: transform.xegpu.convert_layout % + # CHECK: input_sg_layout = [6, 4] + # CHECK: input_sg_data = [32, 32] + # CHECK: input_inst_data = [32, 16] + # CHECK: target_sg_layout = [6, 4] + # CHECK: target_sg_data = [32, 32] + # CHECK: target_inst_data = [8, 16] diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 146e213a..b11340f 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -71,6 +71,7 @@ def testInvalidModule(): func.func @foo() { return } """ ) + # CHECK: error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: func.func # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. try: execution_engine = ExecutionEngine(module) @@ -806,6 +807,7 @@ def testDumpToObjectFile(): # because RTDyldObjectLinkingLayer::emit will try to resolve symbols before dumping # (see the jitLinkForORC call at the bottom there). shared_libs=[MLIR_C_RUNNER_UTILS], + enable_pic=True, ) # CHECK: Object file exists: True diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 8f20231..8eff573 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -25,13 +25,13 @@ func.func @main() -> i32 attributes {llvm.emit_c_interface} { %O1 = memref.alloc() : memref<16xi32> %O2 = memref.alloc() : memref<4x16xi32> - %val0 = arith.constant 1.0 : f32 - %val1 = arith.constant 2.0 : f32 - %val2 = arith.constant 3.0 : f32 + %val0 = arith.constant 1 : i32 + %val1 = arith.constant 2 : i32 + %val2 = arith.constant 3 : i32 - call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> () - call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () - call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () + call @fill_0d_on_buffers(%val0, %O0) : (i32, memref<i32>) -> () + call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> () + call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> () %c0 = arith.constant 0 : index %res0 = memref.load %O0[] : memref<i32> @@ -149,19 +149,18 @@ def transform(module, boilerplate): def test_fill_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() - f32 = F32Type.get() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): linalg.fill(value, outs=[out]) @@ -184,19 +183,18 @@ test_fill_builtin() def test_fill_generic(): with Context() as ctx, Location.unknown(): module = Module.create() - f32 = F32Type.get() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py index 8316890..1747c66a 100644 --- a/mlir/test/python/ir/auto_location.py +++ b/mlir/test/python/ir/auto_location.py @@ -15,17 +15,10 @@ def run(f): assert Context._get_live_count() == 0 -@contextmanager -def with_infer_location(): - _cext.globals.set_loc_tracebacks_enabled(True) - yield - _cext.globals.set_loc_tracebacks_enabled(False) - - # CHECK-LABEL: TEST: testInferLocations @run def testInferLocations(): - with Context() as ctx, with_infer_location(): + with Context() as ctx, loc_tracebacks(): ctx.allow_unregistered_dialects = True op = Operation.create("custom.op1") @@ -34,24 +27,26 @@ def testInferLocations(): two = arith.constant(IndexType.get(), 2) # fmt: off - # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]+]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":31:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))) + # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]+]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))) # fmt: on print(op.location) - # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":32:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) - # fmt: on - print(one.location) + # Test nesting of loc_tracebacks(). + with loc_tracebacks(): + # fmt: off + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))))) + # fmt: on + print(one.location) # fmt: off - # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":34:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))) + # CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))) # fmt: on print(two.location) _cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__) three = arith.constant(IndexType.get(), 3) # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))) # fmt: on print(three.location) @@ -60,7 +55,7 @@ def testInferLocations(): print(four.location) # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))))) # fmt: on foo() @@ -86,13 +81,13 @@ def testInferLocations(): _cext.globals.set_loc_tracebacks_frame_limit(2) # fmt: off - # CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":80:16 to :22))) + # CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:16 to :22))) # fmt: on bar1() _cext.globals.set_loc_tracebacks_frame_limit(1) # fmt: off - # CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61)) + # CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:27 to :61)) # fmt: on bar1() diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py index ced5fce..e876c00 100644 --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -191,3 +191,18 @@ def testBlockEraseArgs(): blocks[0].erase_argument(0) # CHECK: ^bb0: op.print(enable_debug_info=True) + + +# CHECK-LABEL: TEST: testBlockArgSetLocation +# CHECK: ^bb0(%{{.+}}: f32 loc("new_loc")): +@run +def testBlockArgSetLocation(): + with Context() as ctx, Location.unknown(ctx) as loc: + ctx.allow_unregistered_dialects = True + f32 = F32Type.get() + op = Operation.create("test", regions=1, loc=Location.unknown()) + blocks = op.regions[0].blocks + blocks.append(f32) + arg = blocks[0].arguments[0] + arg.set_location(Location.name("new_loc")) + op.print(enable_debug_info=True) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index f5fa4da..d124c28 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -2,12 +2,12 @@ import gc import io -import itertools from tempfile import NamedTemporaryFile from mlir.ir import * from mlir.dialects.builtin import ModuleOp -from mlir.dialects import arith +from mlir.dialects import arith, func, scf, shape from mlir.dialects._ods_common import _cext +from mlir.extras import types as T def run(f): @@ -43,6 +43,10 @@ def testTraverseOpRegionBlockIterators(): ) op = module.operation assert op.context is ctx + # Note, __nb_signature__ stores the fully-qualified signature - the actual type stub emitted is + # class RegionSequence(Sequence[Region]) + # CHECK: class RegionSequence(collections.abc.Sequence[mlir._mlir_libs._mlir.ir.Region]) + print(RegionSequence.__nb_signature__) # Get the block using iterators off of the named collections. regions = list(op.regions[:]) blocks = list(regions[0].blocks) @@ -774,6 +778,21 @@ def testKnownOpView(): print(repr(constant)) +# CHECK-LABEL: TEST: testFailedGenericOperationCreationReportsError +@run +def testFailedGenericOperationCreationReportsError(): + with Context(), Location.unknown(): + c0 = shape.const_shape([]) + c1 = shape.const_shape([1, 2, 3]) + try: + shape.MeetOp.build_generic(operands=[c0, c1]) + except MLIRError as e: + # CHECK: unequal shape cardinality + print(e) + else: + assert False, "Expected exception" + + # CHECK-LABEL: TEST: testSingleResultProperty @run def testSingleResultProperty(): @@ -1199,3 +1218,25 @@ def testGetOwnerConcreteOpview(): r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw) for u in a.result.uses: assert isinstance(u.owner, arith.AddIOp) + + +# CHECK-LABEL: TEST: testIndexSwitch +@run +def testIndexSwitch(): + with Context() as ctx, Location.unknown(): + i32 = T.i32() + module = Module.create() + with InsertionPoint(module.body): + + @func.FuncOp.from_py_func(T.index()) + def index_switch(index): + c1 = arith.constant(i32, 1) + switch_op = scf.IndexSwitchOp(results=[i32], arg=index, cases=range(3)) + + assert len(switch_op.regions) == 4 + assert len(switch_op.regions[2:]) == 2 + assert len([i for i in switch_op.regions[2:]]) == 2 + assert len(switch_op.caseRegions) == 3 + assert len([i for i in switch_op.caseRegions]) == 3 + assert len(switch_op.caseRegions[1:]) == 2 + assert len([i for i in switch_op.caseRegions[1:]]) == 2 diff --git a/mlir/tools/mlir-irdl-to-cpp/mlir-irdl-to-cpp.cpp b/mlir/tools/mlir-irdl-to-cpp/mlir-irdl-to-cpp.cpp index a63b289..4a512bd 100644 --- a/mlir/tools/mlir-irdl-to-cpp/mlir-irdl-to-cpp.cpp +++ b/mlir/tools/mlir-irdl-to-cpp/mlir-irdl-to-cpp.cpp @@ -124,7 +124,7 @@ static LogicalResult translateIRDLToCpp(int argc, char **argv) { }; auto &splitInputFileDelimiter = splitInputFile.getValue(); - if (splitInputFileDelimiter.size()) + if (!splitInputFileDelimiter.empty()) return splitAndProcessBuffer(std::move(input), chunkFn, output->os(), splitInputFileDelimiter, splitInputFileDelimiter); diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index ac739be..a427132 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -182,7 +182,7 @@ void registerTestTransformsTransformDialectExtension(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS -void registerTestPasses() { +static void registerTestPasses() { registerCloneTestPasses(); registerConvertToTargetEnvPass(); registerPrintTosaAvailabilityPass(); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index 2a7ef7e..d7087cb 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR AttrOrTypeFormatGen.cpp BytecodeDialectGen.cpp DialectGen.cpp + DialectInterfacesGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp EnumPythonBindingGen.cpp diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp new file mode 100644 index 0000000..1d3b24a --- /dev/null +++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp @@ -0,0 +1,164 @@ +//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// DialectInterfaceGen generates definitions for Dialect interfaces. +// +//===----------------------------------------------------------------------===// + +#include "CppGenUtilities.h" +#include "DocGenUtilities.h" +#include "mlir/Support/IndentedOstream.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace mlir; +using llvm::Record; +using llvm::RecordKeeper; +using mlir::tblgen::Interface; +using mlir::tblgen::InterfaceMethod; + +/// Emit a string corresponding to a C++ type, followed by a space if necessary. +static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { + type = type.trim(); + os << type; + if (type.back() != '&' && type.back() != '*') + os << " "; + return os; +} + +/// Emit the method name and argument list for the given method. +static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name, + raw_ostream &os) { + os << name << '('; + llvm::interleaveComma(method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); + os << ") const"; +} + +/// Get an array of all Dialect Interface definitions +static std::vector<const Record *> +getAllInterfaceDefinitions(const RecordKeeper &records) { + std::vector<const Record *> defs = + records.getAllDerivedDefinitions("DialectInterface"); + + llvm::erase_if(defs, [&](const Record *def) { + // Ignore interfaces defined outside of the top-level file. + return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != + llvm::SrcMgr.getMainFileID(); + }); + return defs; +} + +namespace { +/// This struct is the generator used when processing tablegen dialect +/// interfaces. +class DialectInterfaceGenerator { +public: + DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) + : defs(getAllInterfaceDefinitions(records)), os(os) {} + + bool emitInterfaceDecls(); + +protected: + void emitInterfaceDecl(const Interface &interface); + + /// The set of interface records to emit. + std::vector<const Record *> defs; + // The stream to emit to. + raw_ostream &os; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// GEN: Interface declarations +//===----------------------------------------------------------------------===// + +static void emitInterfaceMethodDoc(const InterfaceMethod &method, + raw_ostream &os, StringRef prefix = "") { + if (std::optional<StringRef> description = method.getDescription()) + tblgen::emitDescriptionComment(*description, os, prefix); +} + +static void emitInterfaceMethodsDef(const Interface &interface, + raw_ostream &os) { + + raw_indented_ostream ios(os); + ios.indent(2); + + for (auto &method : interface.getMethods()) { + emitInterfaceMethodDoc(method, ios); + ios << "virtual "; + emitCPPType(method.getReturnType(), ios); + emitMethodNameAndArgs(method, method.getName(), ios); + ios << " {"; + + if (auto body = method.getBody()) { + ios << "\n"; + ios.indent(4); + ios << body << "\n"; + ios.indent(2); + } + os << "}\n"; + } +} + +void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) { + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); + + StringRef interfaceName = interface.getName(); + + tblgen::emitSummaryAndDescComments(os, "", + interface.getDescription().value_or("")); + + // Emit the main interface class declaration. + os << llvm::formatv( + "class {0} : public ::mlir::DialectInterface::Base<{0}> {\n" + "public:\n" + " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n", + interfaceName); + + emitInterfaceMethodsDef(interface, os); + + os << "};\n"; +} + +bool DialectInterfaceGenerator::emitInterfaceDecls() { + + llvm::emitSourceFileHeader("Dialect Interface Declarations", os); + + // Sort according to ID, so defs are emitted in the order in which they appear + // in the Tablegen file. + std::vector<const Record *> sortedDefs(defs); + llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { + return lhs->getID() < rhs->getID(); + }); + + for (const Record *def : sortedDefs) + emitInterfaceDecl(Interface(def)); + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: Interface registration hooks +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration genDecls( + "gen-dialect-interface-decls", "Generate dialect interface declarations.", + [](const RecordKeeper &records, raw_ostream &os) { + return DialectInterfaceGenerator(records, os).emitInterfaceDecls(); + }); diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 11bf9ce..8c7f9f7 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -702,41 +702,45 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); auto enumerants = enumInfo.getAllCases(); - llvm::NamespaceEmitter ns(os, cppNamespace); - - // Emit the enum class definition - emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); - - // Emit conversion function declarations - if (llvm::all_of(enumerants, [](EnumCase enumerant) { - return enumerant.getValue() >= 0; - })) { - os << formatv( - "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, - underlyingType.empty() ? std::string("unsigned") : underlyingType); - } - os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType); - os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName, - strToSymFnName); - - if (enumInfo.isBitEnum()) { - emitOperators(enumDef, os); - } else { - emitMaxValueFn(enumDef, os); - } + { + llvm::NamespaceEmitter ns(os, cppNamespace); + + // Emit the enum class definition + emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, + os); + + // Emit conversion function declarations + if (llvm::all_of(enumerants, [](EnumCase enumerant) { + return enumerant.getValue() >= 0; + })) { + os << formatv( + "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, + underlyingType.empty() ? std::string("unsigned") : underlyingType); + } + os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, + symToStrFnRetType); + os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName, + strToSymFnName); + + if (enumInfo.isBitEnum()) { + emitOperators(enumDef, os); + } else { + emitMaxValueFn(enumDef, os); + } - // Generate a generic `stringifyEnum` function that forwards to the method - // specified by the user. - const char *const stringifyEnumStr = R"( + // Generate a generic `stringifyEnum` function that forwards to the method + // specified by the user. + const char *const stringifyEnumStr = R"( inline {0} stringifyEnum({1} enumValue) {{ return {2}(enumValue); } )"; - os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName); + os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, + symToStrFnName); - // Generate a generic `symbolizeEnum` function that forwards to the method - // specified by the user. - const char *const symbolizeEnumStr = R"( + // Generate a generic `symbolizeEnum` function that forwards to the method + // specified by the user. + const char *const symbolizeEnumStr = R"( template <typename EnumType> ::std::optional<EnumType> symbolizeEnum(::llvm::StringRef); @@ -745,9 +749,9 @@ inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) { return {1}(str); } )"; - os << formatv(symbolizeEnumStr, enumName, strToSymFnName); + os << formatv(symbolizeEnumStr, enumName, strToSymFnName); - const char *const attrClassDecl = R"( + const char *const attrClassDecl = R"( class {1} : public ::mlir::{2} { public: using ValueType = {0}; @@ -757,13 +761,12 @@ public: {0} getValue() const; }; )"; - if (enumInfo.genSpecializedAttr()) { - StringRef attrClassName = enumInfo.getSpecializedAttrClassName(); - StringRef baseAttrClassName = "IntegerAttr"; - os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); - } - - ns.close(); + if (enumInfo.genSpecializedAttr()) { + StringRef attrClassName = enumInfo.getSpecializedAttrClassName(); + StringRef baseAttrClassName = "IntegerAttr"; + os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); + } + } // close `ns`. // Generate a generic parser and printer for the enum. std::string qualName = diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp index 525c8d6..54cc4b7 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -14,6 +14,7 @@ #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/CodeGenTypes/MachineValueType.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/PrettyStackTrace.h" @@ -60,8 +61,13 @@ using IndicesTy = llvm::SmallBitVector; /// Return a CodeGen value type entry from a type record. static llvm::MVT::SimpleValueType getValueType(const Record *rec) { - return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt( - "Value"); + return StringSwitch<llvm::MVT::SimpleValueType>( + rec->getValueAsDef("VT")->getValueAsString("LLVMName")) +#define GET_VT_ATTR(Ty, Sz, Any, Int, FP, Vec, Sc, Tup, NF, NElem, EltTy) \ + .Case(#Ty, llvm::MVT::Ty) +#include "llvm/CodeGen/GenVT.inc" +#undef GET_VT_ATTR + .Case("INVALID_SIMPLE_VALUE_TYPE", llvm::MVT::INVALID_SIMPLE_VALUE_TYPE); } /// Return the indices of the definitions in a list of definitions that @@ -191,7 +197,7 @@ private: /// Prints the elements in "range" separated by commas and surrounded by "[]". template <typename Range> -void printBracketedRange(const Range &range, llvm::raw_ostream &os) { +static void printBracketedRange(const Range &range, llvm::raw_ostream &os) { os << '['; llvm::interleaveComma(range, os); os << ']'; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 3b10842..dbae5d92 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2641,7 +2641,14 @@ void OpEmitter::genInlineCreateBody( std::string nonBuilderStateArgs = ""; if (!nonBuilderStateArgsList.empty()) { llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs); - interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS); + interleave( + nonBuilderStateArgsList, + [&](StringRef name) { + nonBuilderStateArgsOS << "std::forward<decltype(" << name << ")>(" + << name << ')'; + }, + [&] { nonBuilderStateArgsOS << ", "; }); + nonBuilderStateArgs = ", " + nonBuilderStateArgs; } if (cWithLoc) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 0172b3f..2c33f4e 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) { StringRef(kind).drop_front()); } +static StringRef getPythonType(StringRef cppType) { + return llvm::StringSwitch<StringRef>(cppType) + .Case("::mlir::MemRefType", "_ods_ir.MemRefType") + .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType") + .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType") + .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType") + .Case("::mlir::VectorType", "_ods_ir.VectorType") + .Case("::mlir::IntegerType", "_ods_ir.IntegerType") + .Case("::mlir::FloatType", "_ods_ir.FloatType") + .Case("::mlir::IndexType", "_ods_ir.IndexType") + .Case("::mlir::ComplexType", "_ods_ir.ComplexType") + .Case("::mlir::TupleType", "_ods_ir.TupleType") + .Case("::mlir::NoneType", "_ods_ir.NoneType") + .Default(StringRef()); +} + /// Emits accessors to "elements" of an Op definition. Currently, the supported /// elements are operands and results, indicated by `kind`, which must be either /// `operand` or `result` and is used verbatim in the emitted code. @@ -370,8 +386,11 @@ static void emitElementAccessors( seenVariableLength = true; if (element.name.empty()) continue; - const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" + std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; + if (StringRef pythonType = getPythonType(element.constraint.getCppType()); + !pythonType.empty()) + type = llvm::formatv("{0}[{1}]", type, pythonType); if (element.isVariableLength()) { if (element.isOptional()) { os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind, @@ -418,6 +437,12 @@ static void emitElementAccessors( type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; } + if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 || + std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) { + StringRef pythonType = getPythonType(element.constraint.getCppType()); + if (!pythonType.empty()) + type += "[" + pythonType.str() + "]"; + } os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), kind, numSimpleLength, numVariadicGroups, numPrecedingSimple, numPrecedingVariadic, type); @@ -449,6 +474,12 @@ static void emitElementAccessors( if (!element.isVariableLength() || element.isOptional()) { type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; + if (std::strcmp(type.c_str(), "_ods_ir.Value") == 0 || + std::strcmp(type.c_str(), "_ods_ir.OpResult") == 0) { + StringRef pythonType = getPythonType(element.constraint.getCppType()); + if (!pythonType.empty()) + type += "[" + pythonType.str() + "]"; + } if (!element.isVariableLength()) { trailing = "[0]"; } else if (element.isOptional()) { diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index f4b8eb4..e4ae78f 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -387,81 +387,6 @@ static void emitPass(const Pass &pass, raw_ostream &os) { emitPassDefs(pass, os); } -// TODO: Drop old pass declarations. -// The old pass base class is being kept until all the passes have switched to -// the new decls/defs design. -const char *const oldPassDeclBegin = R"( -template <typename DerivedT> -class {0}Base : public {1} { -public: - using Base = {0}Base; - - {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} - {0}Base(const {0}Base &other) : {1}(other) {{} - {0}Base& operator=(const {0}Base &) = delete; - {0}Base({0}Base &&) = delete; - {0}Base& operator=({0}Base &&) = delete; - ~{0}Base() = default; - - /// Returns the command-line argument attached to this pass. - static constexpr ::llvm::StringLiteral getArgumentName() { - return ::llvm::StringLiteral("{2}"); - } - ::llvm::StringRef getArgument() const override { return "{2}"; } - - ::llvm::StringRef getDescription() const override { return R"PD({3})PD"; } - - /// Returns the derived pass name. - static constexpr ::llvm::StringLiteral getPassName() { - return ::llvm::StringLiteral("{0}"); - } - ::llvm::StringRef getName() const override { return "{0}"; } - - /// Support isa/dyn_cast functionality for the derived pass class. - static bool classof(const ::mlir::Pass *pass) {{ - return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); - } - - /// A clone method to create a copy of this pass. - std::unique_ptr<::mlir::Pass> clonePass() const override {{ - return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); - } - - /// Register the dialects that must be loaded in the context before this pass. - void getDependentDialects(::mlir::DialectRegistry ®istry) const override { - {4} - } - - /// Explicitly declare the TypeID for this class. We declare an explicit private - /// instantiation because Pass classes should only be visible by the current - /// library. - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) - -protected: -)"; - -// TODO: Drop old pass declarations. -/// Emit a backward-compatible declaration of the pass base class. -static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { - StringRef defName = pass.getDef()->getName(); - std::string dependentDialectRegistrations; - { - llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); - llvm::interleave( - pass.getDependentDialects(), dialectsOs, - [&](StringRef dependentDialect) { - dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); - }, - "\n "); - } - os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(), - pass.getArgument(), pass.getSummary().trim(), - dependentDialectRegistrations); - emitPassOptionDecls(pass, os); - emitPassStatisticDecls(pass, os); - os << "};\n"; -} - static void emitPasses(const RecordKeeper &records, raw_ostream &os) { std::vector<Pass> passes = getPasses(records); os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; @@ -479,12 +404,10 @@ static void emitPasses(const RecordKeeper &records, raw_ostream &os) { emitRegistrations(passes, os); - // TODO: Drop old pass declarations. + // TODO: Remove warning, kept in to make error understandable. // Emit the old code until all the passes have switched to the new design. - os << "// Deprecated. Please use the new per-pass macros.\n"; os << "#ifdef GEN_PASS_CLASSES\n"; - for (const Pass &pass : passes) - emitOldPassDecl(pass, os); + os << "#error \"GEN_PASS_CLASSES is deprecated; use per-pass macros\"\n"; os << "#undef GEN_PASS_CLASSES\n"; os << "#endif // GEN_PASS_CLASSES\n"; } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index c3034bb8..08d6483 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1129,7 +1129,7 @@ void PatternEmitter::emit(StringRef rewriteName) { LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); // Emit RewritePattern for Pattern. - auto locs = pattern.getLocation(); + auto locs = pattern.getLocation(/*forSourceOutput=*/true); os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", llvm::reverse(locs)); os << formatv(R"(struct {0} : public ::mlir::RewritePattern { diff --git a/mlir/tools/tblgen-to-irdl/tblgen-to-irdl.cpp b/mlir/tools/tblgen-to-irdl/tblgen-to-irdl.cpp index 092ec2e..33421b4 100644 --- a/mlir/tools/tblgen-to-irdl/tblgen-to-irdl.cpp +++ b/mlir/tools/tblgen-to-irdl/tblgen-to-irdl.cpp @@ -18,10 +18,11 @@ using namespace llvm; using namespace mlir; // Generator that prints records. -GenRegistration printRecords("print-records", "Print all records to stdout", - [](const RecordKeeper &records, raw_ostream &os) { - os << records; - return false; - }); +static GenRegistration + printRecords("print-records", "Print all records to stdout", + [](const RecordKeeper &records, raw_ostream &os) { + os << records; + return false; + }); int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); } diff --git a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp index eaf0437..4ca9998 100644 --- a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp +++ b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp @@ -301,3 +301,12 @@ TEST(BarvinokTest, computeNumTermsPolytope) { gf = count[0].second; EXPECT_EQ(gf.getNumerators().size(), 24u); } + +TEST(BarvinokTest, solveParametricEquations) { + FracMatrix equations = makeFracMatrix(2, 3, {{2, 3, -4}, {2, 6, -7}}); + auto maybeSolution = solveParametricEquations(equations); + ASSERT_TRUE(maybeSolution.has_value()); + FracMatrix solution = *maybeSolution; + EXPECT_EQ(solution.at(0, 0), Fraction(1, 2)); + EXPECT_EQ(solution.at(1, 0), 1); +} diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp index 9ae90a4..599db4c 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -725,3 +725,18 @@ TEST(IntegerRelationTest, addLocalModulo) { EXPECT_TRUE(rel.containsPointNoLocal({x, x % 32})); } } + +TEST(IntegerRelationTest, simplify) { + IntegerRelation rel = + parseRelationFromSet("(x, y, z): (2*x + y - 4*z - 3 == 0, " + "3*x - y - 3*z + 2 == 0, x + 3*y - 5*z - 8 == 0," + "x - y + z >= 0)", + 2); + IntegerRelation copy = rel; + rel.simplify(); + + EXPECT_TRUE(rel.isEqual(copy)); + // The third equality is redundant and should be removed. + // It can be obtained from 2 times the first equality minus the second. + EXPECT_TRUE(rel.getNumEqualities() == 2); +} diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp index d7b442f..30e7ed9 100644 --- a/mlir/unittests/Bytecode/BytecodeTest.cpp +++ b/mlir/unittests/Bytecode/BytecodeTest.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/OwningOpRef.h" #include "mlir/Parser/Parser.h" +#include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/Endian.h" @@ -228,3 +229,39 @@ TEST(Bytecode, OpWithoutProperties) { EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) == OperationEquivalence::computeHash(roundtripped)); } + +TEST(Bytecode, DeepCallSiteLoc) { + MLIRContext context; + ParserConfig config(&context); + + // Create a deep CallSiteLoc chain to test iterative parsing. + Location baseLoc = FileLineColLoc::get(&context, "test.mlir", 1, 1); + Location loc = baseLoc; + constexpr int kDepth = 1000; + for (int i = 0; i < kDepth; ++i) { + loc = CallSiteLoc::get(loc, baseLoc); + } + + // Create a simple module with the deep location. + Builder builder(&context); + OwningOpRef<ModuleOp> module = + ModuleOp::create(loc, /*attributes=*/std::nullopt); + ASSERT_TRUE(module); + + // Write to bytecode. + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), os))); + + // Parse it back using the bytecode reader. + std::unique_ptr<Block> block = std::make_unique<Block>(); + ASSERT_TRUE(succeeded(readBytecodeFile( + llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config))); + + // Verify we got the roundtripped module. + ASSERT_FALSE(block->empty()); + Operation *roundTripped = &block->front(); + + // Verify the location matches. + EXPECT_EQ(module.get()->getLoc(), roundTripped->getLoc()); +} diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt index 177c868..c8c2bb9 100644 --- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIROpenACCTests OpenACCOpsTest.cpp + OpenACCOpsInterfacesTest.cpp OpenACCUtilsTest.cpp ) mlir_target_link_libraries(MLIROpenACCTests diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp new file mode 100644 index 0000000..7d52ef31 --- /dev/null +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp @@ -0,0 +1,117 @@ +//===- OpenACCOpsInterfacesTest.cpp - Unit tests for OpenACC interfaces --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::acc; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +class OpenACCOpsInterfacesTest : public ::testing::Test { +protected: + OpenACCOpsInterfacesTest() + : context(), builder(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect<acc::OpenACCDialect, memref::MemRefDialect>(); + } + + MLIRContext context; + OpBuilder builder; + Location loc; +}; + +//===----------------------------------------------------------------------===// +// GlobalVariableOpInterface Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceNonConstant) { + // Test that a non-constant global returns false for isConstant() + + auto memrefType = MemRefType::get({10}, builder.getF32Type()); + OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create( + builder, loc, + /*sym_name=*/builder.getStringAttr("mutable_global"), + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/TypeAttr::get(memrefType), + /*initial_value=*/Attribute(), + /*constant=*/UnitAttr(), + /*alignment=*/IntegerAttr()); + + auto globalVarIface = + dyn_cast<GlobalVariableOpInterface>(globalOp->getOperation()); + ASSERT_TRUE(globalVarIface != nullptr); + EXPECT_FALSE(globalVarIface.isConstant()); +} + +TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceConstant) { + // Test that a constant global returns true for isConstant() + + auto memrefType = MemRefType::get({5}, builder.getI32Type()); + OwningOpRef<memref::GlobalOp> constantGlobalOp = memref::GlobalOp::create( + builder, loc, + /*sym_name=*/builder.getStringAttr("constant_global"), + /*sym_visibility=*/builder.getStringAttr("public"), + /*type=*/TypeAttr::get(memrefType), + /*initial_value=*/Attribute(), + /*constant=*/builder.getUnitAttr(), + /*alignment=*/IntegerAttr()); + + auto globalVarIface = + dyn_cast<GlobalVariableOpInterface>(constantGlobalOp->getOperation()); + ASSERT_TRUE(globalVarIface != nullptr); + EXPECT_TRUE(globalVarIface.isConstant()); +} + +TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceInitRegion) { + // Test that memref::GlobalOp returns nullptr for getInitRegion() + // since it uses attributes for initialization, not regions + + auto memrefType = MemRefType::get({10}, builder.getF32Type()); + OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create( + builder, loc, + /*sym_name=*/builder.getStringAttr("test_global"), + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/TypeAttr::get(memrefType), + /*initial_value=*/Attribute(), + /*constant=*/UnitAttr(), + /*alignment=*/IntegerAttr()); + + auto globalVarIface = + dyn_cast<GlobalVariableOpInterface>(globalOp->getOperation()); + ASSERT_TRUE(globalVarIface != nullptr); + + // memref::GlobalOp doesn't have regions for initialization + EXPECT_EQ(globalVarIface.getInitRegion(), nullptr); +} + +//===----------------------------------------------------------------------===// +// AddressOfGlobalOpInterface Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCOpsInterfacesTest, AddressOfGlobalOpInterfaceGetSymbol) { + // Test that getSymbol() returns the correct symbol reference + + auto memrefType = MemRefType::get({5}, builder.getI32Type()); + const auto *symbolName = "test_global_symbol"; + + OwningOpRef<memref::GetGlobalOp> getGlobalOp = memref::GetGlobalOp::create( + builder, loc, memrefType, FlatSymbolRefAttr::get(&context, symbolName)); + + auto addrOfGlobalIface = + dyn_cast<AddressOfGlobalOpInterface>(getGlobalOp->getOperation()); + ASSERT_TRUE(addrOfGlobalIface != nullptr); + EXPECT_EQ(addrOfGlobalIface.getSymbol().getLeafReference(), symbolName); +} diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp index 6f4e305..60d8732 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -674,3 +674,696 @@ TEST_F(OpenACCUtilsTest, getBaseEntityChainedSubviews) { Value ultimateBase = getBaseEntity(baseEntity); EXPECT_EQ(ultimateBase, baseMemref); } + +//===----------------------------------------------------------------------===// +// isValidSymbolUse Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, isValidSymbolUseNoDefiningOp) { + // Create a memref.get_global that references a non-existent global + auto memrefType = MemRefType::get({10}, b.getI32Type()); + llvm::StringRef globalName = "nonexistent_global"; + SymbolRefAttr nonExistentSymbol = SymbolRefAttr::get(&context, globalName); + + OwningOpRef<memref::GetGlobalOp> getGlobalOp = + memref::GetGlobalOp::create(b, loc, memrefType, globalName); + + Operation *definingOp = nullptr; + bool result = + isValidSymbolUse(getGlobalOp.get(), nonExistentSymbol, &definingOp); + + EXPECT_FALSE(result); + EXPECT_EQ(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseRecipe) { + // Create a module to hold the recipe + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private recipe (any recipe type would work) + auto i32Type = b.getI32Type(); + llvm::StringRef recipeName = "test_recipe"; + OwningOpRef<PrivateRecipeOp> recipeOp = + PrivateRecipeOp::create(b, loc, recipeName, i32Type); + + // Create a value to privatize + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr = + cast<TypedValue<PointerLikeType>>(allocOp->getResult()); + + // Create a private op as the user operation + OwningOpRef<PrivateOp> privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Create a symbol reference to the recipe + SymbolRefAttr recipeSymbol = SymbolRefAttr::get(&context, recipeName); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(privateOp.get(), recipeSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_EQ(definingOp, recipeOp.get()); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseFunctionWithRoutineInfo) { + // Create a module to hold the function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function with routine_info attribute + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "routine_func"; + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Add routine_info attribute with a reference to a routine + SmallVector<SymbolRefAttr> routineRefs = { + SymbolRefAttr::get(&context, "acc_routine")}; + funcOp.get()->setAttr(getRoutineInfoAttrName(), + RoutineInfoAttr::get(&context, routineRefs)); + + // Create a call operation that uses the function symbol + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef<func::CallOp> callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseLLVMIntrinsic) { + // Create a module to hold the function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private function with LLVM intrinsic name + auto funcType = b.getFunctionType({b.getF32Type()}, {b.getF32Type()}); + llvm::StringRef intrinsicName = "llvm.sqrt.f32"; + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, intrinsicName, funcType); + + // Set visibility to private (required for intrinsics) + funcOp->setPrivate(); + + // Create a call operation that uses the intrinsic + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, intrinsicName); + OwningOpRef<func::CallOp> callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseFunctionNotIntrinsic) { + // Create a module to hold the function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private function that looks like intrinsic but isn't + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "llvm.not_a_real_intrinsic"; + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + funcOp->setPrivate(); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef<func::CallOp> callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + // Should be false because it's not a valid intrinsic and has no + // acc.routine_info attr + EXPECT_FALSE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseWithDeclareAttr) { + // Create a module to hold a function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function with declare attribute + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "declared_func"; + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Add declare attribute + funcOp.get()->setAttr( + getDeclareAttrName(), + DeclareAttr::get(&context, + DataClauseAttr::get(&context, DataClause::acc_copy))); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef<func::CallOp> callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseWithoutValidAttributes) { + // Create a module to hold a function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function without any special attributes + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "regular_func"; + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef<func::CallOp> callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + // Should be false - no routine_info, not an intrinsic, no declare attribute + EXPECT_FALSE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseNullDefiningOpPtr) { + // Create a module to hold a recipe + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private recipe + auto i32Type = b.getI32Type(); + llvm::StringRef recipeName = "test_recipe"; + OwningOpRef<PrivateRecipeOp> recipeOp = + PrivateRecipeOp::create(b, loc, recipeName, i32Type); + + // Create a value to privatize + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr = + cast<TypedValue<PointerLikeType>>(allocOp->getResult()); + + // Create a private op as the user operation + OwningOpRef<PrivateOp> privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Create a symbol reference to the recipe + SymbolRefAttr recipeSymbol = SymbolRefAttr::get(&context, recipeName); + + // Call without definingOpPtr (nullptr) + bool result = isValidSymbolUse(privateOp.get(), recipeSymbol, nullptr); + + EXPECT_TRUE(result); +} + +//===----------------------------------------------------------------------===// +// getDominatingDataClauses Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesFromComputeConstruct) { + // Create a module to hold a function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create a memref for the data clause + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr = + cast<TypedValue<PointerLikeType>>(allocOp->getResult()); + + // Create a copyin op to represent a data clause + OwningOpRef<CopyinOp> copyinOp = + CopyinOp::create(b, loc, varPtr, /*structured=*/true, /*implicit=*/false, + /*name=*/"test_var"); + + // Create a parallel op + OwningOpRef<ParallelOp> parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands + parallelOp->getDataClauseOperandsMutable().append(copyinOp->getAccVar()); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(parallelOp.get(), domInfo, postDomInfo); + + // Should contain the copyin from the parallel op + EXPECT_EQ(dataClauses.size(), 1ul); + EXPECT_EQ(dataClauses[0], copyinOp->getAccVar()); +} + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesFromEnclosingDataOp) { + // Create a module to hold a function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create a memref for the data clause + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr = + cast<TypedValue<PointerLikeType>>(allocOp->getResult()); + + // Create a copyin op for the data construct + OwningOpRef<CopyinOp> copyinOp = + CopyinOp::create(b, loc, varPtr, /*structured=*/true, /*implicit=*/false, + /*name=*/"test_var"); + + // Create a data op + OwningOpRef<DataOp> dataOp = + DataOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands + dataOp->getDataClauseOperandsMutable().append(copyinOp->getAccVar()); + + Region &dataRegion = dataOp->getRegion(); + Block *dataBlock = &dataRegion.emplaceBlock(); + + b.setInsertionPointToStart(dataBlock); + + // Create a parallel op inside the data region (no data clauses on parallel) + OwningOpRef<ParallelOp> parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(parallelOp.get(), domInfo, postDomInfo); + + // Should contain the copyin from the enclosing data op + EXPECT_EQ(dataClauses.size(), 1ul); + EXPECT_EQ(dataClauses[0], copyinOp->getAccVar()); +} + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesFromComputeAndEnclosingData) { + // Create a module to hold a function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create two memrefs for different data clauses + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp1 = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr1 = + cast<TypedValue<PointerLikeType>>(allocOp1->getResult()); + + OwningOpRef<memref::AllocaOp> allocOp2 = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr2 = + cast<TypedValue<PointerLikeType>>(allocOp2->getResult()); + + // Create copyin ops + OwningOpRef<CopyinOp> copyinOp1 = + CopyinOp::create(b, loc, varPtr1, /*structured=*/true, /*implicit=*/false, + /*name=*/"var1"); + OwningOpRef<CopyinOp> copyinOp2 = + CopyinOp::create(b, loc, varPtr2, /*structured=*/true, /*implicit=*/false, + /*name=*/"var2"); + + // Create a data op + OwningOpRef<DataOp> dataOp = + DataOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands for data op + dataOp->getDataClauseOperandsMutable().append(copyinOp1->getAccVar()); + + Region &dataRegion = dataOp->getRegion(); + Block *dataBlock = &dataRegion.emplaceBlock(); + + b.setInsertionPointToStart(dataBlock); + + // Create a parallel op inside the data region + OwningOpRef<ParallelOp> parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands for parallel op + parallelOp->getDataClauseOperandsMutable().append(copyinOp2->getAccVar()); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(parallelOp.get(), domInfo, postDomInfo); + + // Should contain both copyins (from data op and parallel op) + EXPECT_EQ(dataClauses.size(), 2ul); + // Note: Order might not be guaranteed, so check both are present + EXPECT_TRUE(llvm::is_contained(dataClauses, copyinOp1->getAccVar())); + EXPECT_TRUE(llvm::is_contained(dataClauses, copyinOp2->getAccVar())); +} + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesWithDeclareDirectives) { + // Create a module to hold a function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create a memref for the declare directive + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr = + cast<TypedValue<PointerLikeType>>(allocOp->getResult()); + + // Create a copyin op for declare + OwningOpRef<CopyinOp> copyinOp = + CopyinOp::create(b, loc, varPtr, /*structured=*/false, /*implicit=*/false, + /*name=*/"declare_var"); + + // Create a declare_enter op + OwningOpRef<DeclareEnterOp> declareEnterOp = DeclareEnterOp::create( + b, loc, TypeRange{b.getType<acc::DeclareTokenType>()}, + ValueRange{copyinOp->getAccVar()}); + + // Create a parallel op + OwningOpRef<ParallelOp> parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Create a declare_exit op that post-dominates the parallel + OwningOpRef<DeclareExitOp> declareExitOp = DeclareExitOp::create( + b, loc, declareEnterOp->getToken(), ValueRange{copyinOp->getAccVar()}); + + // Add a return to complete the function + OwningOpRef<func::ReturnOp> returnOp = func::ReturnOp::create(b, loc); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(parallelOp.get(), domInfo, postDomInfo); + + // Should contain the copyin from the declare directive + EXPECT_EQ(dataClauses.size(), 1ul); + EXPECT_EQ(dataClauses[0], copyinOp->getAccVar()); +} + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesMultipleDataConstructs) { + // Create a module to hold a function + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create three memrefs + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp1 = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr1 = + cast<TypedValue<PointerLikeType>>(allocOp1->getResult()); + + OwningOpRef<memref::AllocaOp> allocOp2 = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr2 = + cast<TypedValue<PointerLikeType>>(allocOp2->getResult()); + + OwningOpRef<memref::AllocaOp> allocOp3 = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr3 = + cast<TypedValue<PointerLikeType>>(allocOp3->getResult()); + + // Create copyin ops + OwningOpRef<CopyinOp> copyinOp1 = + CopyinOp::create(b, loc, varPtr1, /*structured=*/true, /*implicit=*/false, + /*name=*/"var1"); + OwningOpRef<CopyinOp> copyinOp2 = + CopyinOp::create(b, loc, varPtr2, /*structured=*/true, /*implicit=*/false, + /*name=*/"var2"); + OwningOpRef<CopyinOp> copyinOp3 = + CopyinOp::create(b, loc, varPtr3, /*structured=*/true, /*implicit=*/false, + /*name=*/"var3"); + + // Create outer data op + OwningOpRef<DataOp> outerDataOp = + DataOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands for outer data op + outerDataOp->getDataClauseOperandsMutable().append(copyinOp1->getAccVar()); + + Region &outerDataRegion = outerDataOp->getRegion(); + Block *outerDataBlock = &outerDataRegion.emplaceBlock(); + + b.setInsertionPointToStart(outerDataBlock); + + // Create inner data op + OwningOpRef<DataOp> innerDataOp = + DataOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands for inner data op + innerDataOp->getDataClauseOperandsMutable().append(copyinOp2->getAccVar()); + + Region &innerDataRegion = innerDataOp->getRegion(); + Block *innerDataBlock = &innerDataRegion.emplaceBlock(); + + b.setInsertionPointToStart(innerDataBlock); + + // Create a parallel op + OwningOpRef<ParallelOp> parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands for parallel op + parallelOp->getDataClauseOperandsMutable().append(copyinOp3->getAccVar()); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(parallelOp.get(), domInfo, postDomInfo); + + // Should contain all three copyins + EXPECT_EQ(dataClauses.size(), 3ul); + EXPECT_TRUE(llvm::is_contained(dataClauses, copyinOp1->getAccVar())); + EXPECT_TRUE(llvm::is_contained(dataClauses, copyinOp2->getAccVar())); + EXPECT_TRUE(llvm::is_contained(dataClauses, copyinOp3->getAccVar())); +} + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesKernelsOp) { + // Test with KernelsOp instead of ParallelOp + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create a memref + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr = + cast<TypedValue<PointerLikeType>>(allocOp->getResult()); + + // Create a copyin op + OwningOpRef<CopyinOp> copyinOp = + CopyinOp::create(b, loc, varPtr, /*structured=*/true, /*implicit=*/false, + /*name=*/"test_var"); + + // Create a kernels op + OwningOpRef<KernelsOp> kernelsOp = + KernelsOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands + kernelsOp->getDataClauseOperandsMutable().append(copyinOp->getAccVar()); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(kernelsOp.get(), domInfo, postDomInfo); + + // Should contain the copyin from the kernels op + EXPECT_EQ(dataClauses.size(), 1ul); + EXPECT_EQ(dataClauses[0], copyinOp->getAccVar()); +} + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesSerialOp) { + // Test with SerialOp + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create a memref + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef<memref::AllocaOp> allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue<PointerLikeType> varPtr = + cast<TypedValue<PointerLikeType>>(allocOp->getResult()); + + // Create a copyin op + OwningOpRef<CopyinOp> copyinOp = + CopyinOp::create(b, loc, varPtr, /*structured=*/true, /*implicit=*/false, + /*name=*/"test_var"); + + // Create a serial op + OwningOpRef<SerialOp> serialOp = + SerialOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Set the data clause operands + serialOp->getDataClauseOperandsMutable().append(copyinOp->getAccVar()); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(serialOp.get(), domInfo, postDomInfo); + + // Should contain the copyin from the serial op + EXPECT_EQ(dataClauses.size(), 1ul); + EXPECT_EQ(dataClauses[0], copyinOp->getAccVar()); +} + +TEST_F(OpenACCUtilsTest, getDominatingDataClausesEmpty) { + // Test with no data clauses at all + OwningOpRef<ModuleOp> module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function + auto funcType = b.getFunctionType({}, {}); + OwningOpRef<func::FuncOp> funcOp = + func::FuncOp::create(b, loc, "test_func", funcType); + Block *funcBlock = funcOp->addEntryBlock(); + + b.setInsertionPointToStart(funcBlock); + + // Create a parallel op with no data clauses + OwningOpRef<ParallelOp> parallelOp = + ParallelOp::create(b, loc, TypeRange{}, ValueRange{}); + + // Create dominance info + DominanceInfo domInfo(funcOp.get()); + PostDominanceInfo postDomInfo(funcOp.get()); + + // Get dominating data clauses + auto dataClauses = + getDominatingDataClauses(parallelOp.get(), domInfo, postDomInfo); + + // Should be empty + EXPECT_EQ(dataClauses.size(), 0ul); +} diff --git a/mlir/unittests/IR/RemarkTest.cpp b/mlir/unittests/IR/RemarkTest.cpp index 94753c1..f33d3ca 100644 --- a/mlir/unittests/IR/RemarkTest.cpp +++ b/mlir/unittests/IR/RemarkTest.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Remarks.h" @@ -377,4 +379,35 @@ TEST(Remark, TestRemarkFinal) { EXPECT_NE(errOut.find(pass3Msg), std::string::npos); // shown EXPECT_NE(errOut.find(pass4Msg), std::string::npos); // shown } + +TEST(Remark, TestArgWithAttribute) { + MLIRContext context; + + SmallVector<Attribute> elements; + elements.push_back(IntegerAttr::get(IntegerType::get(&context, 32), 1)); + elements.push_back(IntegerAttr::get(IntegerType::get(&context, 32), 2)); + elements.push_back(IntegerAttr::get(IntegerType::get(&context, 32), 3)); + ArrayAttr arrayAttr = ArrayAttr::get(&context, elements); + remark::detail::Remark::Arg argWithArray("Values", arrayAttr); + + // Verify the attribute is stored + EXPECT_TRUE(argWithArray.hasAttribute()); + EXPECT_EQ(argWithArray.getAttribute(), arrayAttr); + + // Ensure it can be retrieved as an ArrayAttr. + auto retrievedAttr = dyn_cast<ArrayAttr>(argWithArray.getAttribute()); + EXPECT_TRUE(retrievedAttr); + EXPECT_EQ(retrievedAttr.size(), 3u); + EXPECT_EQ(cast<IntegerAttr>(retrievedAttr[0]).getInt(), 1); + EXPECT_EQ(cast<IntegerAttr>(retrievedAttr[1]).getInt(), 2); + EXPECT_EQ(cast<IntegerAttr>(retrievedAttr[2]).getInt(), 3); + + // Create an Arg without an Attribute (string-based) + remark::detail::Remark::Arg argWithoutAttr("Key", "Value"); + + // Verify no attribute is stored + EXPECT_FALSE(argWithoutAttr.hasAttribute()); + EXPECT_FALSE(argWithoutAttr.getAttribute()); // Returns null Attribute + EXPECT_EQ(argWithoutAttr.val, "Value"); +} } // namespace diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp index 4b3545b..864eb40 100644 --- a/mlir/unittests/IR/SymbolTableTest.cpp +++ b/mlir/unittests/IR/SymbolTableTest.cpp @@ -77,7 +77,7 @@ namespace { TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { // Symbol as `Operation *`, rename within module. - testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + testReplaceAllSymbolUses([&](const auto &symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), module); @@ -86,7 +86,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { // Symbol as `StringAttr`, rename within module. - testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + testReplaceAllSymbolUses([&](const auto &symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), @@ -96,7 +96,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { // Symbol as `Operation *`, rename within module body. - testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + testReplaceAllSymbolUses([&](const auto &symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0)); @@ -105,7 +105,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { // Symbol as `StringAttr`, rename within module body. - testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + testReplaceAllSymbolUses([&](const auto &symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), @@ -115,7 +115,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { // Symbol as `Operation *`, rename within function. - testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + testReplaceAllSymbolUses([&](const auto &symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), fooOp); @@ -124,7 +124,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { // Symbol as `StringAttr`, rename within function. - testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + testReplaceAllSymbolUses([&](const auto &symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), |
