diff options
Diffstat (limited to 'mlir')
208 files changed, 8135 insertions, 1364 deletions
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index 7e6a466a..6f778b0 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -1188,6 +1188,19 @@ which can be `import`ed from the main dialect file, i.e. `python/mlir/dialects/<dialect-namespace>/passes.py` if it is undesirable to make the passes available along with the dialect. +### Other functionality + +Dialect functionality other than IR objects or passes, such as helper functions, +can be exposed to Python similarly to attributes and types. C API is expected to +exist for this functionality, which can then be wrapped using pybind11 and +[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h), +or nanobind and +[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h) +utilities to connect to the rest of Python API. The bindings can be located in a +separate module or in the same module as attributes and types, and +loaded along with the dialect. + + ## Extending MLIR in Python The MLIR Python bindings provide support for defining custom components in Python, @@ -1262,17 +1275,6 @@ This frozen set can then be applied to an operation using the greedy rewrite pattern driver via `apply_patterns_and_fold_greedily`. For further information, see [the PDL dialect documentation](/docs/Dialects/PDLOps/). -### Other functionality - -Dialect functionality other than IR objects or passes, such as helper functions, -can be exposed to Python similarly to attributes and types. C API is expected to -exist for this functionality, which can then be wrapped using pybind11 and -[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h), -or nanobind and -[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h) -utilities to connect to the rest of Python API. The bindings can be located in a -separate module or in the same module as attributes and types, and -loaded along with the dialect. ## Free-threading (No-GIL) support diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md index 686e500..2622c08 100644 --- a/mlir/docs/Canonicalization.md +++ b/mlir/docs/Canonicalization.md @@ -55,7 +55,7 @@ Some important things to think about w.r.t. canonicalization patterns: * It is always good to eliminate operations entirely when possible, e.g. by folding known identities (like "x + 0 = x"). -* Pattens with expensive running time (i.e. have O(n) complexity) or +* Patterns with expensive running time (i.e. have O(n) complexity) or complicated cost models don't belong to canonicalization: since the algorithm is executed iteratively until fixed-point we want patterns that execute quickly (in particular their matching phase). diff --git a/mlir/docs/Rationale/RationaleLinalgDialect.md b/mlir/docs/Rationale/RationaleLinalgDialect.md index 8975b0a..fbe2217 100644 --- a/mlir/docs/Rationale/RationaleLinalgDialect.md +++ b/mlir/docs/Rationale/RationaleLinalgDialect.md @@ -506,6 +506,72 @@ potential by introducing lower-level IR ops and *smaller* Linalg ops. This gradually reduces the potential, all the way to Loops + VectorOps and LLVMIR. +### Interchangeability of Forms<a name="forms"></a> + +#### The Linalg Forms + +The core Linalg operation set has four forms: +* **Generic:** Represented by `linalg.generic` and can encode all perfectly-nested +loop operations. +* **Category:** For example, `linalg.contract` and `linalg.elementwise`, that encode +higher-level semantics of a `linalg.generic` while still representing multiple _named_ +operations via attributes and syntax. In the future, other category operations are +planned (e.g.: `linalg.convolution` and `linalg.pooling`). +* **Named:** For example, `linalg.matmul`, `linalg.add`, etc. All _named_ forms that +can be converted to either a single _category_ or _generic_ forms, ie. are _perfectly nested_. +* **Composite:** For example `linalg.softmax` and the `winograd` variations. These +operations are not perfectly nested, and are converted to a list of other operations +(of various dialects). + +The forms correlate in the following manner: +``` ++ generic + \__ + category + \__ + named ++ composite +``` + +The `category` and `named` forms are derived from `linalg.generic` and are *equivalent*. +It should always be possible to convert a `named` operation into a `category` and that +into a `generic` and back to `named`. However, it may not be possible to convert a +`generic` into a `named` if there is no such `named` form. + +`Composite` operations cannot be converted to the other three classes and forms a +sub-set on its own. But they can use other Linalg forms when expanding. There can be +a pattern-matching transform to detect a graph of operations and convert into a +`composite` operation. + +The various forms in the Linalg dialect are meant to facilitate +pattern matching (single operations or DAGs) and to be able to consider +different forms as *canonical* for different transforms. + +Linalg's various forms also carry information, and that +information should be preserved as much as possible during the progressive +lowering. A `matmul` operation is a special case of a `contract` operation, +which in turn is a special case of a `generic` operation. Transformations on +Linalg operations (in any form) should avoid breaking down into +loops + arithmetic if they can still be represented as a Linalg operation, +preferably in their original form. + +#### Canonical Forms<a name="canonical_forms"></a> + +With multiple (often exchangeable) forms, and with transformation simplicity +in mind, compilers should aim for reducing matching and replacing complexity +as much as possible. When matching a single operation with a complex pattern, +having all the information in a `generic` Op is useful to iteratively match +different patterns in turn. However, when assembling a DAG of operations to +form a pattern, it's much simpler to match against named operations (like +`max` + `div` + `reduce` + `broadcast`) than their generic counterparts. + +This is where the interchangeability of forms comes in handy. Linalg has the +ability to specialize and generalize in order to convert the IR to a form that +is easier for a particular type of transform. With forms being semantically +equivalent, one can convert back-and-forth throughout the various transforms +to match the needs of each transform. For that particular transform, such +form can be considered _canonical_ and therefore "expected" for the pattern +to _match_. This reduces complexity of pattern matchers and simplifies compiler +pipelines. + ### Composable and Declarative Transformations<a name="declarative_transformations"></a> Complex and impactful transformations need not be hard to manipulate, write or maintain. Mixing XLA-style high-level op semantics knowledge with generic diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 2db1d84..fe42a20 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -352,7 +352,7 @@ typedef struct { /// Create a rewrite pattern that matches the operation /// with the given rootName, corresponding to mlir::OpRewritePattern. -MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( +MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames); diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h new file mode 100644 index 0000000..72ac247 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h @@ -0,0 +1,54 @@ +//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H +#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Interfaces/InferStridedMetadataInterface.h" + +namespace mlir { +namespace dataflow { + +/// This lattice element represents the strided metadata of an SSA value. +class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> { +public: + using Lattice::Lattice; +}; + +/// Strided metadata range analysis determines the strided metadata ranges of +/// SSA values using operations that define `InferStridedMetadataInterface`. +/// +/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and +/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not +/// loaded in the same solver context. +class StridedMetadataRangeAnalysis + : public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> { +public: + StridedMetadataRangeAnalysis(DataFlowSolver &solver, + int32_t indexBitwidth = 64); + + /// At an entry point, we cannot reason about strided metadata ranges unless + /// the type also encodes the data. For example, a memref with static layout. + void setToEntryState(StridedMetadataRangeLattice *lattice) override; + + /// Visit an operation. Invoke the transfer function on each operation that + /// implements `InferStridedMetadataInterface`. + LogicalResult + visitOperation(Operation *op, + ArrayRef<const StridedMetadataRangeLattice *> operands, + ArrayRef<StridedMetadataRangeLattice *> results) override; + +private: + /// Index bitwidth to use when operating with the int-ranges. + int32_t indexBitwidth = 64; +}; +} // namespace dataflow +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h new file mode 100644 index 0000000..91d3c92 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h @@ -0,0 +1,27 @@ +//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ +#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include <memory> + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Math to XeVM calls. +void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index da061b2..40d866e 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -49,6 +49,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 3c18ecc..25e9d34 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -797,6 +797,31 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { } //===----------------------------------------------------------------------===// +// MathToXeVM +//===----------------------------------------------------------------------===// + +def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { + let summary = + "Convert (fast) math operations to native XeVM/SPIRV equivalents"; + let description = [{ + This pass converts supported math ops marked with the `afn` fastmath flag + to function calls for OpenCL `native_` math intrinsics: These intrinsics + are typically mapped directly to native device instructions, often resulting + in better performance. However, the precision/error of these intrinsics + are implementation-defined, and thus math ops are only converted when they + have the `afn` fastmath flag enabled. + }]; + let options = [Option< + "convertArith", "convert-arith", "bool", /*default=*/"true", + "Convert supported Arith ops (e.g. arith.divf) as well.">]; + let dependentDialects = [ + "arith::ArithDialect", + "xevm::XeVMDialect", + "LLVM::LLVMDialect", + ]; +} + +//===----------------------------------------------------------------------===// // MathToEmitC //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index 1236fed..cace63d 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [ let summary = "tile zero operation"; let description = [{ Zeroes the destination tile, with the shape defined by the 2-dim - vector type of the result. This is eventually lowered into the - "tilezero" instruction with the corresponding tile configuration. - With memory-effects, each "tilezero" operation serves as a compilation - hint to use a separate tile register. + vector type of the result. + + The operation is eventually lowered into the "tilezero" instruction + with the corresponding tile configuration. + + With the write memory effect, each `amx.tile_zero` operation serves as + a compilation hint to use a separate tile register. Example: @@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [ def TileLoadOp : AMX_Op<"tile_load", [ AMXIntrinsicOpInterface, - MemoryEffects<[MemWrite]> + MemoryEffects<[MemWrite]>, + AttrSizedOperandSegments ]> { let summary = "tile load operation"; let description = [{ - Loads a tile from memory defined by a base and indices, with the - shape defined by the 2-dim vector type of the result. This is - eventually lowered into the "tileloadd" instruction with the - corresponding tile configuration. With memory-effects, each "tileload" - operation serves as a compilation hint to use a separate tile register. + Loads a tile from memory defined by a `base` and `indices`, with the + shape defined by the 2-dim vector type of the result. + The tile's rows are populated by reading contiguous elements starting + at the `base`. For each tile row, the `base` is incremented by `stride` + number of elements. + + The tile is loaded using the following indexing scheme: + + ``` + for row in enumerate(tile_rows): + mem_row = base[i0, i1, ..., iN + row * stride] + for col in enumerate(tile_cols): + tile[row, col] = mem_row[col] + ``` + + If the `stride` is not provided, then the `base` buffer must be at least + 2-dimensional, and the `stride` is automatically inferred and corresponds + to the stride of the buffer's second innermost dimension. + + The operation is eventually lowered into the "tileloadd" instruction + with the corresponding tile configuration. + + With the write memory effect, each `amx.tile_load` operation serves as + a compilation hint to use a separate tile register. Example: ```mlir + // Tile load from a 2-D memref with implicit stride. %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8> + + // Tile load from a 1-D memref with explicit stride. + %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8> ``` }]; let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base, - Variadic<Index>:$indices); + Variadic<Index>:$indices, + Optional<Index>:$stride); let results = (outs AnyAMXTile:$res); + let builders = [ + OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return ::llvm::cast<MemRefType>(getBase().getType()); @@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [ const ::mlir::LLVMTypeConverter &typeConverter, ::mlir::RewriterBase &rewriter); }]; - let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " - "type($base) `into` qualified(type($res))"; + let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict" + "`:` type($base) `into` qualified(type($res))"; let hasVerifier = 1; } def TileStoreOp : AMX_Op<"tile_store", [ - AMXIntrinsicOpInterface + AMXIntrinsicOpInterface, + AttrSizedOperandSegments ]> { let summary = "tile store operation"; let description = [{ - Stores a tile to memory defined by a base and indices, with the - shape defined by the 2-dim vector type of the value. This is - eventually lowered into the "tilestored" instruction with the - corresponding tile configuration. + Stores a tile to memory defined by a `base` and `indices`, with the + shape defined by the 2-dim vector type of the value. + The tile's rows are written contiguously to the buffer starting at + the `base`. For each tile row, the `base` is incremented by `stride` + number of elements. + + The tile is stored using the following indexing scheme: + + ``` + for row in enumerate(tile_rows): + mem_row = base[i0, i1, ..., iN + row * stride] + for col in enumerate(tile_cols): + mem_row[col] = tile[row, col] + ``` + + If the `stride` is not provided, then the `base` buffer must be at least + 2-dimensional, and the `stride` is automatically inferred and corresponds + to the stride of the buffer's second innermost dimension. + + The operation is eventually lowered into the "tilestored" instruction + with the corresponding tile configuration. Example: ```mlir + // Tile store to a 2-D memref with implicit stride. amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8> + + // Tile store to a 1-D memref with explicit stride. + amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8> ``` }]; let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base, Variadic<Index>:$indices, - AnyAMXTile:$val); + AnyAMXTile:$val, + Optional<Index>:$stride); + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return ::llvm::cast<MemRefType>(getBase().getType()); @@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [ const ::mlir::LLVMTypeConverter &typeConverter, ::mlir::RewriterBase &rewriter); }]; - let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " - "type($base) `,` qualified(type($val))"; + let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?" + "attr-dict `:` type($base) `,` qualified(type($val))"; let hasVerifier = 1; } @@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure, let description = [{ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with - pairs of "bf16"). The operation is eventually lowered into the - "tdpbf16ps" instruction with the corresponding tile configuration. + pairs of "bf16"). + + The operation is eventually lowered into the "tdpbf16ps" instruction with + the corresponding tile configuration. Example: @@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure, into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" combinations (4 bytes packed into dwords in the columns of both the source operand tiles; the zero or sign extension is specified with - the attributes and default to sign extended). The operation is eventually - lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" - instructions with the corresponding tile configuration. + the attributes and default to sign extended). + + The operation is eventually lowered into one of the "tdpbssd", + "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding + tile configuration. Example: diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h index 9b59af7..830c394 100644 --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -61,7 +61,7 @@ LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor); /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops); +[[maybe_unused]] bool isPerfectlyNested(ArrayRef<AffineForOp> loops); /// Get perfectly nested sequence of loops starting at root of loop nest /// (the first op being another AffineFor, and the second op - a terminator). diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 8d9474b..c301e0b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -48,6 +48,10 @@ mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conve mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics) add_mlir_dialect_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen) +set(LLVM_TARGET_DEFINITIONS LLVMDialectBytecode.td) +mlir_tablegen(LLVMDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="LLVM") +add_public_tablegen_target(MLIRLLVMDialectBytecodeIncGen) + set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td) mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls) mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td new file mode 100644 index 0000000..e7b202c --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td @@ -0,0 +1,353 @@ +//===-- LLVMDialectBytecode.td - LLVM bytecode defs --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the LLVM bytecode reader/writer definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_DIALECT_BYTECODE +#define LLVM_DIALECT_BYTECODE + +include "mlir/IR/BytecodeBase.td" + +//===----------------------------------------------------------------------===// +// Bytecode classes for attributes and types. +//===----------------------------------------------------------------------===// + +def String : + WithParser <"succeeded($_reader.readString($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedString($_getter)", + WithType <"StringRef">>>>; + +class Attr<string type> : WithType<type, Attribute>; + +class OptionalAttribute<string type> : + WithParser <"succeeded($_reader.readOptionalAttribute($_var))", + WithPrinter<"$_writer.writeOptionalAttribute($_getter)", + WithType<type, Attribute>>>; + +class OptionalInt<string type> : + WithParser <"succeeded(readOptionalInt($_reader, $_var))", + WithPrinter<"writeOptionalInt($_writer, $_getter)", + WithType<"std::optional<" # type # ">", VarInt>>>; + +class OptionalArrayRef<string eltType> : + WithParser <"succeeded(readOptionalArrayRef<" + # eltType # ">($_reader, $_var))", + WithPrinter<"writeOptionalArrayRef<" + # eltType # ">($_writer, $_getter)", + WithType<"SmallVector<" + # eltType # ">", Attribute>>>; + +class EnumClassFlag<string flag, string getter> : + WithParser<"succeeded($_reader.readVarInt($_var))", + WithBuilder<"(" # flag # ")$_args", + WithPrinter<"$_writer.writeVarInt((uint64_t)$_name." # getter # ")", + WithType<"uint64_t", VarInt>>>>; + +//===----------------------------------------------------------------------===// +// General notes +// - For each attribute or type entry, the argument names should match +// LLVMAttrDefs.td +// - The mnemonics are either LLVM or builtin MLIR attributes and types, but +// regular C++ types are also allowed to match builders and parsers. +// - DIScopeAttr and DINodeAttr are empty base classes, custom encoding not +// needed. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// DIBasicTypeAttr +//===----------------------------------------------------------------------===// + +def DIBasicTypeAttr : DialectAttribute<(attr + VarInt:$tag, + String:$name, + VarInt:$sizeInBits, + VarInt:$encoding +)>; + +//===----------------------------------------------------------------------===// +// DIExpressionAttr, DIExpressionElemAttr +//===----------------------------------------------------------------------===// + +def DIExpressionElemAttr : DialectAttribute<(attr + VarInt:$opcode, + OptionalArrayRef<"uint64_t">:$arguments +)>; + +def DIExpressionAttr : DialectAttribute<(attr + OptionalArrayRef<"DIExpressionElemAttr">:$operations +)>; + +//===----------------------------------------------------------------------===// +// DIFileAttr +//===----------------------------------------------------------------------===// + +def DIFileAttr : DialectAttribute<(attr + String:$name, + String:$directory +)>; + +//===----------------------------------------------------------------------===// +// DILocalVariableAttr +//===----------------------------------------------------------------------===// + +def DILocalVariableAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + VarInt:$arg, + VarInt:$alignInBits, + OptionalAttribute<"DITypeAttr">:$type, + EnumClassFlag<"DIFlags", "getFlags()">:$_rawflags, + LocalVar<"DIFlags", "(DIFlags)_rawflags">:$flags +)> { + // DILocalVariableAttr direct getter uses a `StringRef` for `name`. Since the + // more direct getter is prefered during bytecode reading, force the base one + // and prevent crashes for empty `StringAttr`. + let cBuilder = "$_resultType::get(context, $_args)"; +} + +//===----------------------------------------------------------------------===// +// DISubroutineTypeAttr +//===----------------------------------------------------------------------===// + +def DISubroutineTypeAttr : DialectAttribute<(attr + VarInt:$callingConvention, + OptionalArrayRef<"DITypeAttr">:$types +)>; + +//===----------------------------------------------------------------------===// +// DICompileUnitAttr +//===----------------------------------------------------------------------===// + +def DICompileUnitAttr : DialectAttribute<(attr + Attr<"DistinctAttr">:$id, + VarInt:$sourceLanguage, + Attr<"DIFileAttr">:$file, + OptionalAttribute<"StringAttr">:$producer, + Bool:$isOptimized, + EnumClassFlag<"DIEmissionKind", "getEmissionKind()">:$_rawEmissionKind, + LocalVar<"DIEmissionKind", "(DIEmissionKind)_rawEmissionKind">:$emissionKind, + EnumClassFlag<"DINameTableKind", "getNameTableKind()">:$_rawNameTableKind, + LocalVar<"DINameTableKind", + "(DINameTableKind)_rawNameTableKind">:$nameTableKind +)>; + +//===----------------------------------------------------------------------===// +// DISubprogramAttr +//===----------------------------------------------------------------------===// + +def DISubprogramAttr : DialectAttribute<(attr + OptionalAttribute<"DistinctAttr">:$recId, + Bool:$isRecSelf, + OptionalAttribute<"DistinctAttr">:$id, + OptionalAttribute<"DICompileUnitAttr">:$compileUnit, + OptionalAttribute<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"StringAttr">:$linkageName, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + VarInt:$scopeLine, + EnumClassFlag<"DISubprogramFlags", "getSubprogramFlags()">:$_rawflags, + LocalVar<"DISubprogramFlags", "(DISubprogramFlags)_rawflags">:$subprogramFlags, + OptionalAttribute<"DISubroutineTypeAttr">:$type, + OptionalArrayRef<"DINodeAttr">:$retainedNodes, + OptionalArrayRef<"DINodeAttr">:$annotations +)>; + +//===----------------------------------------------------------------------===// +// DICompositeTypeAttr +//===----------------------------------------------------------------------===// + +def DICompositeTypeAttr : DialectAttribute<(attr + OptionalAttribute<"DistinctAttr">:$recId, + Bool:$isRecSelf, + VarInt:$tag, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + OptionalAttribute<"DIScopeAttr">:$scope, + OptionalAttribute<"DITypeAttr">:$baseType, + EnumClassFlag<"DIFlags", "getFlags()">:$_rawflags, + LocalVar<"DIFlags", "(DIFlags)_rawflags">:$flags, + VarInt:$sizeInBits, + VarInt:$alignInBits, + OptionalAttribute<"DIExpressionAttr">:$dataLocation, + OptionalAttribute<"DIExpressionAttr">:$rank, + OptionalAttribute<"DIExpressionAttr">:$allocated, + OptionalAttribute<"DIExpressionAttr">:$associated, + OptionalArrayRef<"DINodeAttr">:$elements +)>; + +//===----------------------------------------------------------------------===// +// DIDerivedTypeAttr +//===----------------------------------------------------------------------===// + +def DIDerivedTypeAttr : DialectAttribute<(attr + VarInt:$tag, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DITypeAttr">:$baseType, + VarInt:$sizeInBits, + VarInt:$alignInBits, + VarInt:$offsetInBits, + OptionalInt<"unsigned">:$dwarfAddressSpace, + OptionalAttribute<"DINodeAttr">:$extraData +)>; + +//===----------------------------------------------------------------------===// +// DIImportedEntityAttr +//===----------------------------------------------------------------------===// + +def DIImportedEntityAttr : DialectAttribute<(attr + VarInt:$tag, + Attr<"DIScopeAttr">:$scope, + Attr<"DINodeAttr">:$entity, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + OptionalAttribute<"StringAttr">:$name, + OptionalArrayRef<"DINodeAttr">:$elements +)>; + +//===----------------------------------------------------------------------===// +// DIGlobalVariableAttr, DIGlobalVariableExpressionAttr +//===----------------------------------------------------------------------===// + +def DIGlobalVariableAttr : DialectAttribute<(attr + OptionalAttribute<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"StringAttr">:$linkageName, + Attr<"DIFileAttr">:$file, + VarInt:$line, + Attr<"DITypeAttr">:$type, + Bool:$isLocalToUnit, + Bool:$isDefined, + VarInt:$alignInBits +)>; + +def DIGlobalVariableExpressionAttr : DialectAttribute<(attr + Attr<"DIGlobalVariableAttr">:$var, + OptionalAttribute<"DIExpressionAttr">:$expr +)>; + +//===----------------------------------------------------------------------===// +// DILabelAttr +//===----------------------------------------------------------------------===// + +def DILabelAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line +)> { + // DILabelAttr direct getter uses a `StringRef` for `name`. Since the + // more direct getter is prefered during bytecode reading, force the base one + // and prevent crashes for empty `StringAttr`. + let cBuilder = "$_resultType::get(context, $_args)"; +} + +//===----------------------------------------------------------------------===// +// DILexicalBlockAttr, DILexicalBlockFileAttr +//===----------------------------------------------------------------------===// + +def DILexicalBlockAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + VarInt:$column +)>; + +def DILexicalBlockFileAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$discriminator +)>; + +//===----------------------------------------------------------------------===// +// DINamespaceAttr +//===----------------------------------------------------------------------===// + +def DINamespaceAttr : DialectAttribute<(attr + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIScopeAttr">:$scope, + Bool:$exportSymbols +)>; + +//===----------------------------------------------------------------------===// +// DISubrangeAttr +//===----------------------------------------------------------------------===// + +def DISubrangeAttr : DialectAttribute<(attr + OptionalAttribute<"Attribute">:$count, + OptionalAttribute<"Attribute">:$lowerBound, + OptionalAttribute<"Attribute">:$upperBound, + OptionalAttribute<"Attribute">:$stride +)>; + +//===----------------------------------------------------------------------===// +// LoopAnnotationAttr +//===----------------------------------------------------------------------===// + +def LoopAnnotationAttr : DialectAttribute<(attr + OptionalAttribute<"BoolAttr">:$disableNonforced, + OptionalAttribute<"LoopVectorizeAttr">:$vectorize, + OptionalAttribute<"LoopInterleaveAttr">:$interleave, + OptionalAttribute<"LoopUnrollAttr">:$unroll, + OptionalAttribute<"LoopUnrollAndJamAttr">:$unrollAndJam, + OptionalAttribute<"LoopLICMAttr">:$licm, + OptionalAttribute<"LoopDistributeAttr">:$distribute, + OptionalAttribute<"LoopPipelineAttr">:$pipeline, + OptionalAttribute<"LoopPeeledAttr">:$peeled, + OptionalAttribute<"LoopUnswitchAttr">:$unswitch, + OptionalAttribute<"BoolAttr">:$mustProgress, + OptionalAttribute<"BoolAttr">:$isVectorized, + OptionalAttribute<"FusedLoc">:$startLoc, + OptionalAttribute<"FusedLoc">:$endLoc, + OptionalArrayRef<"AccessGroupAttr">:$parallelAccesses +)>; + +//===----------------------------------------------------------------------===// +// Attributes & Types with custom bytecode handling. +//===----------------------------------------------------------------------===// + +// All the attributes with custom bytecode handling. +def LLVMDialectAttributes : DialectAttributes<"LLVM"> { + let elems = [ + DIBasicTypeAttr, + DICompileUnitAttr, + DICompositeTypeAttr, + DIDerivedTypeAttr, + DIExpressionElemAttr, + DIExpressionAttr, + DIFileAttr, + DIGlobalVariableAttr, + DIGlobalVariableExpressionAttr, + DIImportedEntityAttr, + DILabelAttr, + DILexicalBlockAttr, + DILexicalBlockFileAttr, + DILocalVariableAttr, + DINamespaceAttr, + DISubprogramAttr, + DISubrangeAttr, + DISubroutineTypeAttr, + LoopAnnotationAttr + // Referenced attributes currently missing support: + // AccessGroupAttr, LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr, + // LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr, + // LoopPeeledAttr, LoopUnswitchAttr + ]; +} + +def LLVMDialectTypes : DialectTypes<"LLVM"> { + let elems = []; +} + +#endif // LLVM_DIALECT_BYTECODE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9753dca..d0811a2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", custom<ShuffleType>(ref(type($v1)), type($res), ref($mask)) }]; + let hasFolder = 1; let hasVerifier = 1; string llvmInstName = "ShuffleVector"; @@ -1985,6 +1986,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ OptionalAttr<StrAttr>:$instrument_function_exit, OptionalAttr<UnitAttr>:$no_inline, OptionalAttr<UnitAttr>:$always_inline, + OptionalAttr<UnitAttr>:$inline_hint, OptionalAttr<UnitAttr>:$no_unwind, OptionalAttr<UnitAttr>:$will_return, OptionalAttr<UnitAttr>:$optimize_none, @@ -2037,6 +2039,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ /// Returns true if the `always_inline` attribute is set, false otherwise. bool isAlwaysInline() { return bool(getAlwaysInlineAttr()); } + /// Returns true if the `inline_hint` attribute is set, false otherwise. + bool isInlineHint() { return bool(getInlineHintAttr()); } + /// Returns true if the `optimize_none` attribute is set, false otherwise. bool isOptimizeNone() { return bool(getOptimizeNoneAttr()); } }]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 89fbeb7..d959464 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -263,6 +263,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda; + let hasVerifier = 1; // Backwards-compatibility builder for an unspecified range. let builders = [ @@ -279,6 +280,11 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = SetIntRangeFn setResultRanges) { nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges); } + + // Verify the range attribute satisfies LLVM ConstantRange constructor requirements. + ::llvm::LogicalResult $cppClass::verify() { + return verifyConstantRangeAttr(getOperation(), getRange()); + } }]; } @@ -1655,6 +1661,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> { }]; } +def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> { + let summary = "Convert a pair of float inputs to f4x2"; + let description = [{ + This Op converts each of the given float inputs to the specified fp4 type. + The result `dst` is returned as an i8 type where the converted values are + packed such that the value converted from `a` is stored in the upper 4 bits + of `dst` and the value converted from `b` is stored in the lower 4 bits of + `dst`. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let results = (outs I8:$dst); + let arguments = (ins F32:$a, F32:$b, + DefaultValuedAttr<BoolAttr, "false">:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args); + $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext())); + }]; +} + def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let summary = "Convert a pair of float inputs to f6x2"; let description = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 6925cec..68f31e6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -412,6 +412,32 @@ def ROCDL_WaitExpcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.expcnt", [], 0, [0], let assemblyFormat = "$count attr-dict"; } +def ROCDL_WaitAsynccntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.asynccnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until ASYNCCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + +def ROCDL_WaitTensorcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.tensorcnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until TENSORCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>, Arguments<(ins I16Attr:$priority)> { let assemblyFormat = "$priority attr-dict"; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 0d6ebc0..8728e66 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -395,31 +395,73 @@ def EliminateLinalgOpAnchoredEmptyTensorsOp //===----------------------------------------------------------------------===// def FuseOp : Op<Transform_Dialect, "structured.fuse", - [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, - DeclareOpInterfaceMethods<TransformOpInterface>, - ReportTrackingListenerFailuresOpTrait]> { + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tiles the operations pointed to by the target handle and fuses their - producers greedily using the options provided as attributes. + producers greedily using the options provided as attributes. Tile sizes + and loop interchange permutation can be provided as either static + attributes or dynamic values (transform parameters or payload handles). If `apply_cleanup` is true then slice canonicalization is applied between - fusion steps. + fusion steps. If `use_forall` is true then tiling method generates a + `scf.forall` loop instead of `scf.for` loops. }]; let arguments = (ins TransformHandleTypeInterface:$target, - DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes, - DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange, - DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup); + Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_sizes, + Variadic<TransformAnyParamTypeOrAnyHandle> : $tile_interchange, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes, + DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_interchange, + UnitAttr:$apply_cleanup, + UnitAttr:$use_forall); let results = (outs TransformHandleTypeInterface:$transformed, Variadic<TransformHandleTypeInterface>:$loops); + let builders = [ + OpBuilder<(ins "TypeRange":$loopTypes, + "Value":$target, + "ArrayRef<int64_t>":$staticTileSizes, + "ArrayRef<int64_t>":$staticTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + OpBuilder<(ins "TypeRange":$loopTypes, + "Value":$target, + "ArrayRef<OpFoldResult>":$mixedTileSizes, + "ArrayRef<OpFoldResult>":$mixedTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + OpBuilder<(ins "Value":$target, + "ArrayRef<int64_t>":$staticTileSizes, + "ArrayRef<int64_t>":$staticTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + OpBuilder<(ins "Value":$target, + "ArrayRef<OpFoldResult>":$mixedTileSizes, + "ArrayRef<OpFoldResult>":$mixedTileInterchange, + CArg<"bool", "false">:$applyCleanup, + CArg<"bool", "false">:$useForall)>, + ]; let assemblyFormat = [{ - $target ($tile_sizes^)? (`interchange` $tile_interchange^)? - (`apply_cleanup` `=` $apply_cleanup^)? attr-dict - `:` functional-type(operands, results) + $target oilist( + `tile_sizes` custom<DynamicIndexList>($tile_sizes, $static_tile_sizes) | + `interchange` custom<DynamicIndexList>($tile_interchange, $static_tile_interchange) + ) + attr-dict `:` functional-type(operands, results) }]; let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileSizes(); + ::mlir::SmallVector<::mlir::OpFoldResult> getMixedTileInterchange(); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 7266687..ae7a085 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1650,8 +1650,12 @@ protected: /// Rewrites a linalg::PackOp into a sequence of: /// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp + /// tensor::InsertSliceOp ops. +/// (InsertSliceOp is rank-expanding). /// -/// Requires that all the outer dims of the input linalg::PackOp are 1. +/// Requires that all the tiled-outer-dims of the input linalg::PackOp are 1. +/// Note that this constraint means that effectively exactly one tile is packed. +/// +/// In addition, assumes that the un-tiled-outer-dims are not permuted. /// /// Before: /// ``` @@ -1687,10 +1691,13 @@ struct DecomposeOuterUnitDimsPackOpPattern PatternRewriter &rewriter) const override; }; -/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced +/// Rewrites a linalg::UnPackOp into a sequence of: /// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp +/// (ExtractSliceOp is rank-reducing). /// -/// Requires that all the tiled outer dims of the input linalg::PackOp are 1. +/// Requires that all the tiled-outer-dims of the input linalg::UnPackOp are 1. +/// Note that this constraint means that effectively exactly one tile is +/// unpacked. /// /// Before: /// ``` diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index 30f33ed..69447f7 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -17,6 +17,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/InferStridedMetadataInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/MemOpInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 40b7d7e..b39207f 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" +include "mlir/Interfaces/InferStridedMetadataInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/MemOpInterfaces.td" include "mlir/Interfaces/MemorySlotInterfaces.td" @@ -184,6 +185,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ def DistinctObjectsOp : MemRef_Op<"distinct_objects", [ Pure, + DistinctObjectsTrait, DeclareOpInterfaceMethods<InferTypeOpInterface> // ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument ]> { @@ -2084,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store", def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, + DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>, DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>, DeclareOpInterfaceMethods<ViewLikeOpInterface>, AttrSizedOperandSegments, diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index 8f87235..b8aa497 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -183,6 +183,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() { return StringLiteral("acc.routine_info"); } +static constexpr StringLiteral getVarNameAttrName() { + return VarNameAttr::name; +} + static constexpr StringLiteral getCombinedConstructsAttrName() { return CombinedConstructsTypeAttr::name; } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 77e833f..1eaa21b4 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -415,6 +415,13 @@ def OpenACC_ConstructResource : Resource<"::mlir::acc::ConstructResource">; // Define a resource for the OpenACC current device setting. def OpenACC_CurrentDeviceIdResource : Resource<"::mlir::acc::CurrentDeviceIdResource">; +// Attribute for saving variable names - this can be attached to non-acc-dialect +// operations in order to ensure the name is preserved. +def OpenACC_VarNameAttr : OpenACC_Attr<"VarName", "var_name"> { + let parameters = (ins StringRefParameter<"">:$name); + let assemblyFormat = "`<` $name `>`"; +} + // Used for data specification in data clauses (2.7.1). // Either (or both) extent and upperbound must be specified. def OpenACC_DataBoundsOp : OpenACC_Op<"bounds", @@ -1316,6 +1323,24 @@ def OpenACC_PrivateRecipeOp }]; let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + /// Creates a PrivateRecipeOp and populates its regions based on the + /// variable type as long as the type implements MappableType or + /// PointerLikeType interface. If a type implements both, the MappableType + /// API will be preferred. Returns std::nullopt if the recipe cannot be + /// created or populated. The builder's current insertion point will be used + /// and it must be a valid place for this operation to be inserted. The + /// `recipeName` must be a unique name to prevent "redefinition of symbol" + /// IR errors. + static std::optional<PrivateRecipeOp> createAndPopulate( + ::mlir::OpBuilder &builder, + ::mlir::Location loc, + ::llvm::StringRef recipeName, + ::mlir::Type varType, + ::llvm::StringRef varName = "", + ::mlir::ValueRange bounds = {}); + }]; } //===----------------------------------------------------------------------===// @@ -1410,6 +1435,24 @@ def OpenACC_FirstprivateRecipeOp }]; let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + /// Creates a FirstprivateRecipeOp and populates its regions based on the + /// variable type as long as the type implements MappableType or + /// PointerLikeType interface. If a type implements both, the MappableType + /// API will be preferred. Returns std::nullopt if the recipe cannot be + /// created or populated. The builder's current insertion point will be used + /// and it must be a valid place for this operation to be inserted. The + /// `recipeName` must be a unique name to prevent "redefinition of symbol" + /// IR errors. + static std::optional<FirstprivateRecipeOp> createAndPopulate( + ::mlir::OpBuilder &builder, + ::mlir::Location loc, + ::llvm::StringRef recipeName, + ::mlir::Type varType, + ::llvm::StringRef varName = "", + ::mlir::ValueRange bounds = {}); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td index 0d16255..93e9e3d 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td @@ -73,17 +73,31 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { InterfaceMethod< /*description=*/[{ Generates allocation operations for the pointer-like type. It will create - an allocate that produces memory space for an instance of the current type. + an allocate operation that produces memory space for an instance of the + current type. The `varName` parameter is optional and can be used to provide a name - for the allocated variable. If the current type is represented - in a way that it does not capture the pointee type, `varType` must be - passed in to provide the necessary type information. + for the allocated variable. When provided, it must be used by the + implementation; and if the implementing dialect does not have its own + way to save it, the discardable `acc.var_name` attribute from the acc + dialect will be used. + + If the current type is represented in a way that it does not capture + the pointee type, `varType` must be passed in to provide the necessary + type information. The `originalVar` parameter is optional but enables support for dynamic types (e.g., dynamic memrefs). When provided, implementations can extract runtime dimension information from the original variable to create - allocations with matching dynamic sizes. + allocations with matching dynamic sizes. When generating recipe bodies, + `originalVar` should be the block argument representing the original + variable in the recipe region. + + The `needsFree` output parameter indicates whether the allocated memory + requires explicit deallocation. Implementations should set this to true + for heap allocations that need a matching deallocation operation (e.g., + alloc) and false for stack-based allocations (e.g., alloca). During + recipe generation, this determines whether a destroy region is created. Returns a Value representing the result of the allocation. If no value is returned, it means the allocation was not successfully generated. @@ -94,7 +108,8 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { "::mlir::Location":$loc, "::llvm::StringRef":$varName, "::mlir::Type":$varType, - "::mlir::Value":$originalVar), + "::mlir::Value":$originalVar, + "bool &":$needsFree), /*methodBody=*/"", /*defaultImplementation=*/[{ return {}; @@ -102,23 +117,34 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { >, InterfaceMethod< /*description=*/[{ - Generates deallocation operations for the pointer-like type. It deallocates - the instance provided. + Generates deallocation operations for the pointer-like type. - The `varPtr` parameter is required and must represent an instance that was - previously allocated. If the current type is represented in a way that it - does not capture the pointee type, `varType` must be passed in to provide - the necessary type information. Nothing is generated in case the allocate - is `alloca`-like. + The `varToFree` parameter is required and must represent an instance + that was previously allocated. When generating recipe bodies, this + should be the block argument representing the private variable in the + destroy region. + + The `allocRes` parameter is optional and provides the result of the + corresponding allocation from the init region. This allows implementations + to inspect the allocation operation to determine the appropriate + deallocation strategy. This is necessary because in recipe generation, + the allocation and deallocation occur in separate regions. Dialects that + use only one allocation type or can determine deallocation from type + information alone may ignore this parameter. + + The `varType` parameter must be provided if the current type does not + capture the pointee type information. No deallocation is generated for + stack-based allocations (e.g., alloca). - Returns true if deallocation was successfully generated or successfully - deemed as not needed to be generated, false otherwise. + Returns true if deallocation was successfully generated or determined to + be unnecessary, false otherwise. }], /*retTy=*/"bool", /*methodName=*/"genFree", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, - "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr, + "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varToFree, + "::mlir::Value":$allocRes, "::mlir::Type":$varType), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -274,6 +300,14 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> { The `initVal` can be empty - it is primarily needed for reductions to ensure the variable is also initialized with appropriate value. + The `needsDestroy` out-parameter is set by implementations to indicate + that destruction code must be generated after the returned private + variable usages, typically in the destroy region of recipe operations + (for example, when heap allocations or temporaries requiring cleanup + are created during initialization). When `needsDestroy` is set, callers + should invoke `generatePrivateDestroy` in the recipe's destroy region + with the privatized value returned by this method. + If the return value is empty, it means that recipe body was not successfully generated. }], @@ -284,12 +318,38 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> { "::mlir::TypedValue<::mlir::acc::MappableType>":$var, "::llvm::StringRef":$varName, "::mlir::ValueRange":$extents, - "::mlir::Value":$initVal), + "::mlir::Value":$initVal, + "bool &":$needsDestroy), /*methodBody=*/"", /*defaultImplementation=*/[{ return {}; }] >, + InterfaceMethod< + /*description=*/[{ + Generates destruction operations for a privatized value previously + produced by `generatePrivateInit`. This is typically inserted in a + recipe's destroy region, after all uses of the privatized value. + + The `privatized` value is the SSA value yielded by the init region + (and passed as the privatized argument to the destroy region). + Implementations should free heap-allocated storage or perform any + cleanup required for the given type. If no destruction is required, + this function should be a no-op and return `true`. + + Returns true if destruction was successfully generated or deemed not + necessary, false otherwise. + }], + /*retTy=*/"bool", + /*methodName=*/"generatePrivateDestroy", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::Value":$privatized), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return true; + }] + >, ]; } diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index 29b384f..b9d7163 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -174,7 +174,7 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [ ``` The above returns two indices, `633` and `693`, which correspond to the index of the previous process `(1, 1, 3)`, and the next process - `(1, 3, 3) along the split axis `1`. + `(1, 3, 3)` along the split axis `1`. A negative value is returned if there is no neighbor in the respective direction along the given `split_axes`. diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 10491f6..4ecf03c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); /// returned by getDefaultTargetEnv() if not provided. TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); +/// A thin wrapper around the SpecificationVersion enum to represent +/// and provide utilities around the TOSA specification version. +class TosaSpecificationVersion { +public: + TosaSpecificationVersion(uint32_t major, uint32_t minor) + : majorVersion(major), minorVersion(minor) {} + TosaSpecificationVersion(SpecificationVersion version) + : TosaSpecificationVersion(fromVersionEnum(version)) {} + + bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const { + return this->majorVersion == baseVersion.majorVersion && + this->minorVersion >= baseVersion.minorVersion; + } + + uint32_t getMajor() const { return majorVersion; } + uint32_t getMinor() const { return minorVersion; } + +private: + uint32_t majorVersion = 0; + uint32_t minorVersion = 0; + + static TosaSpecificationVersion + fromVersionEnum(SpecificationVersion version) { + switch (version) { + case SpecificationVersion::V_1_0: + return TosaSpecificationVersion(1, 0); + case SpecificationVersion::V_1_1_DRAFT: + return TosaSpecificationVersion(1, 1); + } + llvm_unreachable("Unknown TOSA version"); + } +}; + +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version); + /// This class represents the capability enabled in the target implementation /// such as profile, extension, and level. It's a wrapper class around /// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles, + explicit TargetEnv(SpecificationVersion specificationVersion, Level level, + const ArrayRef<Profile> &profiles, const ArrayRef<Extension> &extensions) - : level(level) { + : specificationVersion(specificationVersion), level(level) { enabledProfiles.insert_range(profiles); enabledExtensions.insert_range(extensions); } explicit TargetEnv(TargetEnvAttr targetAttr) - : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), - targetAttr.getExtensions()) {} + : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()) {} void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } - // TODO implement the following utilities. - // Version getSpecVersion() const; + SpecificationVersion getSpecVersion() const { return specificationVersion; } TosaLevel getLevel() const { if (level == Level::eightK) @@ -105,6 +140,7 @@ public: } private: + SpecificationVersion specificationVersion; Level level; llvm::SmallSet<Profile, 3> enabledProfiles; llvm::SmallSet<Extension, 13> enabledExtensions; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index 1f718ac..c1b5e78 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -2,441 +2,779 @@ // `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git profileComplianceMap = { {"tosa.argmax", - {{{Profile::pro_int}, {{i8T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp32T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.max_pool2d", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.add", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.arithmetic_right_shift", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_and", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_or", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_xor", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.intdiv", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_and", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_left_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_right_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_or", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_xor", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.maximum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.minimum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.mul", - {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}}, - {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pow", - {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.sub", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, - {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}}, {"tosa.abs", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_not", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}}, - {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}}, - {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.clz", + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.logical_not", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.negate", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reciprocal", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.select", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater_equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_all", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_any", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_max", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_min", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_product", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_sum", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", {{{Profile::pro_int}, - {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}}, + {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", {{{Profile::pro_int}, - {{i8T, i32T, i8T, i8T}, - {i16T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}}, + {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Profile::pro_int}, - {{boolT, i8T}, - {boolT, i16T}, - {boolT, i32T}, - {i8T, boolT}, - {i8T, i16T}, - {i8T, i32T}, - {i16T, boolT}, - {i16T, i8T}, - {i16T, i32T}, - {i32T, boolT}, - {i32T, i8T}, - {i32T, i16T}}}, - {{Profile::pro_fp}, - {{i8T, fp16T}, - {i8T, fp32T}, - {i16T, fp16T}, - {i16T, fp32T}, - {i32T, fp16T}, - {i32T, fp32T}, - {fp16T, i8T}, - {fp16T, i16T}, - {fp16T, i32T}, - {fp16T, fp32T}, - {fp32T, i8T}, - {fp32T, i16T}, - {fp32T, i32T}, - {fp32T, fp16T}}}}}, + {{{boolT, i8T}, SpecificationVersion::V_1_0}, + {{boolT, i16T}, SpecificationVersion::V_1_0}, + {{boolT, i32T}, SpecificationVersion::V_1_0}, + {{i8T, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, boolT}, SpecificationVersion::V_1_0}, + {{i16T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, boolT}, SpecificationVersion::V_1_0}, + {{i32T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{i8T, fp16T}, SpecificationVersion::V_1_0}, + {{i8T, fp32T}, SpecificationVersion::V_1_0}, + {{i16T, fp16T}, SpecificationVersion::V_1_0}, + {{i16T, fp32T}, SpecificationVersion::V_1_0}, + {{i32T, fp16T}, SpecificationVersion::V_1_0}, + {{i32T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, i8T}, SpecificationVersion::V_1_0}, + {{fp16T, i16T}, SpecificationVersion::V_1_0}, + {{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, i8T}, SpecificationVersion::V_1_0}, + {{fp32T, i16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.rescale", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i8T, i8T, i16T, i16T}, - {i8T, i8T, i32T, i32T}, - {i16T, i16T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i16T, i16T, i32T, i32T}, - {i32T, i32T, i8T, i8T}, - {i32T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT}, {i8T}, {i16T}, {i32T}}, + {{{boolT}, SpecificationVersion::V_1_0}, + {{i8T}, SpecificationVersion::V_1_0}, + {{i16T}, SpecificationVersion::V_1_0}, + {{i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; extensionComplianceMap = { {"tosa.argmax", - {{{Extension::int16}, {{i16T, i32T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}}, - {{Extension::bf16}, {{bf16T, i32T}}}}}, + {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}, + SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}, + SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, - {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, + {"tosa.fft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, - {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, - {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e4m3, Extension::fp8e5m2}, - {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, - {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}}, + {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}, allOf}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.max_pool2d", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rfft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}}, - {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}}, - {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.add", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.maximum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.minimum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.mul", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.pow", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sub", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Extension::int16}, + {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.abs", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.negate", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reciprocal", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.select", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater_equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_max", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_min", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_product", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_sum", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i48T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Extension::bf16}, - {{i8T, bf16T}, - {i16T, bf16T}, - {i32T, bf16T}, - {bf16T, i8T}, - {bf16T, i16T}, - {bf16T, i32T}, - {bf16T, fp32T}, - {fp32T, bf16T}}}, + {{{i8T, bf16T}, SpecificationVersion::V_1_0}, + {{i16T, bf16T}, SpecificationVersion::V_1_0}, + {{i32T, bf16T}, SpecificationVersion::V_1_0}, + {{bf16T, i8T}, SpecificationVersion::V_1_0}, + {{bf16T, i16T}, SpecificationVersion::V_1_0}, + {{bf16T, i32T}, SpecificationVersion::V_1_0}, + {{bf16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, bf16T}, SpecificationVersion::V_1_0}}}, {{Extension::bf16, Extension::fp8e4m3}, - {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}}, + {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::bf16, Extension::fp8e5m2}, - {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}}, + {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp16T}, - {fp8e4m3T, fp32T}, - {fp16T, fp8e4m3T}, - {fp32T, fp8e4m3T}}}, + {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp16T}, - {fp8e5m2T, fp32T}, - {fp16T, fp8e5m2T}, - {fp32T, fp8e5m2T}}}}}, + {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}}, {"tosa.rescale", {{{Extension::int16}, - {{i48T, i48T, i8T, i8T}, - {i48T, i48T, i16T, i16T}, - {i48T, i48T, i32T, i32T}}}}}, + {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", - {{{Extension::int4}, {{i4T}}}, - {{Extension::int16}, {{i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T}}}}}, + {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", - {{{Extension::int4}, {{i4T, i4T}}}, - {{Extension::int16}, {{i48T, i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.variable", + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; + // End of auto-generated metadata diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 38cb293..8376a4c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic, } //===----------------------------------------------------------------------===// -// TOSA Spec Section 1.5. +// TOSA Profiles and extensions // // Profile: // INT : Integer Inference. Integer operations, primarily 8 and 32-bit values. @@ -293,12 +293,6 @@ def Tosa_ExtensionAttr def Tosa_ExtensionArrayAttr : TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">; -def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; -def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; - -def Tosa_LevelAttr - : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; - // The base class for defining op availability dimensions. class Availability { // The following are fields for controlling the generated C++ OpInterface. @@ -405,17 +399,40 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability { } //===----------------------------------------------------------------------===// +// TOSA Levels +//===----------------------------------------------------------------------===// + +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; + +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; + +//===----------------------------------------------------------------------===// +// TOSA Specification versions +//===----------------------------------------------------------------------===// + +def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">; +def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">; + +def Tosa_SpecificationVersion : Tosa_I32EnumAttr< + "SpecificationVersion", "TOSA specification version", "specification_version", + [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>; + +//===----------------------------------------------------------------------===// // TOSA target environment. //===----------------------------------------------------------------------===// def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { let summary = "Target environment information."; let parameters = ( ins + "SpecificationVersion": $specification_version, "Level": $level, ArrayRefParameter<"Profile">: $profiles, ArrayRefParameter<"Extension">: $extensions ); - let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` " + "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " "`extensions` `=` `[` $extensions `]` `>`"; } diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 8f5c72b..7b946ad 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -36,12 +36,15 @@ enum CheckCondition { allOf }; +using VersionedTypeInfo = + std::pair<SmallVector<TypeInfo>, SpecificationVersion>; + template <typename T> struct OpComplianceInfo { // Certain operations require multiple modes enabled. // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3. SmallVector<T> mode; - SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet; + SmallVector<VersionedTypeInfo> operandTypeInfoSet; CheckCondition condition = CheckCondition::anyOf; }; @@ -130,9 +133,8 @@ public: // Find the required profiles or extensions from the compliance info according // to the operand type combination. template <typename T> - SmallVector<T> findMatchedProfile(Operation *op, - SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition); + OpComplianceInfo<T> + findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo); SmallVector<Profile> getCooperativeProfiles(Extension ext) { switch (ext) { @@ -168,8 +170,7 @@ public: private: template <typename T> - FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op, - CheckCondition &condition); + FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op); OperationProfileComplianceMap profileComplianceMap; OperationExtensionComplianceMap extensionComplianceMap; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 6ae19d8..14b00b0 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { ]; let options = [ + Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion", + /*default=*/"mlir::tosa::SpecificationVersion::V_1_0", + "The specification version that TOSA operators should conform to.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"), + clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft") + )}]>, Option<"level", "level", "mlir::tosa::Level", /*default=*/"mlir::tosa::Level::eightK", "The TOSA level that operators should conform to. A TOSA level defines " diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6e79085..6e15b1e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [ }]; let results = (outs VectorOfRankAndType<[1], [Index]>:$result); let assemblyFormat = "attr-dict `:` type($result)"; + let hasCanonicalizer = 1; } def Vector_YieldOp : Vector_Op<"yield", [ diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 5695d5d..19a5231 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().contains(name); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { return getAttrs().getAs<ArrayAttr>("stride"); } + ArrayAttr getBlockAttr() { + return getAttrs().getAs<ArrayAttr>("block"); + } + }]; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73f9061..426377f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, } def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, - AllElementTypesMatch<["mem_desc", "res"]>, - AllRanksMatch<["mem_desc", "res"]>]> { + AllElementTypesMatch<["mem_desc", "res"]>]> { let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic<Index>: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr<UnitAttr>:$subgroup_block_io, OptionalAttr<DistributeLayoutAttr>:$layout ); - let results = (outs XeGPU_ValueType:$res); + let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); let assemblyFormat = [{ $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands) `->` type(results) @@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, Arguments: - `mem_desc`: the memory descriptor identifying the SLM region. - `offsets`: the coordinates within the matrix to read from. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block load. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1336,7 +1339,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } ArrayRef<int64_t> getDataShape() { - return getRes().getType().getShape(); + auto resTy = getRes().getType(); + if (auto vecTy = llvm::dyn_cast<VectorType>(resTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1344,13 +1350,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - AllElementTypesMatch<["mem_desc", "data"]>, - AllRanksMatch<["mem_desc", "data"]>]> { + AllElementTypesMatch<["mem_desc", "data"]>]> { let arguments = (ins - XeGPU_ValueType:$data, + AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, XeGPU_MemDesc:$mem_desc, Variadic<Index>: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr<UnitAttr>:$subgroup_block_io, OptionalAttr<DistributeLayoutAttr>:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) @@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - `mem_desc`: the memory descriptor specifying the SLM region. - `offsets`: the coordinates within the matrix where the data will be written. - `data`: the values to be stored in the matrix. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block store. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1378,7 +1387,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, } ArrayRef<int64_t> getDataShape() { - return getData().getType().getShape(); + auto DataTy = getData().getType(); + if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1386,41 +1398,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, let hasVerifier = 1; } -def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview", - [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> { - let description = [{ - Creates a subview of a memory descriptor. The resulting memory descriptor can have - a lower rank than the source; in this case, the result dimensions correspond to the - higher-order dimensions of the source memory descriptor. - - Arguments: - - `src` : a memory descriptor. - - `offsets` : the coordinates within the matrix the subview will be created from. - - Results: - - `res` : a memory descriptor with smaller size. - - }]; - let arguments = (ins XeGPU_MemDesc:$src, - Variadic<Index>:$offsets, - DenseI64ArrayAttr:$const_offsets); - let results = (outs XeGPU_MemDesc:$res); - let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict - attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}]; - let builders = [ - OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)> - ]; - - let extraClassDeclaration = [{ - mlir::Value getViewSource() { return getSrc(); } - - SmallVector<OpFoldResult> getMixedOffsets() { - return getMixedValues(getConstOffsets(), getOffsets(), getContext()); - } - }]; - - let hasVerifier = 1; -} - - #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84902b2..b1196fb 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,12 +237,11 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { - return layout.getStrides(); + return layout.getStrideAttr(); } - // derive and return default strides SmallVector<int64_t> defaultStrides; llvm::append_range(defaultStrides, getShape().drop_front()); @@ -250,6 +249,63 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m Builder builder(getContext()); return builder.getI64ArrayAttr(defaultStrides); } + + ArrayAttr getBlockAttr() { + auto layout = getMemLayout(); + if (layout && layout.hasAttr("block")) { + return layout.getBlockAttr(); + } + Builder builder(getContext()); + return builder.getI64ArrayAttr({}); + } + + /// Heuristic to determine if the MemDesc uses column-major layout, + /// based on the rank and the value of the first stride dimension. + bool isColMajor() { + auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]); + return getRank() == 2 && dim0.getInt() == 1; + } + + // Get the Blocking shape for a MemDescType, Which is represented + // as an attribute in MemDescType. By default it is the shape + // of the mdescTy + SmallVector<int64_t> getBlockShape() { + SmallVector<int64_t> size(getShape()); + ArrayAttr blockAttr = getBlockAttr(); + if (!blockAttr.empty()) { + size.clear(); + for (auto attr : blockAttr.getValue()) { + size.push_back(cast<IntegerAttr>(attr).getInt()); + } + } + return size; + } + + // Get strides as vector of integer. + // If it contains block attribute, the strides are blocked strides. + // + // The blocking is applied to the base matrix shape derived from the + // memory descriptor's stride information. If the matrix described by + // the memory descriptor is not contiguous, it is assumed that the base + // matrix is contiguous and follows the same memory layout. + // + // It first computes the original matrix shape using the stride info, + // then computes the number of blocks in each dimension of original shape, + // then compute the outer block shape and stride, + // then combines the inner and outer block shape and stride + // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>` + // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) + // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1] + // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) + SmallVector<int64_t> getStrideShape(); + + /// Generates instructions to compute the linearize offset + // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout + // the strides of memory descriptor is always considered regardless of blocked or not + Value getLinearOffsets(OpBuilder &builder, + Location loc, ArrayRef<OpFoldResult> offsets); + + }]; let hasCustomAssemblyFormat = true; diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h index 20e84ec..9877926 100644 --- a/mlir/include/mlir/IR/Remarks.h +++ b/mlir/include/mlir/IR/Remarks.h @@ -18,7 +18,6 @@ #include "llvm/Remarks/Remark.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Regex.h" -#include <optional> #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" @@ -144,7 +143,7 @@ public: llvm::StringRef getCategoryName() const { return categoryName; } - llvm::StringRef getFullCategoryName() const { + llvm::StringRef getCombinedCategoryName() const { if (categoryName.empty() && subCategoryName.empty()) return {}; if (subCategoryName.empty()) @@ -318,7 +317,7 @@ private: }; //===----------------------------------------------------------------------===// -// MLIR Remark Streamer +// Pluggable Remark Utilities //===----------------------------------------------------------------------===// /// Base class for MLIR remark streamers that is used to stream @@ -338,6 +337,26 @@ public: virtual void finalize() {} // optional }; +using ReportFn = llvm::unique_function<void(const Remark &)>; + +/// Base class for MLIR remark emitting policies that is used to emit +/// optimization remarks to the underlying remark streamer. The derived classes +/// should implement the `reportRemark` method to provide the actual emitting +/// implementation. +class RemarkEmittingPolicyBase { +protected: + ReportFn reportImpl; + +public: + RemarkEmittingPolicyBase() = default; + virtual ~RemarkEmittingPolicyBase() = default; + + void initialize(ReportFn fn) { reportImpl = std::move(fn); } + + virtual void reportRemark(const Remark &remark) = 0; + virtual void finalize() = 0; +}; + //===----------------------------------------------------------------------===// // Remark Engine (MLIR Context will own this class) //===----------------------------------------------------------------------===// @@ -355,6 +374,8 @@ private: std::optional<llvm::Regex> failedFilter; /// The MLIR remark streamer that will be used to emit the remarks. std::unique_ptr<MLIRRemarkStreamerBase> remarkStreamer; + /// The MLIR remark policy that will be used to emit the remarks. + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy; /// When is enabled, engine also prints remarks as mlir::emitRemarks. bool printAsEmitRemarks = false; @@ -392,6 +413,8 @@ private: InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts, bool (RemarkEngine::*isEnabled)(StringRef) const); + /// Report a remark. + void reportImpl(const Remark &remark); public: /// Default constructor is deleted, use the other constructor. @@ -407,8 +430,15 @@ public: ~RemarkEngine(); /// Setup the remark engine with the given output path and format. - LogicalResult initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, - std::string *errMsg); + LogicalResult + initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy, + std::string *errMsg); + + /// Get the remark emitting policy. + RemarkEmittingPolicyBase *getRemarkEmittingPolicy() const { + return remarkEmittingPolicy.get(); + } /// Report a remark. void report(const Remark &&remark); @@ -446,6 +476,46 @@ inline InFlightRemark withEngine(Fn fn, Location loc, Args &&...args) { namespace mlir::remark { +//===----------------------------------------------------------------------===// +// Remark Emitting Policies +//===----------------------------------------------------------------------===// + +/// Policy that emits all remarks. +class RemarkEmittingPolicyAll : public detail::RemarkEmittingPolicyBase { +public: + RemarkEmittingPolicyAll(); + + void reportRemark(const detail::Remark &remark) override { + assert(reportImpl && "reportImpl is not set"); + reportImpl(remark); + } + void finalize() override {} +}; + +/// Policy that emits final remarks. +class RemarkEmittingPolicyFinal : public detail::RemarkEmittingPolicyBase { +private: + /// user can intercept them for custom processing via a registered callback, + /// otherwise they will be reported on engine destruction. + llvm::DenseSet<detail::Remark> postponedRemarks; + +public: + RemarkEmittingPolicyFinal(); + + void reportRemark(const detail::Remark &remark) override { + postponedRemarks.erase(remark); + postponedRemarks.insert(remark); + } + + void finalize() override { + assert(reportImpl && "reportImpl is not set"); + for (auto &remark : postponedRemarks) { + if (reportImpl) + reportImpl(remark); + } + } +}; + /// Create a Reason with llvm::formatv formatting. template <class... Ts> inline detail::LazyTextBuild reason(const char *fmt, Ts &&...ts) { @@ -505,16 +575,72 @@ inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) { /// Setup remarks for the context. This function will enable the remark engine /// and set the streamer to be used for optimization remarks. The remark -/// categories are used to filter the remarks that will be emitted by the remark -/// engine. If a category is not specified, it will not be emitted. If +/// categories are used to filter the remarks that will be emitted by the +/// remark engine. If a category is not specified, it will not be emitted. If /// `printAsEmitRemarks` is true, the remarks will be printed as /// mlir::emitRemarks. 'streamer' must inherit from MLIRRemarkStreamerBase and /// will be used to stream the remarks. LogicalResult enableOptimizationRemarks( MLIRContext &ctx, std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer, + std::unique_ptr<remark::detail::RemarkEmittingPolicyBase> + remarkEmittingPolicy, const remark::RemarkCategories &cats, bool printAsEmitRemarks = false); } // namespace mlir::remark +// DenseMapInfo specialization for Remark +namespace llvm { +template <> +struct DenseMapInfo<mlir::remark::detail::Remark> { + static constexpr StringRef kEmptyKey = "<EMPTY_KEY>"; + static constexpr StringRef kTombstoneKey = "<TOMBSTONE_KEY>"; + + /// Helper to provide a static dummy context for sentinel keys. + static mlir::MLIRContext *getStaticDummyContext() { + static mlir::MLIRContext dummyContext; + return &dummyContext; + } + + /// Create an empty remark + static inline mlir::remark::detail::Remark getEmptyKey() { + return mlir::remark::detail::Remark( + mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note, + mlir::UnknownLoc::get(getStaticDummyContext()), + mlir::remark::RemarkOpts::name(kEmptyKey)); + } + + /// Create a dead remark + static inline mlir::remark::detail::Remark getTombstoneKey() { + return mlir::remark::detail::Remark( + mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note, + mlir::UnknownLoc::get(getStaticDummyContext()), + mlir::remark::RemarkOpts::name(kTombstoneKey)); + } + + /// Compute the hash value of the remark + static unsigned getHashValue(const mlir::remark::detail::Remark &remark) { + return llvm::hash_combine( + remark.getLocation().getAsOpaquePointer(), + llvm::hash_value(remark.getRemarkName()), + llvm::hash_value(remark.getCombinedCategoryName())); + } + + static bool isEqual(const mlir::remark::detail::Remark &lhs, + const mlir::remark::detail::Remark &rhs) { + // Check for empty/tombstone keys first + if (lhs.getRemarkName() == kEmptyKey || + lhs.getRemarkName() == kTombstoneKey || + rhs.getRemarkName() == kEmptyKey || + rhs.getRemarkName() == kTombstoneKey) { + return lhs.getRemarkName() == rhs.getRemarkName(); + } + + // For regular remarks, compare key identifying fields + return lhs.getLocation() == rhs.getLocation() && + lhs.getRemarkName() == rhs.getRemarkName() && + lhs.getCombinedCategoryName() == rhs.getCombinedCategoryName(); + } +}; +} // namespace llvm #endif // MLIR_IR_REMARKS_H diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index a5feb59..72ed046 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(FunctionInterfaces) add_mlir_interface(IndexingMapOpInterface) add_mlir_interface(InferIntRangeInterface) +add_mlir_interface(InferStridedMetadataInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(MemOpInterfaces) diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h index 0e107e8..a6de3d1 100644 --- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -117,7 +117,8 @@ public: IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} /// Create an integer value range lattice value. - IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt) + explicit IntegerValueRange( + std::optional<ConstantIntRanges> value = std::nullopt) : value(std::move(value)) {} /// Whether the range is uninitialized. This happens when the state hasn't @@ -167,6 +168,15 @@ using SetIntRangeFn = using SetIntLatticeFn = llvm::function_ref<void(Value, const IntegerValueRange &)>; +/// Helper callback type to get the integer range of a value. +using GetIntRangeFn = function_ref<IntegerValueRange(Value)>; + +/// Helper function to collect the integer range values of an array of op fold +/// results. +SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values, + GetIntRangeFn getIntRange, + int32_t indexBitwidth); + class InferIntRangeInterface; namespace intrange::detail { diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h new file mode 100644 index 0000000..0c572e0 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h @@ -0,0 +1,145 @@ +//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the strided metadata inference interface +// defined in `InferStridedMetadataInterface.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H +#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H + +#include "mlir/Interfaces/InferIntRangeInterface.h" + +namespace mlir { +/// A class that represents the strided metadata range information, including +/// offsets, sizes, and strides as integer ranges. +class StridedMetadataRange { +public: + /// Default constructor creates uninitialized ranges. + StridedMetadataRange() = default; + + /// Returns a ranked strided metadata range. + static StridedMetadataRange + getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets, + SmallVectorImpl<ConstantIntRanges> &&sizes, + SmallVectorImpl<ConstantIntRanges> &&strides) { + return StridedMetadataRange(std::move(offsets), std::move(sizes), + std::move(strides)); + } + + /// Returns a strided metadata range with maximum ranges. + static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, + int32_t offsetsRank, + int32_t sizeRank, + int32_t stridedRank) { + return StridedMetadataRange( + SmallVector<ConstantIntRanges>( + offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)), + SmallVector<ConstantIntRanges>( + sizeRank, ConstantIntRanges::maxRange(indexBitwidth)), + SmallVector<ConstantIntRanges>( + stridedRank, ConstantIntRanges::maxRange(indexBitwidth))); + } + + static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, + int32_t rank) { + return getMaxRanges(indexBitwidth, 1, rank, rank); + } + + /// Returns whether the metadata is uninitialized. + bool isUninitialized() const { return !offsets.has_value(); } + + /// Get the offsets range. + ArrayRef<ConstantIntRanges> getOffsets() const { + return offsets ? *offsets : ArrayRef<ConstantIntRanges>(); + } + MutableArrayRef<ConstantIntRanges> getOffsets() { + return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>(); + } + + /// Get the sizes ranges. + ArrayRef<ConstantIntRanges> getSizes() const { return sizes; } + MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; } + + /// Get the strides ranges. + ArrayRef<ConstantIntRanges> getStrides() const { return strides; } + MutableArrayRef<ConstantIntRanges> getStrides() { return strides; } + + /// Compare two strided metadata ranges. + bool operator==(const StridedMetadataRange &other) const { + return offsets == other.offsets && sizes == other.sizes && + strides == other.strides; + } + + /// Print the strided metadata range. + void print(raw_ostream &os) const; + + /// Join two strided metadata ranges, by taking the element-wise union of the + /// metadata. + static StridedMetadataRange join(const StridedMetadataRange &lhs, + const StridedMetadataRange &rhs) { + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + + // Helper fuction to compute the range union of constant ranges. + auto rangeUnion = + +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs) + -> ConstantIntRanges { + return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs)); + }; + + // Get the elementwise range union. Note, that `zip_equal` will assert if + // sizes are not equal. + SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector( + llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion); + SmallVector<ConstantIntRanges> sizes = + llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion); + SmallVector<ConstantIntRanges> strides = llvm::map_to_vector( + llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion); + + // Return the joined metadata. + return StridedMetadataRange(std::move(offsets), std::move(sizes), + std::move(strides)); + } + +private: + /// Create a strided metadata range with the given offset, sizes, and strides. + StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets, + SmallVectorImpl<ConstantIntRanges> &&sizes, + SmallVectorImpl<ConstantIntRanges> &&strides) + : offsets(std::move(offsets)), sizes(std::move(sizes)), + strides(std::move(strides)) {} + + /// The offsets range. + std::optional<SmallVector<ConstantIntRanges>> offsets; + + /// The sizes ranges. + SmallVector<ConstantIntRanges> sizes; + + /// The strides ranges. + SmallVector<ConstantIntRanges> strides; +}; + +/// Print the strided metadata to `os`. +inline raw_ostream &operator<<(raw_ostream &os, + const StridedMetadataRange &range) { + range.print(os); + return os; +} + +/// Callback function type for setting the strided metadata of a value. +using SetStridedMetadataRangeFn = + function_ref<void(Value, const StridedMetadataRange &)>; +} // end namespace mlir + +#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc" + +#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td new file mode 100644 index 0000000..ee5b094 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td @@ -0,0 +1,45 @@ +//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for strided metadata range analysis +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE +#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE + +include "mlir/IR/OpBase.td" + +def InferStridedMetadataOpInterface : + OpInterface<"InferStridedMetadataOpInterface"> { + let description = [{ + Allows operations to participate in strided metadata analysis by providing + methods that allow them to specify bounds on offsets, sizes, and strides + of their result(s) given bounds on their input(s) if known. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Infer the strided metadata bounds on the results of this op given + the bounds on its operands. + For each result value or block argument of interest, the method should + call `setMetadata` with that `Value` as an argument. + The `operands` parameter contains the strided metadata ranges for all the + operands of the operation in order. + The `getIntRange` callback is provided for obtaining the int-range + analysis result for a given value. + }], + "void", "inferStridedMetadataRanges", + (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands, + "::mlir::GetIntRangeFn":$getIntRange, + "::mlir::SetStridedMetadataRangeFn":$setMetadata, + "int32_t":$indexBitwidth)> + ]; +} +#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index db9c37f..c1c2269 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -230,6 +230,22 @@ LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, ArrayRef<int64_t> attr, ValueRange values); +namespace OpTrait { +/// This trai indicates that pointer-like objects (such as memrefs) returned +/// from this operation will never alias with each other. This provides a +/// guarantee to optimization passes that accesses through different results +/// of this operation can be safely reordered, as they will never reference +/// overlapping memory locations. +/// +/// Operations with this trait take multiple pointer-like operands +/// and return the same operands with additional non-aliasing guarantees. +/// If the access to the results of this operation aliases at runtime, the +/// behavior of such access is undefined. +template <typename ConcreteType> +class DistinctObjectsTrait + : public TraitBase<ConcreteType, DistinctObjectsTrait> {}; +} // namespace OpTrait + } // namespace mlir #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td index ed213bf..131c1a0 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -414,4 +414,16 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface }]; } +// This trai indicates that pointer-like objects (such as memrefs) returned +// from this operation will never alias with each other. This provides a +// guarantee to optimization passes that accesses through different results +// of this operation can be safely reordered, as they will never reference +// overlapping memory locations. +// +// Operations with this trait take multiple pointer-like operands +// and return the same operands with additional non-aliasing guarantees. +// If the access to the results of this operation aliases at runtime, the +// behavior of such access is undefined. +def DistinctObjectsTrait : NativeOpTrait<"DistinctObjectsTrait">; + #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE diff --git a/mlir/include/mlir/Remark/RemarkStreamer.h b/mlir/include/mlir/Remark/RemarkStreamer.h index 170d6b4..19a70fa 100644 --- a/mlir/include/mlir/Remark/RemarkStreamer.h +++ b/mlir/include/mlir/Remark/RemarkStreamer.h @@ -45,6 +45,7 @@ namespace mlir::remark { /// mlir::emitRemarks. LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef filePath, llvm::remarks::Format fmt, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks = false); } // namespace mlir::remark diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h index 252da21..997aef2 100644 --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -88,7 +88,7 @@ public: /// /// Constraints that do not meet the restriction that they can only reference /// `$_self` and `$_op` are not uniqued. - void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs); + void emitOpConstraints(); /// Unique all compatible type and attribute constraints from a pattern file /// and emit them at the top of the generated file. diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h index 0fbe15f..b739438 100644 --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -44,6 +44,11 @@ enum class RemarkFormat { REMARK_FORMAT_BITSTREAM, }; +enum class RemarkPolicy { + REMARK_POLICY_ALL, + REMARK_POLICY_FINAL, +}; + /// Configuration options for the mlir-opt tool. /// This is intended to help building tools like mlir-opt by collecting the /// supported options. @@ -242,6 +247,8 @@ public: /// Set the reproducer output filename RemarkFormat getRemarkFormat() const { return remarkFormatFlag; } + /// Set the remark policy to use. + RemarkPolicy getRemarkPolicy() const { return remarkPolicyFlag; } /// Set the remark format to use. std::string getRemarksAllFilter() const { return remarksAllFilterFlag; } /// Set the remark output file. @@ -265,6 +272,8 @@ protected: /// Remark format RemarkFormat remarkFormatFlag = RemarkFormat::REMARK_FORMAT_STDOUT; + /// Remark policy + RemarkPolicy remarkPolicyFlag = RemarkPolicy::REMARK_POLICY_ALL; /// Remark file to output to std::string remarksOutputFileFlag = ""; /// Remark filters diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 8062b474..a84d10d 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -258,6 +258,39 @@ getAllocEffectFor(Value value, return success(); } +static Operation *isDistinctObjectsOp(Operation *op) { + if (op && op->hasTrait<OpTrait::DistinctObjectsTrait>()) + return op; + + return nullptr; +} + +static Value getDistinctObjectsOperand(Operation *op, Value value) { + unsigned argNumber = cast<OpResult>(value).getResultNumber(); + return op->getOperand(argNumber); +} + +static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) { + // We should already checked that lhs and rhs are different. + assert(lhs != rhs && "lhs and rhs must be different"); + + // Result and corresponding operand must alias. + auto lhsOp = isDistinctObjectsOp(lhs.getDefiningOp()); + if (lhsOp && getDistinctObjectsOperand(lhsOp, lhs) == rhs) + return AliasResult::MustAlias; + + auto rhsOp = isDistinctObjectsOp(rhs.getDefiningOp()); + if (rhsOp && getDistinctObjectsOperand(rhsOp, rhs) == lhs) + return AliasResult::MustAlias; + + // If two different values come from the same `DistinctObjects` operation, + // they don't alias. + if (lhsOp && lhsOp == rhsOp) + return AliasResult::NoAlias; + + return std::nullopt; +} + /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { if (lhs == rhs) @@ -289,6 +322,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { : AliasResult::MayAlias; } + if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs)) + return *result; + // Otherwise, neither of the values are constant so check to see if either has // an allocation effect. bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 609cb34..db10ebc 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp + DataFlow/StridedMetadataRangeAnalysis.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis @@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis MLIRDataLayoutInterfaces MLIRFunctionInterfaces MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRPresburger diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp new file mode 100644 index 0000000..01c9daf --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp @@ -0,0 +1,127 @@ +//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "strided-metadata-range-analysis" + +using namespace mlir; +using namespace mlir::dataflow; + +/// Get the entry state for a value. For any value that is not a ranked memref, +/// this function sets the metadata to a top state with no offsets, sizes, or +/// strides. For `memref` types, this function will use the metadata in the type +/// to try to deduce as much informaiton as possible. +static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) { + // TODO: generalize this method with a type interface. + auto mTy = dyn_cast<BaseMemRefType>(v.getType()); + + // If not a memref or it's un-ranked, don't infer any metadata. + if (!mTy || !mTy.hasRank()) + return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0); + + // Get the top state. + auto metadata = + StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank()); + + // Compute the offset and strides. + int64_t offset; + SmallVector<int64_t> strides; + if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset))) + return metadata; + + // Refine the metadata if we know it from the type. + if (!ShapedType::isDynamic(offset)) { + metadata.getOffsets()[0] = + ConstantIntRanges::constant(APInt(indexBitwidth, offset)); + } + for (auto &&[size, range] : + llvm::zip_equal(mTy.getShape(), metadata.getSizes())) { + if (ShapedType::isDynamic(size)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, size)); + } + for (auto &&[stride, range] : + llvm::zip_equal(strides, metadata.getStrides())) { + if (ShapedType::isDynamic(stride)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, stride)); + } + + return metadata; +} + +StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis( + DataFlowSolver &solver, int32_t indexBitwidth) + : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) { + assert(indexBitwidth > 0 && "invalid bitwidth"); +} + +void StridedMetadataRangeAnalysis::setToEntryState( + StridedMetadataRangeLattice *lattice) { + propagateIfChanged(lattice, lattice->join(getEntryStateImpl( + lattice->getAnchor(), indexBitwidth))); +} + +LogicalResult StridedMetadataRangeAnalysis::visitOperation( + Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands, + ArrayRef<StridedMetadataRangeLattice *> results) { + auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op); + + // Bail if we cannot reason about the op. + if (!inferrable) { + setAllToEntryStates(results); + return success(); + } + + LDBG() << "Inferring metadata for: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + + // Helper function to retrieve int range values. + auto getIntRange = [&](Value value) -> IntegerValueRange { + auto lattice = getOrCreateFor<IntegerValueRangeLattice>( + getProgramPointAfter(op), value); + return lattice ? lattice->getValue() : IntegerValueRange(); + }; + + // Convert the arguments lattices to a vector. + SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector( + operands, [](const StridedMetadataRangeLattice *lattice) { + return lattice->getValue(); + }); + + // Callback to set metadata on a result. + auto joinCallback = [&](Value v, const StridedMetadataRange &md) { + auto result = cast<OpResult>(v); + assert(llvm::is_contained(op->getResults(), result)); + LDBG() << "- Inferred metadata: " << md; + StridedMetadataRangeLattice *lattice = results[result.getResultNumber()]; + ChangeResult changed = lattice->join(md); + LDBG() << "- Joined metadata: " << lattice->getValue(); + propagateIfChanged(lattice, changed); + }; + + // Infer the metadata. + inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback, + indexBitwidth); + return success(); +} diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 30ce1fb..6588b53 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -1244,8 +1244,9 @@ bool FlatLinearValueConstraints::areVarsAlignedWithOther( /// Checks if the SSA values associated with `cst`'s variables in range /// [start, end) are unique. -static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( - const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { +[[maybe_unused]] static bool +areVarsUnique(const FlatLinearValueConstraints &cst, unsigned start, + unsigned end) { assert(start <= cst.getNumDimAndSymbolVars() && "Start position out of bounds"); @@ -1267,14 +1268,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( } /// Checks if the SSA values associated with `cst`'s variables are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst) { return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars()); } /// Checks if the SSA values associated with `cst`'s variables of kind `kind` /// are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { if (kind == VarKind::SetDim) diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index a1cbe29..547a4c2 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -34,7 +34,7 @@ using Direction = Simplex::Direction; const int nullIndex = std::numeric_limits<int>::max(); // Return a + scale*b; -LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static SmallVector<DynamicAPInt, 8> scaleAndAddForAssert(ArrayRef<DynamicAPInt> a, const DynamicAPInt &scale, ArrayRef<DynamicAPInt> b) { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7b17106..06d0256 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2730,6 +2730,17 @@ public: operation->get(), toMlirStringRef(name))); } + static void + forEachAttr(MlirOperation op, + llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) { + intptr_t n = mlirOperationGetNumAttributes(op); + for (intptr_t i = 0; i < n; ++i) { + MlirNamedAttribute na = mlirOperationGetAttribute(op, i); + MlirStringRef name = mlirIdentifierStr(na.name); + fn(name, na.attribute); + } + } + static void bind(nb::module_ &m) { nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") .def("__contains__", &PyOpAttributeMap::dunderContains) @@ -2737,7 +2748,50 @@ public: .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) .def("__setitem__", &PyOpAttributeMap::dunderSetItem) - .def("__delitem__", &PyOpAttributeMap::dunderDelItem); + .def("__delitem__", &PyOpAttributeMap::dunderDelItem) + .def("__iter__", + [](PyOpAttributeMap &self) { + nb::list keys; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + keys.append(nb::str(name.data, name.length)); + }); + return nb::iter(keys); + }) + .def("keys", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + out.append(nb::str(name.data, name.length)); + }); + return out; + }) + .def("values", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef, MlirAttribute attr) { + out.append(PyAttribute(self.operation->getContext(), attr) + .maybeDownCast()); + }); + return out; + }) + .def("items", [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute attr) { + out.append(nb::make_tuple( + nb::str(name.data, name.length), + PyAttribute(self.operation->getContext(), attr) + .maybeDownCast())); + }); + return out; + }); } private: diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 5ddb3fb..0f0ed22 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -205,7 +205,7 @@ public: nb::object res = f(opView, PyPatternRewriter(rewriter)); return logicalResultFromObject(res); }; - MlirRewritePattern pattern = mlirOpRewritePattenCreate( + MlirRewritePattern pattern = mlirOpRewritePatternCreate( rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 46c329d..41ceb15 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -341,7 +341,7 @@ private: } // namespace mlir -MlirRewritePattern mlirOpRewritePattenCreate( +MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) { diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 71986f8..bebf1b8 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) +add_subdirectory(MathToXeVM) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index f0d8b78..610ce1f 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { if (auto vectorType = dyn_cast<VectorType>(operandType)) nanAttr = DenseElementsAttr::get(vectorType, nan); - Value NanValue = + Value nanValue = spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); Value lhs = spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, - NanValue, adaptor.getLhs()); + nanValue, adaptor.getLhs()); Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt new file mode 100644 index 0000000..050c0ed --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(MLIRMathToXeVM + MathToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithAttrToLLVMConversion + MLIRArithDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRMathDialect + MLIRXeVMDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp new file mode 100644 index 0000000..0fe31d0 --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -0,0 +1,167 @@ +//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/FormatVariadic.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-xevm" + +/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. +template <typename Op> +struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> { + + ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, + PatternBenefit benefit = 1) + : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {} + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isSPIRVCompatibleFloatOrVec(op.getType())) + return failure(); + + arith::FastMathFlags fastFlags = op.getFastmath(); + if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn)) + return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); + + SmallVector<Type, 1> operandTypes; + for (auto operand : adaptor.getOperands()) { + Type opTy = operand.getType(); + // This pass only supports operations on vectors that are already in SPIRV + // supported vector sizes: Distributing unsupported vector sizes to SPIRV + // supported vector sizes are done in other blocking optimization passes. + if (!isSPIRVCompatibleFloatOrVec(opTy)) + return rewriter.notifyMatchFailure( + op, llvm::formatv("incompatible operand type: '{0}'", opTy)); + operandTypes.push_back(opTy); + } + + auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>(); + auto funcOpRes = LLVM::lookupOrCreateFn( + rewriter, moduleOp, getMangledNativeFuncName(operandTypes), + operandTypes, op.getType()); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + + auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( + op, funcOp, adaptor.getOperands()); + // Preserve fastmath flags in our MLIR op when converting to llvm function + // calls, in order to allow further fastmath optimizations: We thus need to + // convert arith fastmath attrs into attrs recognized by llvm. + arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op); + mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; + callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); + return success(); + } + + inline bool isSPIRVCompatibleFloatOrVec(Type type) const { + if (type.isFloat()) + return true; + if (auto vecType = dyn_cast<VectorType>(type)) { + if (!vecType.getElementType().isFloat()) + return false; + // SPIRV distinguishes between vectors and matrices: OpenCL native math + // intrsinics are not compatible with matrices. + ArrayRef<int64_t> shape = vecType.getShape(); + if (shape.size() != 1) + return false; + // SPIRV only allows vectors of size 2, 3, 4, 8, 16. + if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || + shape[0] == 16) + return true; + } + return false; + } + + inline std::string + getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const { + std::string mangledFuncName = + "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); + + auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { + if (type.isF32()) + mangledFuncName += "f"; + else if (type.isF16()) + mangledFuncName += "Dh"; + else if (type.isF64()) + mangledFuncName += "d"; + }; + + for (auto type : operandTypes) { + if (auto vecType = dyn_cast<VectorType>(type)) { + mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; + appendFloatToMangledFunc(vecType.getElementType()); + } else + appendFloatToMangledFunc(type); + } + + return mangledFuncName; + } + + const StringRef nativeFunc; +}; + +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith) { + patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(), + "__spirv_ocl_native_exp"); + patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(), + "__spirv_ocl_native_cos"); + patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>( + patterns.getContext(), "__spirv_ocl_native_exp2"); + patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(), + "__spirv_ocl_native_log"); + patterns.add<ConvertNativeFuncPattern<math::Log2Op>>( + patterns.getContext(), "__spirv_ocl_native_log2"); + patterns.add<ConvertNativeFuncPattern<math::Log10Op>>( + patterns.getContext(), "__spirv_ocl_native_log10"); + patterns.add<ConvertNativeFuncPattern<math::PowFOp>>( + patterns.getContext(), "__spirv_ocl_native_powr"); + patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_rsqrt"); + patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(), + "__spirv_ocl_native_sin"); + patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_sqrt"); + patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(), + "__spirv_ocl_native_tan"); + if (convertArith) + patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>( + patterns.getContext(), "__spirv_ocl_native_divide"); +} + +namespace { +struct ConvertMathToXeVMPass + : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToXeVMPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateMathToXeVMConversionPatterns(patterns, convertArith); + ConversionTarget target(getContext()); + target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>(); + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a5336ed..00df14b1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1392,6 +1392,137 @@ public: } }; +// Collapse tensor<1xiN> into tensor<iN> +// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16> +static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input, + Location loc) { + SmallVector<ReassociationExprs, 1> reassociation; + // Create the collapsed type + auto inputType = cast<RankedTensorType>(input.getType()); + auto elemType = inputType.getElementType(); + auto collapsedType = RankedTensorType::get({}, elemType); + // Emit the collapse op + return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input, + reassociation); +} + +static llvm::SmallVector<int8_t> +convertToI8(const llvm::SmallVector<int32_t> &input) { + llvm::SmallVector<int8_t> output; + output.reserve(input.size()); + + for (auto v : llvm::map_range( + input, [](int32_t val) { return static_cast<int8_t>(val); })) { + output.push_back(v); + } + return output; +} + +// The shift or multiplier may be either constant or non-constant, depending on +// whether dynamic extension is enabled. +// - If the shift or multiplier is non-constant, add it as an input to +// linalg::GenericOp by: +// 1. Pushing it into 'genericInputs'. +// 2. Appending a corresponding affine map to 'indexingMaps'. +// - If the shift or multiplier is constant, set 'constant' instead. +static void setupLinalgGenericOpInputAndIndexingMap( + PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values, + SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps, + bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg, + bool isShift = false) { + + auto loc = op.getLoc(); + auto inputTy = cast<ShapedType>(op.getInput().getType()); + unsigned rank = inputTy.getRank(); + SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)}; + + if (isConstant) { + // If we are rescaling per-channel then we need to store the + // values in a buffer. + if (values.size() == 1) { + IntegerAttr intAttr = isShift + ? rewriter.getI8IntegerAttr(values.front()) + : rewriter.getI32IntegerAttr(values.front()); + constant = rewriter.create<arith::ConstantOp>(loc, intAttr); + } else { + auto elementType = + isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type(); + auto tensorType = RankedTensorType::get( + {static_cast<int64_t>(values.size())}, elementType); + DenseIntElementsAttr EltAttr; + if (isShift) + EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values)); + else + EltAttr = DenseIntElementsAttr::get(tensorType, values); + genericInputs.push_back( + arith::ConstantOp::create(rewriter, loc, EltAttr)); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } else { + // If we are not rescaling per-channel then we need to collapse 1xN to N + // and push broadcastMap. + auto operand = isShift ? op.getShift() : op.getMultiplier(); + auto tensorType = dyn_cast<RankedTensorType>(operand.getType()); + if (tensorType && tensorType.hasStaticShape() && + tensorType.getShape()[0] == 1) { + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = + AffineMap::get(rank, 0, {}, rewriter.getContext()); + genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc)); + indexingMaps.push_back(broadcastMap); + } else { + genericInputs.push_back(operand); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } + arg = indexingMaps.size() - 1; +} + +// Return the extended Zp to be used in subsequent arithmetic operations. +static Value getExtendZp(OpBuilder &builder, Type valueTy, + FailureOr<int64_t> maybeZp, Location loc, + ValueRange blockArgs, int64_t zpArg, + bool isOutputZp = false) { + Value result; + const int32_t bitwidth = valueTy.getIntOrFloatBitWidth(); + const uint32_t attrBitwidth = + isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32); + auto extendType = builder.getIntegerType(attrBitwidth); + // The Zp value can be either constant or non-constant, depending on + // whether dynamic extension is enabled. + // If 'maybeZp' fails, it indicates that Zp is non-constant and will + // be passed as an input to linalg::GenericOp. + if (failed(maybeZp)) { + result = blockArgs[zpArg]; + auto zpTy = result.getType(); + if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) { + // For ExtUIOp, the input must be signless. + // UnrealizedConversionCastOp will cast the input to signless type. + if (zpTy.isUnsignedInteger()) { + result = + UnrealizedConversionCastOp::create( + builder, loc, + builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result) + .getResult(0); + } + if (zpTy.isUnsignedInteger()) { + return builder.create<arith::ExtUIOp>(loc, extendType, result); + } else { + return builder.create<arith::ExtSIOp>(loc, extendType, result); + } + } + } else { + return builder.create<arith::ConstantOp>( + loc, IntegerAttr::get(extendType, *maybeZp)); + } + return result; +} + class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { public: using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; @@ -1423,40 +1554,46 @@ public: } } - // The shift and multiplier values. DenseElementsAttr shiftElems; - if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant shift input values"); + bool isShiftConstant = false; + if (matchPattern(op.getShift(), m_Constant(&shiftElems))) + isShiftConstant = true; DenseElementsAttr multiplierElems; - if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant multiplier input values"); - - llvm::SmallVector<int8_t> shiftValues = - llvm::to_vector(shiftElems.getValues<int8_t>()); - // explicit cast is required here - llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector( - llvm::map_range(multiplierElems.getValues<IntegerAttr>(), - [](IntegerAttr attr) -> int32_t { - return static_cast<int32_t>(attr.getInt()); - })); - - // If we shift by more than the bitwidth, this just sets to 0. - for (int i = 0, s = multiplierValues.size(); i < s; i++) { - if (shiftValues[i] > 63) { - shiftValues[i] = 0; - multiplierValues[i] = 0; + bool isMultiplierConstant = false; + if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) + isMultiplierConstant = true; + + llvm::SmallVector<int32_t> shiftValues; + llvm::SmallVector<int32_t> multiplierValues; + bool doubleRound; + + if (isMultiplierConstant && isShiftConstant) { + // explicit cast is required here + shiftValues = llvm::to_vector(llvm::map_range( + shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + multiplierValues = llvm::to_vector( + llvm::map_range(multiplierElems.getValues<IntegerAttr>(), + [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + + // If we shift by more than the bitwidth, this just sets to 0. + for (int i = 0, s = multiplierValues.size(); i < s; i++) { + if (shiftValues[i] > 63) { + shiftValues[i] = 0; + multiplierValues[i] = 0; + } } - } + // Double round only occurs if shift is greater than 31, check that this + // is ever true. + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && + llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); + } else + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND; - // Double round only occurs if shift is greater than 31, check that this - // is ever true. - - bool doubleRound = - op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && - llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); RoundingMode roundingMode = doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND; @@ -1468,45 +1605,43 @@ public: // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; - if (multiplierValues.size() == 1) { - multiplierConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); - } else { - SmallVector<AffineExpr, 2> multiplierExprs{ - rewriter.getAffineDimExpr(rank - 1)}; - auto multiplierType = - RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, - rewriter.getI32Type()); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, - DenseIntElementsAttr::get(multiplierType, multiplierValues))); - - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, multiplierExprs, - rewriter.getContext())); - - multiplierArg = indexingMaps.size() - 1; - } + setupLinalgGenericOpInputAndIndexingMap( + rewriter, multiplierValues, genericInputs, indexingMaps, + isMultiplierConstant, op, multiplierConstant, multiplierArg); // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; - if (shiftValues.size() == 1) { - shiftConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); - } else { - SmallVector<AffineExpr, 2> shiftExprs = { - rewriter.getAffineDimExpr(rank - 1)}; - auto shiftType = - RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, - rewriter.getIntegerType(8)); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, shiftExprs, - rewriter.getContext())); - shiftArg = indexingMaps.size() - 1; + setupLinalgGenericOpInputAndIndexingMap( + rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op, + shiftConstant, shiftArg, true); + + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext()); + FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); + FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); + // The inputZp and outputZp may be either constant or non-constant, + // depending on whether dynamic extension is enabled. + // - If the zp's are non-constant, add them as an inputs to + // linalg::GenericOp by: + // 1. Pushing it into 'genericInputs'. + // 2. Appending a corresponding affine map to 'indexingMaps'. + // - If the zp's are constant, they would be generated as arith.constant. + int64_t iZpArg = 0; + if (failed(maybeIZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(3), loc)); + indexingMaps.push_back(broadcastMap); + iZpArg = indexingMaps.size() - 1; + } + int64_t oZpArg = 0; + if (failed(maybeOZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(4), loc)); + indexingMaps.push_back(broadcastMap); + oZpArg = indexingMaps.size() - 1; } // Indexing maps for output values. @@ -1526,36 +1661,17 @@ public: Type valueTy = value.getType(); FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); - if (failed(maybeIZp)) { - (void)rewriter.notifyMatchFailure( - op, "input zero point cannot be statically determined"); - return; - } - - const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); - // Extend zeropoint for sub-32bits widths. - const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), - *maybeIZp)); + auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp, + nestedLoc, blockArgs, iZpArg); FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); - if (failed(maybeOZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return; - }; + auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp, + nestedLoc, blockArgs, oZpArg, true); IntegerType outIntType = cast<IntegerType>(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); - const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), - *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5355909..41d8d53 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle`. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + // For 1-d vector, we additionally do a `shufflevector`. int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0); SmallVector<int32_t> zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison, - zeroValues); + auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>( + broadcast.getLoc(), v, poison, zeroValues); + rewriter.replaceOp(broadcast, shuffle); return success(); } }; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt index 84b2580..dd9edc4 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM MLIRIndexDialect MLIRSCFDialect MLIRXeGPUDialect + MLIRXeGPUUtils MLIRPass MLIRTransforms MLIRSCFTransforms diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 71687b1..fcbf66d 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -20,7 +20,9 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -62,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { case xegpu::MemorySpace::SLM: return static_cast<int>(xevm::AddrSpace::SHARED); } + llvm_unreachable("Unknown XeGPU memory space"); } // Get same bitwidth flat vector type of new element type. @@ -185,6 +188,7 @@ class CreateNdDescToXeVMPattern int64_t rank = mixedSizes.size(); if (rank != 2) return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -363,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { // Add a builder that creates // offset * elemByteSize + baseAddr -static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, - Value baseAddr, Value offset, int64_t elemByteSize) { +static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter, + Location loc, Value baseAddr, Value offset, + int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI64Type(), elemByteSize); + rewriter, loc, baseAddr.getType(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -390,7 +395,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // Load result or Store valye Type can be vector or scalar. Type valOrResTy; if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) - valOrResTy = op.getResult().getType(); + valOrResTy = + this->getTypeConverter()->convertType(op.getResult().getType()); else valOrResTy = adaptor.getValue().getType(); VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); @@ -441,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // If offset is provided, we add them to the base pointer. // Offset is in number of elements, we need to multiply by // element byte size. - basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); + basePtrI64 = + addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -504,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { } }; +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than +// 32 bits will be converted to 32 bits. +class CreateMemDescOpPattern final + : public OpConversionPattern<xegpu::CreateMemDescOp> { +public: + using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto resTy = op.getMemDesc(); + + // Create the result MemRefType with the same shape, element type, and + // memory space + auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); + + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, + op.getSource(), zero, ValueRange()); + rewriter.replaceOp(op, viewOp); + return success(); + } +}; + +template <typename OpType, + typename = std::enable_if_t<llvm::is_one_of< + OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> +class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<OpFoldResult> offsets = op.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); + + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + Value basePtrStruct = adaptor.getMemDesc(); + Value mdescVal = op.getMemDesc(); + // Load result or Store value Type can be vector or scalar. + Value data; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) + data = op.getResult(); + else + data = adaptor.getData(); + VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); + if (!valOrResVecTy) + valOrResVecTy = VectorType::get(1, data.getType()); + + int64_t elemBitWidth = + valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + + // Default memory space is SLM. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); + + auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); + + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, basePtrStruct); + + // Convert base pointer (ptr) to i32 + Value basePtrI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), basePtrLLVM); + + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + linearOffset = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), linearOffset); + basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, + elemByteSize); + + // convert base pointer (i32) to LLVM pointer type + basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); + + if (op.getSubgroupBlockIoAttr()) { + // if the attribute 'subgroup_block_io' is set to true, it lowers to + // xevm.blockload + + Type intElemTy = rewriter.getIntegerType(elemBitWidth); + VectorType intVecTy = + VectorType::get(valOrResVecTy.getShape(), intElemTy); + + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + Value loadOp = + xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); + if (intVecTy != valOrResVecTy) { + loadOp = + vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); + } + rewriter.replaceOp(op, loadOp); + } else { + Value dataToStore = adaptor.getData(); + if (valOrResVecTy != intVecTy) { + dataToStore = + vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); + } + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, + nullptr); + rewriter.eraseOp(op); + } + return success(); + } + + if (valOrResVecTy.getNumElements() >= 1) { + auto chipOpt = xegpu::getChipStr(op); + if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { + // the lowering for chunk load only works for pvc and bmg + return rewriter.notifyMatchFailure( + op, "The lowering is specific to pvc or bmg."); + } + } + + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + // if the size of valOrResVecTy is 1, it lowers to a scalar load/store + // operation. LLVM load/store does not support vector of size 1, so we + // need to handle this case separately. + auto scalarTy = valOrResVecTy.getElementType(); + LLVM::LoadOp loadOp; + if (valOrResVecTy.getNumElements() == 1) + loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); + else + loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } + return success(); + } +}; + class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -546,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { op, "Expected element type bit width to be multiple of 8."); elemByteSize = elemBitWidth / 8; } - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets, + elemByteSize); } } // Default memory space is global. @@ -784,6 +932,13 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); + // Convert MemDescType into flattened MemRefType for SLM + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + Type elemTy = type.getElementType(); + int numElems = type.getNumElements(); + return MemRefType::get(numElems, elemTy, AffineMap(), 3); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); @@ -878,10 +1033,30 @@ struct ConvertXeGPUToXeVMPass } return {}; }; - typeConverter.addSourceMaterialization(memrefMaterializationCast); - typeConverter.addSourceMaterialization(ui64MaterializationCast); - typeConverter.addSourceMaterialization(ui32MaterializationCast); - typeConverter.addSourceMaterialization(vectorMaterializationCast); + + // If result type of original op is single element vector and lowered type + // is scalar. This materialization cast creates a single element vector by + // broadcasting the scalar value. + auto singleElementVectorMaterializationCast = + [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType().isIntOrIndexOrFloat()) { + // If the input is a scalar, and the target type is a vector of single + // element, create a single element vector by broadcasting. + if (auto vecTy = dyn_cast<VectorType>(type)) { + if (vecTy.getNumElements() == 1) { + return vector::BroadcastOp::create(builder, loc, vecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization( + singleElementVectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); @@ -918,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns( LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( typeConverter, patterns.getContext()); + patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>, + LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>, + CreateMemDescOpPattern>(typeConverter, patterns.getContext()); patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index f405d0c..c798adb 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -757,13 +757,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> { offset = numElements - 4l; } Type scaleSrcElemType = scaleSrcType.getElementType(); - auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}), - scaleSrcElemType); + auto newSrcType = + VectorType::get(ArrayRef{numElements}, scaleSrcElemType); Value newScaleSrc = vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); auto extract = vector::ExtractStridedSliceOp::create( - rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset}, - ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1}); + rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, + ArrayRef{int64_t(1)}); rewriter.modifyOpInPlace(op, [&] { op->setOperand(opIdx, extract); setOpsel(opIdx, opsel); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 68990ef..d9c097c 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } +/// Returns stride expressed in number of bytes for the given `elementStride` +/// stride encoded in number of elements of the type `mType`. +static Value computeStrideInBytes(Location loc, MemRefType mType, + Value elementStride, RewriterBase &rewriter) { + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) + .getResult(); +} + /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -static Value getStride(Location loc, MemRefType mType, Value base, - RewriterBase &rewriter) { +static Value inferStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = rewriter.getIntegerType(64); @@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base, if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); + return computeStrideInBytes( + loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); @@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands, return getTileSizes(getLoc(), getTileType(), rewriter); } -LogicalResult amx::TileLoadOp::verify() { - MemRefType memrefTy = getMemRefType(); +template <typename OpTy, + typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> || + std::is_same_v<OpTy, amx::TileStoreOp>>> +static LogicalResult tileTransferVerifier(OpTy op) { + MemRefType memrefTy = op.getMemRefType(); unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); + if (op.getIndices().size() != rank) + return op.emitOpError("requires ") << rank << " indices"; + + if (failed(verifyTileSize(op, op.getTileType()))) + return failure(); + + // Validate basic buffer properties when the stride is implicit. + if (!op.getStride()) { + if (rank < 2) + return op.emitOpError("requires at least 2D memref"); + SmallVector<int64_t> strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return op.emitOpError("requires memref with unit innermost stride"); + } + + return success(); +} + +void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res, + Value base, ValueRange indices) { + build(builder, state, res, base, indices, /*stride=*/nullptr); } +LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); return intrinsicOperands; } -LogicalResult amx::TileStoreOp::verify() { - MemRefType memrefTy = getMemRefType(); - unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); +void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange indices, Value val) { + build(builder, state, base, indices, val, /*stride=*/nullptr); } +LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); intrinsicOperands.push_back(adaptor.getVal()); return intrinsicOperands; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 7e5ce26..749e2ba 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -125,9 +125,9 @@ static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, // Use "unused attribute" marker to silence clang-tidy warning stemming from // the inability to see through "llvm::TypeSwitch". template <> -bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op, - Region *src, Region *dest, - const IRMapping &mapping) { +[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src, + Region *dest, + const IRMapping &mapping) { // If it's a valid dimension, we need to check that it remains so. if (isValidDim(op.getResult(), src)) return remainsLegalAfterInline( @@ -1032,8 +1032,8 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map, /// Simplify the map while exploiting information on the values in `operands`. // Use "unused attribute" marker to silence warning stemming from the inability // to see through the template expansion. -static void LLVM_ATTRIBUTE_UNUSED -simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) { +[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map, + ArrayRef<Value> operands) { assert(map.getNumInputs() == operands.size() && "invalid operands for map"); SmallVector<AffineExpr> newResults; newResults.reserve(map.getNumResults()); @@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, return success(*map != initialMap); } +/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form +/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`, +/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove) +/// into `replacementsMap`. If no entries were added to `replacementsMap`, +/// nothing was found. +static void shortenAddChainsContainingAll( + AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove, + AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) { + auto binOp = dyn_cast<AffineBinaryOpExpr>(e); + if (!binOp) + return; + AffineExpr lhs = binOp.getLHS(); + AffineExpr rhs = binOp.getRHS(); + if (binOp.getKind() != AffineExprKind::Add) { + shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap); + shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap); + return; + } + SmallVector<AffineExpr> toPreserve; + llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove); + AffineExpr thisTerm = rhs; + AffineExpr nextTerm = lhs; + + while (thisTerm) { + if (!ourTracker.erase(thisTerm)) { + toPreserve.push_back(thisTerm); + shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal, + replacementsMap); + } + auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm); + if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) { + thisTerm = nextTerm; + nextTerm = AffineExpr(); + } else { + thisTerm = nextBinOp.getRHS(); + nextTerm = nextBinOp.getLHS(); + } + } + if (!ourTracker.empty()) + return; + // We reverse the terms to be preserved here in order to preserve + // associativity between them. + AffineExpr newExpr = newVal; + for (AffineExpr preserved : llvm::reverse(toPreserve)) + newExpr = newExpr + preserved; + replacementsMap.insert({e, newExpr}); +} + +/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N + +/// ...` (not necessarily in order) where the set of the `x_i` is the set of +/// outputs of an `affine.delinearize_index` whos inverse is that expression, +/// replace that expression with the input of that delinearize_index op. +/// +/// `unitDimInput` is the input that was detected as the potential start to this +/// replacement chain - if it isn't the rightmost result of the delinearization, +/// this method fails. (This is intended to ensure we don't have redundant scans +/// over the same expression). +/// +/// While this currently only handles delinearizations with a constant basis, +/// that isn't a fundamental limitation. +/// +/// This is a utility function for `replaceDimOrSym` below. +static LogicalResult replaceAffineDelinearizeIndexInverseExpression( + AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, + SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) { + if (!delinOp.getDynamicBasis().empty()) + return failure(); + if (resultToReplace != delinOp.getMultiIndex().back()) + return failure(); + + MLIRContext *ctx = delinOp.getContext(); + SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr()); + for (auto [pos, dim] : llvm::enumerate(dims)) { + auto asResult = dyn_cast_if_present<OpResult>(dim); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx); + } + for (auto [pos, sym] : llvm::enumerate(syms)) { + auto asResult = dyn_cast_if_present<OpResult>(sym); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx); + } + if (llvm::is_contained(resToExpr, AffineExpr())) + return failure(); + + bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>); + int64_t stride = 1; + llvm::SmallDenseSet<AffineExpr, 4> expectedExprs; + // This isn't zip_equal since sometimes the delinearize basis is missing a + // size for the first result. + for (auto [binding, size] : llvm::zip( + llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) { + expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx)); + stride *= size; + } + if (resToExpr.size() != delinOp.getStaticBasis().size()) + expectedExprs.insert(resToExpr[0] * stride); + + DenseMap<AffineExpr, AffineExpr> replacements; + AffineExpr delinInExpr = isDimReplacement + ? getAffineDimExpr(dims.size(), ctx) + : getAffineSymbolExpr(syms.size(), ctx); + + for (AffineExpr e : map->getResults()) + shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements); + if (replacements.empty()) + return failure(); + + AffineMap origMap = *map; + if (isDimReplacement) + dims.push_back(delinOp.getLinearIndex()); + else + syms.push_back(delinOp.getLinearIndex()); + *map = origMap.replace(replacements, dims.size(), syms.size()); + + // Blank out dead dimensions and symbols + for (AffineExpr e : resToExpr) { + if (auto d = dyn_cast<AffineDimExpr>(e)) { + unsigned pos = d.getPosition(); + if (!map->isFunctionOfDim(pos)) + dims[pos] = nullptr; + } + if (auto s = dyn_cast<AffineSymbolExpr>(e)) { + unsigned pos = s.getPosition(); + if (!map->isFunctionOfSymbol(pos)) + syms[pos] = nullptr; + } + } + return success(); +} + /// Replace all occurrences of AffineExpr at position `pos` in `map` by the /// defining AffineApplyOp expression and operands. /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. @@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map, syms); } + if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) { + return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims, + syms); + } + auto affineApply = v.getDefiningOp<AffineApplyOp>(); if (!affineApply) return failure(); diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index cd216ef..4743941 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1357,7 +1357,7 @@ bool mlir::affine::isValidLoopInterchangePermutation( /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] bool mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) { assert(!loops.empty() && "no loops provided"); @@ -1920,8 +1920,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, return copyNestRoot; } -static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED -emitRemarkForBlock(Block &block) { +[[maybe_unused]] static InFlightDiagnostic emitRemarkForBlock(Block &block) { return block.getParentOp()->emitRemark(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index 624519f..70faa71 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -64,12 +64,13 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { module.walk([&](func::CallOp callOp) { if (func::FuncOp calledFunc = dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { - callerMap[calledFunc].insert(callOp); + if (!calledFunc.isPublic() && !calledFunc.isExternal()) + callerMap[calledFunc].insert(callOp); } }); for (auto funcOp : module.getOps<func::FuncOp>()) { - if (funcOp.isExternal()) + if (funcOp.isExternal() || funcOp.isPublic()) continue; func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); // TODO: Support functions with multiple blocks. diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ec581ac..cc66fac 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect IR/LLVMMemorySlot.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp + IR/LLVMDialectBytecode.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR DEPENDS + MLIRLLVMDialectBytecodeIncGen MLIRLLVMOpsIncGen MLIRLLVMTypesIncGen MLIRLLVMIntrinsicOpsIncGen diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5d08ccc..3eae67f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -29,6 +29,8 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/Error.h" +#include "LLVMDialectBytecode.h" + #include <numeric> #include <optional> @@ -2824,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() { return success(); } +// Folding for shufflevector op when v1 is single element 1D vector +// and the mask is a single zero. OpFoldResult will be v1 in this case. +OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) { + // Check if operand 0 is a single element vector. + auto vecType = llvm::dyn_cast<VectorType>(getV1().getType()); + if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1) + return {}; + // Check if the mask is a single zero. + // Note: The mask is guaranteed to be non-empty. + if (getMask().size() != 1 || getMask()[0] != 0) + return {}; + return getV1(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// @@ -4237,6 +4253,7 @@ void LLVMDialect::initialize() { // Support unknown operations because not all LLVM operations are registered. allowUnknownOperations(); declarePromisedInterface<DialectInlinerInterface, LLVMDialect>(); + detail::addBytecodeInterface(this); } #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp new file mode 100644 index 0000000..41d1f80 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp @@ -0,0 +1,154 @@ +//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMDialectBytecode.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <type_traits> + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +// Provide some forward declarations of the functions that will be generated by +// the include below. +static void write(DIExpressionElemAttr attribute, + DialectBytecodeWriter &writer); +static LogicalResult writeAttribute(Attribute attribute, + DialectBytecodeWriter &writer); + +//===--------------------------------------------------------------------===// +// Optional ArrayRefs +// +// Note that both the writer and reader functions consider attributes to be +// optional. This is because the attribute may be present or empty. +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalArrayRef(DialectBytecodeWriter &writer, + ArrayRef<EntryTy> storage) { + if (storage.empty()) { + writer.writeOwnedBool(false); + return; + } + + writer.writeOwnedBool(true); + writer.writeList(storage, [&](EntryTy val) { + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + (void)writer.writeOptionalAttribute(val); + } else if constexpr (std::is_integral_v<EntryTy>) { + (void)writer.writeVarInt(val); + } else { + static_assert(true, "EntryTy not supported"); + } + }); +} + +template <class EntryTy> +static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader, + SmallVectorImpl<EntryTy> &storage) { + bool isPresent = false; + if (failed(reader.readBool(isPresent))) + return failure(); + // Nothing to do here, the array is empty. + if (!isPresent) + return success(); + + auto readEntry = [&]() -> FailureOr<EntryTy> { + EntryTy temp; + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + if (succeeded(reader.readOptionalAttribute(temp))) + return temp; + } else if constexpr (std::is_integral_v<EntryTy>) { + if (succeeded(reader.readVarInt(temp))) + return temp; + } else { + static_assert(true, "EntryTy not supported"); + } + return failure(); + }; + + return reader.readList(storage, readEntry); +} + +//===--------------------------------------------------------------------===// +// Optional integral types +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalInt(DialectBytecodeWriter &writer, + std::optional<EntryTy> storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + EntryTy val = storage.value_or(0); + writer.writeVarIntWithFlag(val, storage.has_value()); +} + +template <class EntryTy> +static LogicalResult readOptionalInt(DialectBytecodeReader &reader, + std::optional<EntryTy> &storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + uint64_t result = 0; + bool flag = false; + if (failed(reader.readVarIntWithFlag(result, flag))) + return failure(); + if (flag) + storage = static_cast<EntryTy>(result); + else + storage = std::nullopt; + return success(); +} + +//===--------------------------------------------------------------------===// +// Tablegen generated bytecode functions +//===--------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc" + +//===--------------------------------------------------------------------===// +// LLVMDialectBytecodeInterface +//===--------------------------------------------------------------------===// + +/// This class implements the bytecode interface for the LLVM dialect. +struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface { + LLVMDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + // Attributes + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } + + // Types + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } +}; +} // namespace + +void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) { + dialect->addInterfaces<LLVMDialectBytecodeInterface>(); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h new file mode 100644 index 0000000..1a17cb4 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h @@ -0,0 +1,27 @@ +//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the LLVM dialect bytecode +// implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H +#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H + +namespace mlir::LLVM { +class LLVMDialect; + +namespace detail { +/// Add the interfaces necessary for encoding the LLVM dialect components in +/// bytecode. +void addBytecodeInterface(LLVMDialect *dialect); +} // namespace detail +} // namespace mlir::LLVM + +#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 01a16ce..ac35eea 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -134,10 +134,10 @@ static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams, /// These are unused for now. /// TODO: Move over to these once more types have been migrated to TypeDef. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 5edcc40b..2a8c330 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF4x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f32x2 to f4x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() { " attribute"); } + // Validate layout combinations. According to the operation description, most + // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 + // can use other layout combinations. + bool isM8N8K4_F16 = + (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && + getMultiplicandAPtxType() == MMATypes::f16); + + if (!isM8N8K4_F16) { + // For all other shapes/types, layoutA must be row and layoutB must be col + if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { + return emitOpError("requires layoutA = #nvvm.mma_layout<row> and " + "layoutB = #nvvm.mma_layout<col> for shape <") + << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] + << "> with element types " + << stringifyEnum(*getMultiplicandAPtxType()) << " and " + << stringifyEnum(*getMultiplicandBPtxType()) + << ". Only m8n8k4 with f16 supports other layouts."; + } + } + return success(); } @@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair +ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getA())); + args.push_back(mt.lookupValue(op.getB())); + + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = + hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite; + + return {intId, std::move(args)}; +} + #define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite @@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result, } } +/// Verify the range attribute satisfies LLVM ConstantRange constructor +/// requirements for NVVM SpecialRangeableRegisterOp. +static LogicalResult +verifyConstantRangeAttr(Operation *op, + std::optional<LLVM::ConstantRangeAttr> rangeAttr) { + if (!rangeAttr) + return success(); + + const llvm::APInt &lower = rangeAttr->getLower(); + const llvm::APInt &upper = rangeAttr->getUpper(); + + // Check LLVM ConstantRange constructor condition + if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) { + unsigned bitWidth = lower.getBitWidth(); + llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth); + llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth); + return op->emitOpError( + "invalid range attribute: Lower == Upper, but they aren't min (") + << llvm::toString(minVal, 10, false) << ") or max (" + << llvm::toString(maxVal, 10, false) + << ") value! This is an invalid constant range."; + } + + return success(); +} + static llvm::Value *getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder) { return builder.CreateBitCast(arg, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index c477c6c..dcc1ef9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody( Value yielded = getSourceSkipUnary(terminator->getOperand(0)); Operation *reductionOp = yielded.getDefiningOp(); - if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + if (!reductionOp || reductionOp->getNumResults() != 1 || + reductionOp->getNumOperands() != 2) { errs << "expected reduction op to be binary"; return false; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 59013a2..cbc565b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() { SmallVector<int64_t> PackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getDestType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index dd9b4c2..6192d79 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( // FuseOp //===----------------------------------------------------------------------===// +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, loopTypes, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + // Loop types are automaticaly splat by the callee, setting up one is + // enough. + SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>()); + build(builder, result, loopTypes, target, mixedTileSizes, + mixedTileInterchange, applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + SmallVector<int64_t> staticTileSizes; + SmallVector<Value> dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); + SmallVector<int64_t> staticTileInterchange; + SmallVector<Value> dynamicTileInterchange; + dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange, + staticTileInterchange); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + auto staticTileInterchangeAttr = + builder.getDenseI64ArrayAttr(staticTileInterchange); + unsigned numExpectedLoops = + useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0); + SmallVector<Type> resultTypes; + resultTypes.reserve(numExpectedLoops); + assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && + "expected one loop type or as many as loops"); + if (loopTypes.size() == 1) + resultTypes.append(numExpectedLoops, loopTypes[0]); + else + llvm::append_range(resultTypes, loopTypes); + build(builder, result, /*transformed=*/target.getType(), + /*loops=*/resultTypes, + /*target=*/target, + /*tile_sizes=*/dynamicTileSizes, + /*tile_interchange=*/dynamicTileInterchange, + /*static_tile_sizes=*/staticTileSizesAttr, + /*static_tile_interchange=*/staticTileInterchangeAttr, + /*apply_cleanup=*/applyCleanup, + /*use_forall=*/useForall); +} + /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template <typename Range> @@ -630,13 +710,25 @@ DiagnosedSilenceableFailure transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - SmallVector<int64_t> tileSizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - SmallVector<int64_t> tileInterchange = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); + auto transformOp = cast<TransformOpInterface>(getOperation()); + + SmallVector<int64_t> tileSizes; + DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileSizes(), tileSizes); + if (!status.succeeded()) + return status; + SmallVector<int64_t> tileInterchange; + status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileInterchange(), tileInterchange); + if (!status.succeeded()) + return status; scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; + bool useForall = getUseForall(); + tilingOptions.setLoopType(useForall + ? scf::SCFTilingOptions::LoopType::ForallOp + : scf::SCFTilingOptions::LoopType::ForOp); SmallVector<OpFoldResult> tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); @@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } + size_t numLoops = + useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); LogicalResult result = applyTilingToAll( - rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizes.size() - llvm::count(tileSizes, 0), transformResults, + rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, + transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr<scf::SCFTileAndFuseResult> { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { - SmallVector<int64_t> permutation = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); - auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); - if (!std::is_permutation(sequence.begin(), sequence.end(), - permutation.begin(), permutation.end())) { - return emitOpError() << "expects interchange to be a permutation, found " - << getTileInterchange(); + auto iterspace_rank = getStaticTileSizes().size(); + ArrayRef<int64_t> permutation = getStaticTileInterchange(); + if (permutation.size() > iterspace_rank) + return emitOpError() + << "interchange length exceeds iteration space dimensions (" + << iterspace_rank << "), found " << getTileInterchange(); + SmallVector<bool> seen(iterspace_rank, false); + for (int64_t v : permutation) { + if (!ShapedType::isDynamic(v)) { + if (v < 0 || v >= static_cast<int64_t>(iterspace_rank)) + return emitOpError() << "expects interchange values to be in range [0, " + << iterspace_rank << "), found: " << v; + if (seen[v]) + return emitOpError() << "found duplicate interchange value: " << v; + seen[v] = true; + } } - SmallVector<int64_t> sizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); + ArrayRef<int64_t> sizes = getStaticTileSizes(); + size_t numExpectedLoops = + getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; return success(); } +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() { + return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext()); +} + +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() { + return getMixedValues(getStaticTileInterchange(), getTileInterchange(), + getContext()); +} + +void transform::FuseOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getTileInterchangeMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// @@ -2903,10 +3024,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } if (dynamicPointParseResult.has_value()) { - Type ChunkSizesType; + Type chunkSizesType; if (failed(*dynamicPointParseResult) || parser.parseComma() || - parser.parseType(ChunkSizesType) || - parser.resolveOperand(dynamicChunkSizes, ChunkSizesType, + parser.parseType(chunkSizesType) || + parser.resolveOperand(dynamicChunkSizes, chunkSizesType, result.operands)) { return failure(); } @@ -3278,9 +3399,9 @@ void transform::ContinuousTileSizesOp::getEffects( } static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, - Type targetType, Type tile_sizes, + Type targetType, Type tileSizes, Type) { - printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes}); + printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes}); } static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 0dac688..eb2d825 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { - // TODO: support the case that outer dimensions are not all 1s. A - // tensor.expand_shape will be generated in this case. - if (llvm::any_of(packOp.getAllOuterDims(), + if (llvm::any_of(packOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "not all outer dimensions of the result are 1s"); } + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + + // Verify that there are no: + // * non-unit + un-tiled-outer-dims, + // that are permuted. Supporting such cases would require refining the logic + // that generates the Transpose Op. + if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) { + static int prev = 0; + // Skip tiled dims - these can be permuted. + if (llvm::is_contained(innerDimsPos, dim)) + return true; + + // Check whether this dim has been permuted. Permuting unit dims is fine + // as that's effectively a no-op. + if (dim < prev && (packOp.getType().getShape()[prev] != 1 || + packOp.getType().getShape()[dim] != 1)) + return false; + + prev = dim; + return true; + })) { + return rewriter.notifyMatchFailure( + packOp, "At least one non-unit and un-tiled outer dim is permuted, " + "this is not supported ATM!"); + } + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); - int64_t numberOfTiles = innerDimsPos.size(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // - All outer dims are 1 - the corresponding transposition order doesn't - // matter, but requires all dim indices to be present. + // - All tiled outer dims are 1 - the corresponding transposition order + // doesn't matter, but requires all dim indices to be present. + // - Un-tiled outer dims remain un-permuted. - // 2.1 Get the permutation for linalg.transpose + // 2.1 Get the permutation for linalg.transpose: + // [ untiled-dims, inner-dims-pos ] + // Note, this logic assumes that the untiled dims are not permuted. SmallVector<int64_t> srcPermForTranspose; for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the @@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); - // 2.2 Create the init tensor for linalg.transpose with the correct shape - SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles, - oneIdxAttr); + // 2.2 Create the init tensor for linalg.transpose with the correct shape: + // [ untiled-dims, tiled-dims ] + ShapedType inputTy = cast<ShapedType>(input.getType()); + SmallVector<OpFoldResult> shapeForEmptyOp; + for (int64_t i = 0; i < srcRank; i++) { + if (llvm::is_contained(innerDimsPos, i)) { + // The tiled dims are appended after this loop. + continue; + } + if (inputTy.isStaticDim(i)) + shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i])); + else + shapeForEmptyOp.emplace_back( + tensor::DimOp::create(rewriter, loc, input, i).getResult()); + } shapeForEmptyOp.append(packOp.getMixedTiles()); // getMixedTiles() may contain Values pointing to constant ops, not the @@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); - // 3. Insert the inner tile to the destination: + // 3. Insert the inner tile into the destination tensor: // %inserted_tile = tensor.insert_slice(%transposed_tile) - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - // Outer dims are all 1s! - SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr); - SmallVector<int64_t> writeShape; + + // Compute the sizes attribute: + // [ outer-dims, tile-sizes ] + // Note that the output from the transpose Op excludes the tiled outer dims. + // However, given the assumption that: + // * all tiled outer dims == 1, + // we can just use a rank-expanding tensor.insert_slice. + SmallVector<OpFoldResult> writeSizes; + for (auto size : packOp.getAllOuterDims()) { + writeSizes.push_back(rewriter.getIndexAttr(size)); + } for (auto tileSize : packOp.getMixedTiles()) { - auto [tileSizeStatic, tileSizeOfr] = + auto [_, tileSizeOfr] = getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); writeSizes.push_back(tileSizeOfr); - writeShape.push_back(tileSizeStatic); } - // 4. Replace tensor.packOp with tensor.insert_slice created above + // TODO: Add a constructor for tensor.insert_slice that doesn't require + // strides nor offsets. + SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); + auto insert = tensor::InsertSliceOp::create( rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); + + // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index e25a012..1382c7ac 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR DEPENDS MLIRMemRefOpsIncGen @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda..507597b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } +void SubViewOp::inferStridedMetadataRanges( + ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange, + SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { + auto isUninitialized = + +[](IntegerValueRange range) { return range.isUninitialized(); }; + + // Bail early if any of the operands metadata is not ready: + SmallVector<IntegerValueRange> offsetOperands = + getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); + if (llvm::any_of(offsetOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> sizeOperands = + getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); + if (llvm::any_of(sizeOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> stridesOperands = + getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); + if (llvm::any_of(stridesOperands, isUninitialized)) + return; + + StridedMetadataRange sourceRange = + ranges[getSourceMutable().getOperandNumber()]; + if (sourceRange.isUninitialized()) + return; + + ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides(); + + // Get the dropped dims. + llvm::SmallBitVector droppedDims = getDroppedDims(); + + // Compute the new offset, strides and sizes. + ConstantIntRanges offset = sourceRange.getOffsets()[0]; + SmallVector<ConstantIntRanges> strides, sizes; + + for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { + bool dropped = droppedDims.test(i); + // Compute the new offset. + ConstantIntRanges off = + intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); + offset = intrange::inferAdd({offset, off}); + + // Skip dropped dimensions. + if (dropped) + continue; + // Multiply the strides. + strides.push_back( + intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); + // Get the sizes. + sizes.push_back(sizeOperands[i].getValue()); + } + + setMetadata(getResult(), + StridedMetadataRange::getRanked( + SmallVector<ConstantIntRanges>({std::move(offset)}), + std::move(sizes), std::move(strides))); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index 49b7162..6f815ae 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -121,7 +121,7 @@ struct EmulateWideIntPass final [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); RewritePatternSet patterns(ctx); - // Add common pattenrs to support contants, functions, etc. + // Add common patterns to support contants, functions, etc. arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6564a4e..90cbbd8 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" @@ -39,6 +40,16 @@ static bool isScalarLikeType(Type type) { return type.isIntOrIndexOrFloat() || isa<ComplexType>(type); } +/// Helper function to attach the `VarName` attribute to an operation +/// if a variable name is provided. +static void attachVarNameAttr(Operation *op, OpBuilder &builder, + StringRef varName) { + if (!varName.empty()) { + auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName); + op->setAttr(acc::getVarNameAttrName(), varNameAttr); + } +} + struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, MemRefType> { @@ -74,14 +85,18 @@ struct MemRefPointerLikeModel } mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, - StringRef varName, Type varType, - Value originalVar) const { + StringRef varName, Type varType, Value originalVar, + bool &needsFree) const { auto memrefTy = cast<MemRefType>(pointer); // Check if this is a static memref (all dimensions are known) - if yes // then we can generate an alloca operation. - if (memrefTy.hasStaticShape()) - return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + if (memrefTy.hasStaticShape()) { + needsFree = false; // alloca doesn't need deallocation + auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy); + attachVarNameAttr(allocaOp, builder, varName); + return allocaOp.getResult(); + } // For dynamic memrefs, extract sizes from the original variable if // provided. Otherwise they cannot be handled. @@ -99,8 +114,11 @@ struct MemRefPointerLikeModel // Note: We only add dynamic sizes to the dynamicSizes array // Static dimensions are handled automatically by AllocOp } - return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) - .getResult(); + needsFree = true; // alloc needs deallocation + auto allocOp = + memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes); + attachVarNameAttr(allocOp, builder, varName); + return allocOp.getResult(); } // TODO: Unranked not yet supported. @@ -108,10 +126,14 @@ struct MemRefPointerLikeModel } bool genFree(Type pointer, OpBuilder &builder, Location loc, - TypedValue<PointerLikeType> varPtr, Type varType) const { - if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + TypedValue<PointerLikeType> varToFree, Value allocRes, + Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) { + // Use allocRes if provided to determine the allocation type + Value valueToInspect = allocRes ? allocRes : memrefValue; + // Walk through casts to find the original allocation - Value currentValue = memrefValue; + Value currentValue = valueToInspect; Operation *originalAlloc = nullptr; // Follow the chain of operations to find the original allocation @@ -150,7 +172,7 @@ struct MemRefPointerLikeModel return true; } if (isa<memref::AllocOp>(originalAlloc)) { - // This is an alloc - generate dealloc + // This is an alloc - generate dealloc on varToFree memref::DeallocOp::create(builder, loc, memrefValue); return true; } @@ -1003,6 +1025,142 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +//===----------------------------------------------------------------------===// +// Recipe Region Helpers +//===----------------------------------------------------------------------===// + +/// Create and populate an init region for privatization recipes. +/// Returns the init block on success, or nullptr on failure. +/// Sets needsFree to indicate if the allocated memory requires deallocation. +static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc, + Type varType, StringRef varName, + ValueRange bounds, + bool &needsFree) { + // Create init block with arguments: original value + bounds + SmallVector<Type> argTypes{varType}; + SmallVector<Location> argLocs{loc}; + for (Value bound : bounds) { + argTypes.push_back(bound.getType()); + argLocs.push_back(loc); + } + + auto initBlock = std::make_unique<Block>(); + initBlock->addArguments(argTypes, argLocs); + builder.setInsertionPointToStart(initBlock.get()); + + Value privatizedValue; + + // Get the block argument that represents the original variable + Value blockArgVar = initBlock->getArgument(0); + + // Generate init region body based on variable type + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + auto typedVar = cast<TypedValue<MappableType>>(blockArgVar); + privatizedValue = mappableTy.generatePrivateInit( + builder, loc, typedVar, varName, bounds, {}, needsFree); + if (!privatizedValue) + return nullptr; + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + // Use PointerLikeType's allocation API with the block argument + privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType, + blockArgVar, needsFree); + if (!privatizedValue) + return nullptr; + } + + // Add yield operation to init block + acc::YieldOp::create(builder, loc, privatizedValue); + + return initBlock; +} + +/// Create and populate a copy region for firstprivate recipes. +/// Returns the copy block on success, or nullptr on failure. +/// TODO: Handle MappableType - it does not yet have a copy API. +static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc, + Type varType, + ValueRange bounds) { + // Create copy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> copyArgTypes{varType, varType}; + SmallVector<Location> copyArgLocs{loc, loc}; + for (Value bound : bounds) { + copyArgTypes.push_back(bound.getType()); + copyArgLocs.push_back(loc); + } + + auto copyBlock = std::make_unique<Block>(); + copyBlock->addArguments(copyArgTypes, copyArgLocs); + builder.setInsertionPointToStart(copyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a copy API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + // Generate copy region body based on variable type + if (isPointerLike) { + auto pointerLikeTy = cast<PointerLikeType>(varType); + Value originalArg = copyBlock->getArgument(0); + Value privatizedArg = copyBlock->getArgument(1); + + // Generate copy operation using PointerLikeType interface + if (!pointerLikeTy.genCopy( + builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg), + cast<TypedValue<PointerLikeType>>(originalArg), varType)) + return nullptr; + } + + // Add terminator to copy block + acc::TerminatorOp::create(builder, loc); + + return copyBlock; +} + +/// Create and populate a destroy region for privatization recipes. +/// Returns the destroy block on success, or nullptr if not needed. +static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder, + Location loc, Type varType, + Value allocRes, + ValueRange bounds) { + // Create destroy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> destroyArgTypes{varType, varType}; + SmallVector<Location> destroyArgLocs{loc, loc}; + for (Value bound : bounds) { + destroyArgTypes.push_back(bound.getType()); + destroyArgLocs.push_back(loc); + } + + auto destroyBlock = std::make_unique<Block>(); + destroyBlock->addArguments(destroyArgTypes, destroyArgLocs); + builder.setInsertionPointToStart(destroyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a deallocation API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + auto privatizedArg = + cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1)); + // Pass allocRes to help determine the allocation type + if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType)) + return nullptr; + + acc::TerminatorOp::create(builder, loc); + + return destroyBlock; +} + } // namespace //===----------------------------------------------------------------------===// @@ -1050,6 +1208,55 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() { return success(); } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1080,6 +1287,60 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { return success(); } +std::optional<FirstprivateRecipeOp> +FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init, copy, and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + auto copyBlock = createCopyRegion(builder, loc, varType, bounds); + if (!copyBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + recipe.getCopyRegion().push_back(copyBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // ReductionRecipeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 135c033..645cbff 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op, } template <typename It> -bool isUnique(It begin, It end) { +static bool isUnique(It begin, It end) { if (begin == end) { return true; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index a1711a6..069191c 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -143,8 +143,8 @@ void VarInfo::setNum(Var::Num n) { /// Helper function for `assertUsageConsistency` to better handle SMLoc /// mismatches. -LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc -minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { +[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, + llvm::SMLoc sm2) { const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1)); assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index f539502..684c088 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -43,8 +43,8 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// #ifndef NDEBUG -LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, - Location loc, Value memref) { +[[maybe_unused]] static void dumpIndexMemRef(OpBuilder &builder, Location loc, + Value memref) { memref = memref::CastOp::create( builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); createFuncCall(builder, loc, "printMemrefInd", TypeRange{}, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fa97b49..ac72002 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType( sourceTensorType.getEncoding()); } +// TODO: This uses neither offsets nor strides! RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 5aad671..1cba1bb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { @@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) { } TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { - return TargetEnvAttr::get(context, Level::eightK, + return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK, {Profile::pro_int, Profile::pro_fp}, {}); } @@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp index bcb880a..a0661e4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -61,8 +61,8 @@ public: ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - const auto targetEnvAttr = - TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + const auto targetEnvAttr = TargetEnvAttr::get( + ctx, specificationVersion, level, selectedProfiles, selectedExtensions); mod->setAttr(TargetEnvAttr::name, targetEnvAttr); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 20f9333..f072e3e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { //===----------------------------------------------------------------------===// template <typename T> -FailureOr<SmallVector<T>> -TosaProfileCompliance::getOperatorDefinition(Operation *op, - CheckCondition &condition) { +FailureOr<OpComplianceInfo<T>> +TosaProfileCompliance::getOperatorDefinition(Operation *op) { const std::string opName = op->getName().getStringRef().str(); const auto complianceMap = getProfileComplianceMap<T>(); const auto it = complianceMap.find(opName); if (it == complianceMap.end()) return {}; - return findMatchedProfile<T>(op, it->second, condition); + return findMatchedEntry<T>(op, it->second); } template <typename T> @@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( if (specRequiredModeSet.size() == 0) return success(); - CheckCondition condition = CheckCondition::invalid; - const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition); - if (failed(maybeOpRequiredMode)) { + const auto maybeOpDefinition = getOperatorDefinition<T>(op); + if (failed(maybeOpDefinition)) { // Operators such as control-flow and shape ops do not have an operand type // restriction. When the profile compliance information of operation is not // found, confirm if the target have enabled the profile required from the // specification. - int mode_count = 0; + int modeCount = 0; for (const auto &cands : specRequiredModeSet) { if (targetEnv.allowsAnyOf(cands)) return success(); - mode_count += cands.size(); + modeCount += cands.size(); } op->emitOpError() << "illegal: requires" - << (mode_count > 1 ? " any of " : " ") << "[" + << (modeCount > 1 ? " any of " : " ") << "[" << llvm::join(stringifyProfile<T>(specRequiredModeSet), ", ") << "] but not enabled in target\n"; @@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( // Find the required profiles or extensions according to the operand type // combination. - const auto opRequiredMode = maybeOpRequiredMode.value(); + const auto opDefinition = maybeOpDefinition.value(); + const SmallVector<T> opRequiredMode = opDefinition.mode; + const CheckCondition condition = opDefinition.condition; + if (opRequiredMode.size() == 0) { // No matched restriction found. return success(); @@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( } } + // Ensure the matched op compliance version does not exceed the target + // specification version. + const VersionedTypeInfo versionedTypeInfo = + opDefinition.operandTypeInfoSet[0]; + const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second}; + const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()}; + if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) { + op->emitOpError() << "illegal: the target specification version (" + << stringifyVersion(targetVersion) + << ") is not backwards compatible with the op compliance " + "specification version (" + << stringifyVersion(complianceVersion) << ")\n"; + return failure(); + } + return success(); } @@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op, } LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { - CheckCondition condition = CheckCondition::invalid; - const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); - const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + const auto maybeProfDef = getOperatorDefinition<Profile>(op); + const auto maybeExtDef = getOperatorDefinition<Extension>(op); if (failed(maybeProfDef) && failed(maybeExtDef)) return success(); - const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || - (succeeded(maybeExtDef) && !maybeExtDef->empty()); + const bool hasEntry = + (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->mode.empty()); if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); @@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { SmallVector<TypeInfo> bestTypeInfo; const auto searchBestMatch = [&](auto map) { for (const auto &complianceInfos : map[opName]) { - for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) { + for (const auto &versionedTypeInfos : + complianceInfos.operandTypeInfoSet) { + const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first; const int matches = llvm::count_if( llvm::zip_equal(current, typeInfos), [&](const auto zipType) { return isSameTypeInfo(std::get<0>(zipType), @@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { // Find the profiles or extensions requirement according to the signature of // type of the operand list. template <typename T> -SmallVector<T> TosaProfileCompliance::findMatchedProfile( - Operation *op, SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition) { +OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry( + Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) { assert(compInfo.size() != 0 && "profile-based compliance information is empty"); @@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile( return {}; for (size_t i = 0; i < compInfo.size(); i++) { - SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet; - for (SmallVector<TypeInfo> expected : sets) { + SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet; + for (const auto &set : sets) { + SmallVector<TypeInfo> expected = set.first; assert(present.size() == expected.size() && "the entries for profile-based compliance do not match between " "the generated metadata and the type definition retrieved from " " the operation"); - bool is_found = true; + bool isFound = true; // Compare the type signature between the given operation and the // compliance metadata. for (size_t j = 0; j < expected.size(); j++) { if (!isSameTypeInfo(present[j], expected[j])) { // Verify the next mode set from the list. - is_found = false; + isFound = false; break; } } - if (is_found == true) { - condition = compInfo[i].condition; - return compInfo[i].mode; + if (isFound == true) { + SmallVector<VersionedTypeInfo> typeInfoSet{set}; + OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet, + compInfo[i].condition}; + return info; } } } diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp index 9a24c2b..a2cff6a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -21,10 +21,10 @@ using namespace mlir; // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc" diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 58256b0..45c54c7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), result); } +namespace { + +/// Fold `vector.step -> arith.cmpi` when the step value is compared to a +/// constant large enough such that the result is the same at all indices. +/// +/// For example, rewrite the 'greater than' comparison below, +/// +/// ```mlir +/// %cst = arith.constant dense<7> : vector<3xindex> +/// %stp = vector.step : vector<3xindex> +/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> +/// ``` +/// +/// as, +/// +/// ```mlir +/// %out = arith.constant dense<false> : vector<3xi1>. +/// ``` +/// +/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result +/// is false at ALL indices we fold. If the constant was 1, then +/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold, +/// conservatively preferring the 'compact' vector.step representation. +/// +/// Note: this folder only works for the case where the constant (`%cst` above) +/// is the second operand of the comparison. The arith.cmpi canonicalizer will +/// ensure that constants are always second (on the right). +struct StepCompareFolder : public OpRewritePattern<StepOp> { + using Base::Base; + + LogicalResult matchAndRewrite(StepOp stepOp, + PatternRewriter &rewriter) const override { + const int64_t stepSize = stepOp.getResult().getType().getNumElements(); + + for (OpOperand &use : stepOp.getResult().getUses()) { + auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner()); + if (!cmpiOp) + continue; + + // arith.cmpi canonicalizer makes constants final operands. + const unsigned stepOperandNumber = use.getOperandNumber(); + if (stepOperandNumber != 0) + continue; + + // Check that operand 1 is a constant. + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); + std::optional<int64_t> maybeConstValue = + getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; + + int64_t constValue = maybeConstValue.value(); + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional<bool> { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; + + // Handle ule and ugt. + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize - 1 <= constValue) { + return pred == arith::CmpIPredicate::ule; + } + + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; + + return std::nullopt; + }(); + + if (!maybeSplat.has_value()) + continue; + + rewriter.setInsertionPointAfter(cmpiOp); + + auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType()); + if (!type) + continue; + + auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); + + rewriter.replaceOp(cmpiOp, splat); + return success(); + } + + return failure(); + } +}; +} // namespace + +void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<StepCompareFolder>(context); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e95338f..12e6475 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern { // Some values may be yielded multiple times and correspond to multiple // results. Deduplicating occurs by taking each result with its matching // yielded value, and: - // 1. recording the unique first position at which the value is yielded. + // 1. recording the unique first position at which the value with uses is + // yielded. // 2. recording for the result, the first position at which the dedup'ed // value is yielded. // 3. skipping from the new result types / new yielded values any result // that has no use or whose yielded value has already been seen. for (OpResult result : warpOp.getResults()) { + if (result.use_empty()) + continue; Value yieldOperand = yield.getOperand(result.getResultNumber()); auto it = dedupYieldOperandPositionMap.insert( std::make_pair(yieldOperand, newResultTypes.size())); dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); - if (result.use_empty() || !it.second) + if (!it.second) continue; newResultTypes.push_back(result.getType()); newYieldValues.push_back(yieldOperand); @@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), escapingValueDistTypesElse.end()); - llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; for (auto [idx, val] : llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { - origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(val); newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); } - // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + // Replace the old `WarpOp` with the new one that has additional yield + // values and types. + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // `ifOp` returns the result of the inner warp op. SmallVector<Type> newIfOpDistResTypes; for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { @@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newIfOp = scf::IfOp::create( - rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), - static_cast<bool>(ifOp.thenBlock()), + rewriter, ifOp.getLoc(), newIfOpDistResTypes, + newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()), static_cast<bool>(ifOp.elseBlock())); auto encloseRegionInWarpOp = [&](Block *oldIfBranch, Block *newIfBranch, @@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { innerWarpInputVals.push_back( - newWarpOp.getResult(warpResRangeStart)); + newWarpOp.getResult(newIndices[warpResRangeStart])); escapeValToBlockArgIndex[escapingValues[i]] = innerWarpInputTypes.size(); innerWarpInputTypes.push_back(escapingValueInputTypes[i]); @@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newIfOp.getResult(newIdx), newIfOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `IfOp`. - for (auto [origIdx, newIdx] : origToNewYieldIdx) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `IfOp`, they should not have any uses - // at this point. - rewriter.eraseOp(ifOp); - rewriter.eraseOp(warpOp); return success(); } @@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { escapingValueDistTypes.begin(), escapingValueDistTypes.end()); // Next, we insert all non-`ForOp` yielded values and their distributed - // types. We also create a mapping between the non-`ForOp` yielded value - // index and the corresponding new `WarpOp` yield value index (needed to - // update users later). - llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping; + // types. for (auto [i, v] : llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { - nonForResultMapping[i] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); } // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. @@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // escaping values in the new `WarpOp`. SmallVector<Value> newForOpOperands; for (size_t i = 0; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); + newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); @@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { llvm::SmallDenseMap<Value, int64_t> argIndexMapping; for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { - innerWarpInput.push_back(newWarpOp.getResult(i)); + innerWarpInput.push_back(newWarpOp.getResult(newIndices[i])); argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = innerWarpInputType.size(); innerWarpInputType.push_back( @@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (!innerWarp.getResults().empty()) scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); - // Update the users of original `WarpOp` results that were coming from the + // Update the users of the new `WarpOp` results that were coming from the // original `ForOp` to the corresponding new `ForOp` result. for (auto [origIdx, newIdx] : forResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `ForOp`. - for (auto [origIdx, newIdx] : nonForResultMapping) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `ForOp`, they should not have any uses - // at this point. - rewriter.eraseOp(forOp); - rewriter.eraseOp(warpOp); // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 14639c5..fbae098 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern { auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); + int64_t targetShapeRank = targetShape->size(); auto dstVecType = cast<VectorType>(op->getResult(0).getType()); SmallVector<int64_t> originalSize = *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); - // Bail-out if rank(source) != rank(target). The main limitation here is the - // fact that `ExtractStridedSlice` requires the rank for the input and - // output to match. If needed, we can relax this later. - if (originalSize.size() != targetShape->size()) - return rewriter.notifyMatchFailure( - op, "expected input vector rank to match target shape rank"); + int64_t originalShapeRank = originalSize.size(); + Location loc = op->getLoc(); + + // Handle rank mismatch by adding leading unit dimensions to targetShape + SmallVector<int64_t> adjustedTargetShape(originalShapeRank); + int64_t rankDiff = originalShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), + adjustedTargetShape.begin() + rankDiff, 1); + std::copy(targetShape->begin(), targetShape->end(), + adjustedTargetShape.begin() + rankDiff); + + int64_t adjustedTargetShapeRank = adjustedTargetShape.size(); // Prepare the result vector. Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - SmallVector<int64_t> strides(targetShape->size(), 1); - VectorType newVecType = + SmallVector<int64_t> strides(adjustedTargetShapeRank, 1); + VectorType unrolledVecType = VectorType::get(*targetShape, dstVecType.getElementType()); // Create the unrolled computation. for (SmallVector<int64_t> offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { + StaticTileOffsetRange(originalSize, adjustedTargetShape)) { SmallVector<Value> extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = dyn_cast<VectorType>(operand.get().getType()); @@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern { extractOperands.push_back(operand.get()); continue; } - extractOperands.push_back( - rewriter.createOrFold<vector::ExtractStridedSliceOp>( - loc, operand.get(), offsets, *targetShape, strides)); + Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, operand.get(), offsets, adjustedTargetShape, strides); + + // Reshape to remove leading unit dims if needed + if (adjustedTargetShapeRank > targetShapeRank) { + extracted = rewriter.createOrFold<vector::ShapeCastOp>( + loc, VectorType::get(*targetShape, vecType.getElementType()), + extracted); + } + extractOperands.push_back(extracted); } + Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, op, extractOperands, newVecType); + rewriter, loc, op, extractOperands, unrolledVecType); + + Value computeResult = newOp->getResult(0); + + // Use strides sized to targetShape for proper insertion + SmallVector<int64_t> insertStrides = + (adjustedTargetShapeRank > targetShapeRank) + ? SmallVector<int64_t>(targetShapeRank, 1) + : strides; + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( - loc, newOp->getResult(0), result, offsets, strides); + loc, computeResult, result, offsets, insertStrides); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 025ee9a..c809c502 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -91,7 +91,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) { // Check whether the two source vector dimensions that are greater than one // must be transposed with each other so that we can apply one of the 2-D - // transpose pattens. Otherwise, these patterns are not applicable. + // transpose patterns. Otherwise, these patterns are not applicable. if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], op.getPermutation())) return failure(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9beb22d..1599ae9 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } +// a helper utility to perform binary operation on OpFoldResult. +// If both a and b are attributes, it will simply return the result. +// Otherwise, the corresponding arith op will be generated, and an +// contant op will be created if one of them is an attribute. +template <typename ArithOp> +OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, + OpBuilder &builder) { + auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); + auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); + return builder.create<ArithOp>(loc, aVal, bVal).getResult(); +} + +// a helper utility to perform division operation on OpFoldResult and int64_t. +#define div(a, b) \ + genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform reminder operation on OpFoldResult and int64_t. +#define rem(a, b) \ + genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform multiply operation on OpFoldResult and int64_t. +#define mul(a, b) \ + genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform addition operation on two OpFoldResult. +#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder) + +// block the given offsets according to the block shape +// say the original offset is [y, x], and the block shape is [By, Bx], +// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] +SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets, + ArrayRef<int64_t> blockShape) { + + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + SmallVector<OpFoldResult> blockedOffsets; + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + return blockedOffsets; +} + +// Get strides as vector of integer for MemDesc. +SmallVector<int64_t> MemDescType::getStrideShape() { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + + ArrayAttr strideAttr = getStrideAttr(); + SmallVector<int64_t> strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast<IntegerAttr>(attr).getInt()); + } + + SmallVector<int64_t> innerBlkShape = getBlockShape(); + + // get perm from FCD to LCD + // perm[i] = the dim with i-th smallest stride + SmallVector<int, 4> perm = + llvm::to_vector<4>(llvm::seq<int>(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + + SmallVector<int64_t> innerBlkStride(innerBlkShape.size()); + innerBlkStride[perm[0]] = 1; + for (size_t i = 1; i < perm.size(); ++i) + innerBlkStride[perm[i]] = + innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector<int64_t> matrixShapeOrig(matrixShape.size()); + SmallVector<int64_t> BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } + + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; + + SmallVector<int64_t> outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = + outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; + } + + // combine the inner and outer strides + SmallVector<int64_t> blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + return blockedStrides; +} + +// Calculate the linear offset using the blocked offsets and stride +Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets) { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + SmallVector<int64_t> blockShape = getBlockShape(); + SmallVector<int64_t> strides = getStrideShape(); + SmallVector<OpFoldResult> blockedOffsets; + + // blockshape equal to matrixshape means no blocking + if (llvm::equal(blockShape, matrixShape)) { + // remove the outer dims from strides + strides.erase(strides.begin(), strides.begin() + matrixShape.size()); + } else { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + offsets = blockedOffsets; + } + + // Start with initial value as matrix descriptor's base offset. + Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); + for (size_t i = 0; i < offsets.size(); ++i) { + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + } + + return linearOffset; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788..abd12e2 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -20,8 +20,8 @@ #define DEBUG_TYPE "xegpu" -namespace mlir { -namespace xegpu { +using namespace mlir; +using namespace mlir::xegpu; static bool isSharedMemory(const MemRefType &memrefTy) { Attribute attr = memrefTy.getMemorySpace(); @@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } +LogicalResult +IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, + UnitAttr subgroup_block_io, + function_ref<InFlightDiagnostic()> emitError) { + + if (!dataTy) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + else + return success(); + } + + if (mdescTy.getRank() != 2) + return emitError() << "mem_desc must be 2D."; + + ArrayRef<int64_t> dataShape = dataTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + + if (dataShape.size() == 2) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitError() << "data shape must not exceed mem_desc shape."; + } else { + SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); + // if the subgroup_block_io attribute is set, mdescTy must have block + // attribute + if (subgroup_block_io && !blockShape.size()) + return emitError() << "mem_desc must have block attribute when " + "subgroup_block_io is set."; + // if the subgroup_block_io attribute is set, the memdesc should be row + // major + if (subgroup_block_io && mdescTy.isColMajor()) + return emitError() << "mem_desc should be row major when " + "subgroup_block_io is set."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + // Call the generated builder with all parameters (including optional ones as + // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { - VectorType resTy = getRes().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); + auto resTy = dyn_cast<VectorType>(getRes().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); - ArrayRef<int64_t> valueShape = resTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); - return success(); + return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult StoreMatrixOp::verify() { - VectorType dataTy = getData().getType(); - MemDescType mdescTy = getMemDesc().getType(); - - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - - ArrayRef<int64_t> dataShape = dataTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - - return success(); -} - -//===----------------------------------------------------------------------===// -// XeGPU_MemDescSubviewOp -//===----------------------------------------------------------------------===// - -void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, - Type resTy, Value src, - llvm::ArrayRef<OpFoldResult> offsets) { - llvm::SmallVector<Value> dynamicOffsets; - llvm::SmallVector<int64_t> staticOffsets; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); -} - -LogicalResult MemDescSubviewOp::verify() { - MemDescType srcTy = getSrc().getType(); - MemDescType resTy = getRes().getType(); - ArrayRef<int64_t> srcShape = srcTy.getShape(); - ArrayRef<int64_t> resShape = resTy.getShape(); - - if (srcTy.getRank() < resTy.getRank()) - return emitOpError("result rank must not exceed source rank."); - if (llvm::any_of( - llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed source shape."); - - if (srcTy.getStrides() != resTy.getStrides()) - return emitOpError("result must inherit the source strides."); - - return success(); + auto dataTy = dyn_cast<VectorType>(getData().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); + return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } -} // namespace xegpu -} // namespace mlir - namespace mlir { #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 36c498e..f77784a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op)) return getTileShape(op->getOpResult(0)); if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, - xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op)) + xegpu::StoreMatrixOp>(op)) return getTileShape(op->getOpOperand(0)); - if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op)) + if (isa<xegpu::StoreNdOp>(op)) return getTileShape(op->getOpOperand(1)); + // Handle LoadGatherOp and StoreScatterOp (with and without offset) + if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) { + if (loadGatherOp.getOffsets()) + return getTileShape(loadGatherOp->getOpResult(0)); + else + return getTileShape(loadGatherOp->getOpOperand(0)); + } + + if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) + return getTileShape(storeScatterOp.getOffsets() + ? storeScatterOp->getOpOperand(0) + : storeScatterOp->getOpOperand(1)); + if (isa<xegpu::DpasOp>(op)) { std::optional<SmallVector<int64_t>> aTile = getTileShape(op->getOpOperand(0)); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0f..aafa1b7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = op.getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType()); + assert(valueTy && "the value type must be vector type!"); + std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> { return failure(); Location loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType()); + assert(valueTy && "the value type must be vector type!"); ArrayRef<int64_t> shape = valueTy.getShape(); auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c28d2fc..31a967d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { return failure(); ArrayRef<int64_t> wgShape = op.getDataShape(); - VectorType valueTy = op.getRes().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType()); + assert(valueTy && "the value type must be vector type!"); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 3d19c5a..9b23dd6 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2200,10 +2200,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, os << '>'; } os << '['; - interleave( - loc.getLocations(), - [&](Location loc) { printLocationInternal(loc, pretty); }, - [&]() { os << ", "; }); + interleaveComma(loc.getLocations(), [&](Location loc) { + printLocationInternal(loc, pretty); + }); os << ']'; }) .Default([&](LocationAttr loc) { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 776b5c6..4d81918 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl { } // Otherwise, try to load the source file. - std::string ignored; - unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored); + auto bufferOrErr = llvm::MemoryBuffer::getFile(filename); + if (!bufferOrErr) + return 0; + unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc()); filenameToBufId[filename] = id; return id; } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 1fa04ed..5f63fe6 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -121,6 +121,11 @@ namespace mlir { class MLIRContextImpl { public: //===--------------------------------------------------------------------===// + // Remark + //===--------------------------------------------------------------------===// + std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; + + //===--------------------------------------------------------------------===// // Debugging //===--------------------------------------------------------------------===// @@ -135,11 +140,6 @@ public: DiagnosticEngine diagEngine; //===--------------------------------------------------------------------===// - // Remark - //===--------------------------------------------------------------------===// - std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; - - //===--------------------------------------------------------------------===// // Options //===--------------------------------------------------------------------===// @@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); } -MLIRContext::~MLIRContext() = default; +MLIRContext::~MLIRContext() { + // finalize remark engine before destroying anything else. + impl->remarkEngine.reset(); +} /// Copy the specified array of elements into memory managed by the provided /// bump pointer allocator. This assumes the elements are all PODs. @@ -1201,7 +1204,7 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, /// present in result expressions is less than `dimCount` and the highest index /// of symbolic identifier present in result expressions is less than /// `symbolCount`. -LLVM_ATTRIBUTE_UNUSED static bool +[[maybe_unused]] static bool willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, ArrayRef<AffineExpr> results) { int64_t maxDimPosition = -1; diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index a55f61a..031eae2 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -16,7 +16,7 @@ #include "llvm/ADT/StringRef.h" using namespace mlir::remark::detail; - +using namespace mlir::remark; //------------------------------------------------------------------------------ // Remark //------------------------------------------------------------------------------ @@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) { void Remark::print(llvm::raw_ostream &os, bool printLocation) const { // Header: [Type] pass:remarkName StringRef type = getRemarkTypeString(); - StringRef categoryName = getFullCategoryName(); + StringRef categoryName = getCombinedCategoryName(); StringRef name = remarkName; os << '[' << type << "] "; @@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const { os << "Function=" << getFunction() << " | "; if (printLocation) { - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) { os << " @" << flc.getFilename() << ":" << flc.getLine() << ":" << flc.getColumn(); + } } printArgs(os, getArgs()); @@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const { r.RemarkType = getRemarkType(); r.RemarkName = getRemarkName(); // MLIR does not use passes; instead, it has categories and sub-categories. - r.PassName = getFullCategoryName(); + r.PassName = getCombinedCategoryName(); r.FunctionName = getFunction(); r.Loc = locLambda(); for (const Remark::Arg &arg : getArgs()) { @@ -225,26 +226,42 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc, // RemarkEngine //===----------------------------------------------------------------------===// -void RemarkEngine::report(const Remark &&remark) { +void RemarkEngine::reportImpl(const Remark &remark) { // Stream the remark - if (remarkStreamer) + if (remarkStreamer) { remarkStreamer->streamOptimizationRemark(remark); + } // Print using MLIR's diagnostic if (printAsEmitRemarks) emitRemark(remark.getLocation(), remark.getMsg()); } +void RemarkEngine::report(const Remark &&remark) { + if (remarkEmittingPolicy) + remarkEmittingPolicy->reportRemark(remark); +} + RemarkEngine::~RemarkEngine() { + if (remarkEmittingPolicy) + remarkEmittingPolicy->finalize(); + if (remarkStreamer) remarkStreamer->finalize(); } -llvm::LogicalResult -RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, - std::string *errMsg) { - // If you need to validate categories/filters, do so here and set errMsg. +llvm::LogicalResult RemarkEngine::initialize( + std::unique_ptr<MLIRRemarkStreamerBase> streamer, + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy, + std::string *errMsg) { + remarkStreamer = std::move(streamer); + + auto reportFunc = + std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1); + remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc))); + + this->remarkEmittingPolicy = std::move(remarkEmittingPolicy); return success(); } @@ -301,14 +318,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks, } llvm::LogicalResult mlir::remark::enableOptimizationRemarks( - MLIRContext &ctx, - std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer, - const remark::RemarkCategories &cats, bool printAsEmitRemarks) { + MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, + const RemarkCategories &cats, bool printAsEmitRemarks) { auto engine = - std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats); + std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats); std::string errMsg; - if (failed(engine->initialize(std::move(streamer), &errMsg))) { + if (failed(engine->initialize(std::move(streamer), + std::move(remarkEmittingPolicy), &errMsg))) { llvm::report_fatal_error( llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg); } @@ -316,3 +334,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks( return success(); } + +//===----------------------------------------------------------------------===// +// Remark emitting policies +//===----------------------------------------------------------------------===// + +namespace mlir::remark { +RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default; +RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default; +} // namespace mlir::remark diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 388de1c..f96af02 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES FunctionInterfaces.cpp IndexingMapOpInterface.cpp InferIntRangeInterface.cpp + InferStridedMetadataInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp MemOpInterfaces.cpp @@ -64,6 +65,21 @@ add_mlir_library(MLIRFunctionInterfaces add_mlir_interface_library(IndexingMapOpInterface) add_mlir_interface_library(InferIntRangeInterface) + +add_mlir_library(MLIRInferStridedMetadataInterface + InferStridedMetadataInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces + + DEPENDS + MLIRInferStridedMetadataInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + MLIRIR +) + add_mlir_interface_library(InferTypeOpInterface) add_mlir_library(MLIRLoopLikeInterface diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index 9f3e97d..84fc9b8 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) { return os; } +SmallVector<IntegerValueRange> +mlir::getIntValueRanges(ArrayRef<OpFoldResult> values, + GetIntRangeFn getIntRange, int32_t indexBitwidth) { + SmallVector<IntegerValueRange> ranges; + ranges.reserve(values.size()); + for (OpFoldResult ofr : values) { + if (auto value = dyn_cast<Value>(ofr)) { + ranges.push_back(getIntRange(value)); + continue; + } + + // Create a constant range. + auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); + ranges.emplace_back(ConstantIntRanges::constant( + attr.getValue().sextOrTrunc(indexBitwidth))); + } + return ranges; +} + void mlir::intrange::detail::defaultInferResultRanges( InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) { diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp new file mode 100644 index 0000000..483e9f1 --- /dev/null +++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp @@ -0,0 +1,36 @@ +//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferStridedMetadataInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include <optional> + +using namespace mlir; + +#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc" + +void StridedMetadataRange::print(raw_ostream &os) const { + if (isUninitialized()) { + os << "strided_metadata<None>"; + return; + } + os << "strided_metadata<offset = ["; + llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], sizes = ["; + llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], strides = ["; + llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "]>"; +} diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp index d213a1a..bf36286 100644 --- a/mlir/lib/Remark/RemarkStreamer.cpp +++ b/mlir/lib/Remark/RemarkStreamer.cpp @@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() { namespace mlir::remark { LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks) { FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr = @@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer( if (failed(sOr)) return failure(); - return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat, + return remark::enableOptimizationRemarks(ctx, std::move(*sOr), + std::move(remarkEmittingPolicy), cat, printAsEmitRemarks); } diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index cb90ef8..d52d5e7 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( raw_ostream &os, const RecordKeeper &records, StringRef tag) : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} -void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef<const Record *> opDefs) { - NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); +void StaticVerifierFunctionEmitter::emitOpConstraints() { emitTypeConstraints(); emitAttrConstraints(); emitPropConstraints(); diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 5fe5f41..1243511 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -357,11 +357,6 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { if (expressionOp.getDoNotInline()) return false; - // Do not inline expressions with side effects to prevent side-effect - // reordering. - if (expressionOp.hasSideEffects()) - return false; - // Do not inline expressions with multiple uses. Value result = expressionOp.getResult(); if (!result.hasOneUse()) @@ -377,7 +372,34 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { // Do not inline expressions used by other expressions or by ops with the // CExpressionInterface. If this was intended, the user could have been merged // into the expression op. - return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user); + if (isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user)) + return false; + + // Expressions with no side-effects can safely be inlined. + if (!expressionOp.hasSideEffects()) + return true; + + // Expressions with side-effects can be only inlined if side-effect ordering + // in the program is provably retained. + + // Require the user to immediately follow the expression. + if (++Block::iterator(expressionOp) != Block::iterator(user)) + return false; + + // These single-operand ops are safe. + if (isa<emitc::IfOp, emitc::SwitchOp, emitc::ReturnOp>(user)) + return true; + + // For assignment look for specific cases to inline as evaluation order of + // its lvalue and rvalue is undefined in C. + if (auto assignOp = dyn_cast<emitc::AssignOp>(user)) { + // Inline if this assignment is of the form `<var> = <expression>`. + if (expressionOp.getResult() == assignOp.getValue() && + isa_and_present<VariableOp>(assignOp.getVar().getDefiningOp())) + return true; + } + + return false; } static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp index 4bbcd8e..db39c70 100644 --- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp +++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp @@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) { return UnknownLoc::get(context); // Add a fused location to link the subprogram information. - StringAttr funcName = StringAttr::get(context, subprogram->getName()); StringAttr fileName = StringAttr::get(context, subprogram->getFilename()); return FusedLocWith<DISubprogramAttr>::get( - {NameLoc::get(funcName), - FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, + {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, translate(subprogram), context); } diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 9603813..857e31b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2604,6 +2604,7 @@ static constexpr std::array kExplicitLLVMFuncOpAttributes{ StringLiteral("denormal-fp-math-f32"), StringLiteral("fp-contract"), StringLiteral("frame-pointer"), + StringLiteral("inlinehint"), StringLiteral("instrument-function-entry"), StringLiteral("instrument-function-exit"), StringLiteral("memory"), @@ -2643,6 +2644,8 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, funcOp.setNoInline(true); if (func->hasFnAttribute(llvm::Attribute::AlwaysInline)) funcOp.setAlwaysInline(true); + if (func->hasFnAttribute(llvm::Attribute::InlineHint)) + funcOp.setInlineHint(true); if (func->hasFnAttribute(llvm::Attribute::OptimizeNone)) funcOp.setOptimizeNone(true); if (func->hasFnAttribute(llvm::Attribute::Convergent)) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 845a14f..147613f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1652,6 +1652,8 @@ static void convertFunctionAttributes(LLVMFuncOp func, llvmFunc->addFnAttr(llvm::Attribute::NoInline); if (func.getAlwaysInlineAttr()) llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline); + if (func.getInlineHintAttr()) + llvmFunc->addFnAttr(llvm::Attribute::InlineHint); if (func.getOptimizeNoneAttr()) llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone); if (func.getConvergentAttr()) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 132be4e..51c6077 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -956,7 +956,7 @@ inline parsed_inst_t ExpressionParser::buildNumericOp( << ", type = " << ty << " ***"; auto tysToPop = SmallVector<Type, numOperands>(); tysToPop.resize(numOperands); - std::fill(tysToPop.begin(), tysToPop.end(), ty); + llvm::fill(tysToPop, ty); auto operands = popOperands(tysToPop); if (failed(operands)) return failure(); diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp index 9670285..3fda5a7 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -93,7 +93,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { // Emit function to add the generated matchers to the pattern list. os << "template <typename... ConfigsT>\n" - "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" + "[[maybe_unused]] static void populateGeneratedPDLLPatterns(" "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n"; for (const auto &name : patternNames) os << " patterns.add<" << name diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index c883baa..3236b4f 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Parser.h" #include <optional> @@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, llvm::SourceMgr tdSrcMgr; tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); + tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); // This class provides a context argument for the llvm::SourceMgr diagnostic // handler. diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 30fd384..9ef405d 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Remarks/RemarkFormat.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/ManagedStatic.h" @@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { "bitstream", "Print bitstream file")), llvm::cl::cat(remarkCategory)}; + static llvm::cl::opt<RemarkPolicy, /*ExternalStorage=*/true> remarkPolicy{ + "remark-policy", + llvm::cl::desc("Specify the policy for remark output."), + cl::location(remarkPolicyFlag), + llvm::cl::value_desc("format"), + llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL), + llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all", + "Print all remarks"), + clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final", + "Print final remarks")), + llvm::cl::cat(remarkCategory)}; + static cl::opt<std::string, /*ExternalStorage=*/true> remarksAll( "remarks-filter", cl::desc("Show all remarks: passed, missed, failed, analysis"), @@ -517,18 +530,28 @@ performActions(raw_ostream &os, return failure(); context->enableMultithreading(wasThreadingEnabled); - + // Set the remark categories and policy. remark::RemarkCategories cats{ config.getRemarksAllFilter(), config.getRemarksPassedFilter(), config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(), config.getRemarksFailedFilter()}; mlir::MLIRContext &ctx = *context; + // Helper to create the appropriate policy based on configuration + auto createPolicy = [&config]() + -> std::unique_ptr<mlir::remark::detail::RemarkEmittingPolicyBase> { + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyAll>(); + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyFinal>(); + + llvm_unreachable("Invalid remark policy"); + }; switch (config.getRemarkFormat()) { case RemarkFormat::REMARK_FORMAT_STDOUT: if (failed(mlir::remark::enableOptimizationRemarks( - ctx, nullptr, cats, true /*printAsEmitRemarks*/))) + ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/))) return failure(); break; @@ -537,7 +560,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.yaml" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::YAML, cats))) + ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats))) return failure(); break; } @@ -547,7 +570,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.bitstream" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::Bitstream, cats))) + ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats))) return failure(); break; } @@ -593,6 +616,12 @@ performActions(raw_ostream &os, AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); os << OpWithState(op.get(), asmState) << '\n'; + + // This is required if the remark policy is final. Otherwise, the remarks are + // not emitted. + if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine()) + engine->getRemarkEmittingPolicy()->finalize(); + return success(); } diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 60b9567..1dbe7eca 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -31,6 +31,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include <optional> using namespace mlir; @@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, llvm::append_range(includeDirs, extraDirs); sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) { diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 3080b78..2d817be 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" #include <optional> @@ -448,6 +449,7 @@ void TableGenTextFile::initialize( return; } sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); // This class provides a context argument for the SourceMgr diagnostic diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index 111f58e..5f3b04a 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -66,7 +66,9 @@ size_t mlir::moveLoopInvariantCode( size_t numMoved = 0; for (Region *region : regions) { - LDBG() << "Original loop:\n" << *region->getParentOp(); + LDBG() << "Original loop:\n" + << OpWithFlags(region->getParentOp(), + OpPrintingFlags().skipRegions()); std::queue<Operation *> worklist; // Add top-level operations in the loop body to the worklist. @@ -90,7 +92,8 @@ size_t mlir::moveLoopInvariantCode( !canBeHoisted(op, definedOutside)) continue; - LDBG() << "Moving loop-invariant op: " << *op; + LDBG() << "Moving loop-invariant op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); moveOutOfRegion(op, region); ++numMoved; @@ -111,9 +114,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { [&](Value value, Region *) { return loopLike.isDefinedOutsideOfLoop(value); }, - [&](Operation *op, Region *) { - return isMemoryEffectFree(op) && isSpeculatable(op); - }, + [&](Operation *op, Region *) { return isPure(op); }, [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 9f5246d..ffa96ad 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -137,6 +137,16 @@ declare_mlir_dialect_python_bindings( declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/OpenACCOps.td + SOURCES + dialects/openacc.py + DIALECT_NAME acc + DEPENDS acc_common_td + ) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUOps.td SOURCES_GLOB dialects/gpu/*.py DIALECT_NAME gpu diff --git a/mlir/python/mlir/dialects/OpenACCOps.td b/mlir/python/mlir/dialects/OpenACCOps.td new file mode 100644 index 0000000..69a3002 --- /dev/null +++ b/mlir/python/mlir/dialects/OpenACCOps.td @@ -0,0 +1,14 @@ +//===-- OpenACCOps.td - Entry point for OpenACCOps bind ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_OPENACC_OPS +#define PYTHON_BINDINGS_OPENACC_OPS + +include "mlir/Dialect/OpenACC/OpenACCOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py index 4cd80aa..b14ea68 100644 --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -3,5 +3,151 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_ops_gen import * +from .._gpu_ops_gen import _Dialect from .._gpu_enum_gen import * from ..._mlir_libs._mlirDialectsGPU import * +from typing import Callable, Sequence, Union, Optional, List + +try: + from ...ir import ( + FunctionType, + TypeAttr, + StringAttr, + UnitAttr, + Block, + InsertionPoint, + ArrayAttr, + Type, + DictAttr, + Attribute, + DenseI32ArrayAttr, + ) + from .._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GPUFuncOp(GPUFuncOp): + __doc__ = GPUFuncOp.__doc__ + + KERNEL_ATTR_NAME = "gpu.kernel" + KNOWN_BLOCK_SIZE_ATTR_NAME = "known_block_size" + KNOWN_GRID_SIZE_ATTR_NAME = "known_grid_size" + + FUNCTION_TYPE_ATTR_NAME = "function_type" + SYM_NAME_ATTR_NAME = "sym_name" + ARGUMENT_ATTR_NAME = "arg_attrs" + RESULT_ATTR_NAME = "res_attrs" + + def __init__( + self, + function_type: Union[FunctionType, TypeAttr], + sym_name: Optional[Union[str, StringAttr]] = None, + kernel: Optional[bool] = None, + workgroup_attrib_attrs: Optional[Sequence[dict]] = None, + private_attrib_attrs: Optional[Sequence[dict]] = None, + known_block_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None, + known_grid_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None, + loc=None, + ip=None, + body_builder: Optional[Callable[[GPUFuncOp], None]] = None, + ): + """ + Create a GPUFuncOp with the provided `function_type`, `sym_name`, + `kernel`, `workgroup_attrib_attrs`, `private_attrib_attrs`, `known_block_size`, + `known_grid_size`, and `body_builder`. + - `function_type` is a FunctionType or a TypeAttr. + - `sym_name` is a string or a StringAttr representing the function name. + - `kernel` is a boolean representing whether the function is a kernel. + - `workgroup_attrib_attrs` is an optional list of dictionaries. + - `private_attrib_attrs` is an optional list of dictionaries. + - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size. + - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size. + - `body_builder` is an optional callback. When provided, a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + function_type = ( + TypeAttr.get(function_type) + if not isinstance(function_type, TypeAttr) + else function_type + ) + super().__init__( + function_type, + workgroup_attrib_attrs=workgroup_attrib_attrs, + private_attrib_attrs=private_attrib_attrs, + loc=loc, + ip=ip, + ) + + if isinstance(sym_name, str): + self.attributes[self.SYM_NAME_ATTR_NAME] = StringAttr.get(sym_name) + elif isinstance(sym_name, StringAttr): + self.attributes[self.SYM_NAME_ATTR_NAME] = sym_name + else: + raise ValueError("sym_name must be a string or a StringAttr") + + if kernel: + self.attributes[self.KERNEL_ATTR_NAME] = UnitAttr.get() + + if known_block_size is not None: + if isinstance(known_block_size, Sequence): + block_size = DenseI32ArrayAttr.get(known_block_size) + self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = block_size + elif isinstance(known_block_size, DenseI32ArrayAttr): + self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = known_block_size + else: + raise ValueError( + "known_block_size must be a list of integers or a DenseI32ArrayAttr" + ) + + if known_grid_size is not None: + if isinstance(known_grid_size, Sequence): + grid_size = DenseI32ArrayAttr.get(known_grid_size) + self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = grid_size + elif isinstance(known_grid_size, DenseI32ArrayAttr): + self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = known_grid_size + else: + raise ValueError( + "known_grid_size must be a list of integers or a DenseI32ArrayAttr" + ) + + if body_builder is not None: + with InsertionPoint(self.add_entry_block()): + body_builder(self) + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes[self.SYM_NAME_ATTR_NAME]) + + @property + def is_kernel(self) -> bool: + return self.KERNEL_ATTR_NAME in self.attributes + + def add_entry_block(self) -> Block: + if len(self.body.blocks) > 0: + raise RuntimeError(f"Entry block already exists for {self.name.value}") + + function_type = self.function_type.value + return self.body.blocks.append( + *function_type.inputs, + arg_locs=[self.location for _ in function_type.inputs], + ) + + @property + def entry_block(self) -> Block: + if len(self.body.blocks) == 0: + raise RuntimeError( + f"Entry block does not exist for {self.name.value}." + + " Do you need to call the add_entry_block() method on this GPUFuncOp?" + ) + return self.body.blocks[0] + + @property + def arguments(self) -> Sequence[Type]: + return self.function_type.value.inputs diff --git a/mlir/python/mlir/dialects/openacc.py b/mlir/python/mlir/dialects/openacc.py new file mode 100644 index 0000000..057f71a --- /dev/null +++ b/mlir/python/mlir/dialects/openacc.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._acc_ops_gen import * diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index e3bacb5..14c7380 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -144,9 +144,10 @@ class FuseOp(FuseOp): loop_types: Union[Type, Sequence[Type]], target: Union[Operation, Value, OpView], *, - tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - tile_interchange: OptionalIntList = None, - apply_cleanup: Optional[bool] = False, + tile_sizes: Optional[MixedValues] = None, + tile_interchange: Optional[MixedValues] = None, + apply_cleanup: bool = False, + use_forall: bool = False, loc=None, ip=None, ): @@ -157,9 +158,10 @@ class FuseOp(FuseOp): self, target: Union[Operation, Value, OpView], *, - tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - tile_interchange: OptionalIntList = None, - apply_cleanup: Optional[bool] = False, + tile_sizes: Optional[MixedValues] = None, + tile_interchange: Optional[MixedValues] = None, + apply_cleanup: bool = False, + use_forall: bool = False, loc=None, ip=None, ): @@ -170,17 +172,26 @@ class FuseOp(FuseOp): loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], target_or_none: Optional[Union[Operation, Value, OpView]] = None, *, - tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - tile_interchange: OptionalIntList = None, - apply_cleanup: Optional[bool] = False, + tile_sizes: Optional[MixedValues] = None, + tile_interchange: Optional[MixedValues] = None, + apply_cleanup: bool = False, + use_forall: bool = False, loc=None, ip=None, ): tile_sizes = tile_sizes if tile_sizes else [] tile_interchange = tile_interchange if tile_interchange else [] - _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes) - _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange) - num_loops = sum(0 if v == 0 else 1 for v in tile_sizes) + ( + dynamic_tile_sizes, + static_tile_sizes, + _, + ) = _dispatch_dynamic_index_list(tile_sizes) + ( + dynamic_tile_interchange, + static_tile_interchange, + _, + ) = _dispatch_dynamic_index_list(tile_interchange) + num_loops = 1 if use_forall else sum(1 for v in static_tile_sizes if v != 0) if isinstance(loop_types_or_target, (Operation, Value, OpView)): loop_types = [transform.AnyOpType.get()] * num_loops @@ -197,9 +208,12 @@ class FuseOp(FuseOp): target.type, loop_types, target, - tile_sizes=tile_sizes, - tile_interchange=tile_interchange, + tile_sizes=dynamic_tile_sizes, + tile_interchange=dynamic_tile_interchange, + static_tile_sizes=static_tile_sizes, + static_tile_interchange=static_tile_interchange, apply_cleanup=apply_cleanup, + use_forall=use_forall, loc=loc, ip=ip, ) diff --git a/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir b/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir new file mode 100644 index 0000000..808c1c2 --- /dev/null +++ b/mlir/test/Analysis/DataFlow/test-strided-metadata-range-analysis.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt -test-strided-metadata-range-analysis %s 2>&1 | FileCheck %s + +func.func @memref_subview(%arg0: memref<8x16x4xf32, strided<[64, 4, 1]>>, %arg1: memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>>, %arg2: memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>>, %arg3: index, %arg4: index, %arg5: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index + %1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index + + // Test subview with unknown sizes, and constant offsets and strides. + // CHECK: Op: %[[SV0:.*]] = memref.subview + // CHECK-NEXT: result[0]: strided_metadata< + // CHECK-SAME: offset = [{unsigned : [1, 1] signed : [1, 1]}] + // CHECK-SAME: sizes = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] + // CHECK-SAME: strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}] + %subview = memref.subview %arg0[%c0, %c0, %c1] [%arg3, %arg4, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> + + // Test a subview of a subview, with bounded dynamic offsets. + // CHECK: Op: %[[SV1:.*]] = memref.subview + // CHECK-NEXT: result[0]: strided_metadata< + // CHECK-SAME: offset = [{unsigned : [346, 484] signed : [346, 484]}] + // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}] + // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}] + %subview_0 = memref.subview %subview[%1, %1, %1] [%c2, %c2, %c2] [%0, %0, %0] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> + + // Test a subview of a subview, with constant operands. + // CHECK: Op: %[[SV2:.*]] = memref.subview + // CHECK-NEXT: result[0]: strided_metadata< + // CHECK-SAME: offset = [{unsigned : [368, 510] signed : [368, 510]}] + // CHECK-SAME: sizes = [{unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}, {unsigned : [2, 2] signed : [2, 2]}] + // CHECK-SAME: strides = [{unsigned : [704, 832] signed : [704, 832]}, {unsigned : [44, 52] signed : [44, 52]}, {unsigned : [11, 13] signed : [11, 13]}] + %subview_1 = memref.subview %subview_0[%c0, %c0, %c2] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> + + // Test a rank-reducing subview. + // CHECK: Op: %[[SV3:.*]] = memref.subview + // CHECK-NEXT: result[0]: strided_metadata< + // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] + // CHECK-SAME: sizes = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [16, 16] signed : [16, 16]}] + // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] + %subview_2 = memref.subview %arg1[%arg4, %arg4, %arg4, %arg4, %arg4] [1, 64, 1, 16, 1] [%arg5, %arg5, %arg5, %arg5, %arg5] : memref<1x128x1x32x1xf32, strided<[4096, 32, 32, 1, 1]>> to memref<64x16xf32, strided<[?, ?], offset: ?>> + + // Test a subview of a rank-reducing subview + // CHECK: Op: %[[SV4:.*]] = memref.subview + // CHECK-NEXT: result[0]: strided_metadata< + // CHECK-SAME: offset = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] + // CHECK-SAME: sizes = [{unsigned : [5, 7] signed : [5, 7]}] + // CHECK-SAME: strides = [{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] + %subview_3 = memref.subview %subview_2[%c0, %0] [1, %1] [%c1, %c2] : memref<64x16xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>> + + // Test a subview with mixed bounded and unbound dynamic sizes. + // CHECK: Op: %[[SV5:.*]] = memref.subview + // CHECK-NEXT: result[0]: strided_metadata< + // CHECK-SAME: offset = [{unsigned : [32, 32] signed : [32, 32]}] + // CHECK-SAME: sizes = [{unsigned : [11, 13] signed : [11, 13]}, {unsigned : [5, 7] signed : [5, 7]}, {unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}] + // CHECK-SAME: strides = [{unsigned : [1, 1] signed : [1, 1]}, {unsigned : [64, 64] signed : [64, 64]}, {unsigned : [8, 8] signed : [8, 8]}] + %subview_4 = memref.subview %arg2[%c0, %c0, %c2] [%0, %1, %arg5] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[1, 64, 8], offset: 16>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> + return +} + +// CHECK: func.func @memref_subview +// CHECK: %[[A0:.*]]: memref<8x16x4xf32, strided<[64, 4, 1]>> +// CHECK: %[[SV0]] = memref.subview %[[A0]] +// CHECK-NEXT: %[[SV1]] = memref.subview +// CHECK-NEXT: %[[SV2]] = memref.subview +// CHECK-NEXT: %[[SV3]] = memref.subview +// CHECK-NEXT: %[[SV4]] = memref.subview +// CHECK-NEXT: %[[SV5]] = memref.subview diff --git a/mlir/test/Analysis/test-alias-analysis.mlir b/mlir/test/Analysis/test-alias-analysis.mlir index 8cbee61..d71adee 100644 --- a/mlir/test/Analysis/test-alias-analysis.mlir +++ b/mlir/test/Analysis/test-alias-analysis.mlir @@ -256,3 +256,19 @@ func.func @constants(%arg: memref<2xf32>) attributes {test.ptr = "func"} { return } + +// ----- + +// CHECK-LABEL: Testing : "distinct_objects" +// CHECK-DAG: func.region0#0 <-> func.region0#1: MayAlias + +// CHECK-DAG: distinct#0 <-> distinct#1: NoAlias +// CHECK-DAG: distinct#0 <-> func.region0#0: MustAlias +// CHECK-DAG: distinct#1 <-> func.region0#0: MayAlias +// CHECK-DAG: distinct#0 <-> func.region0#1: MayAlias +// CHECK-DAG: distinct#1 <-> func.region0#1: MustAlias + +func.func @distinct_objects(%arg: memref<?xf32>, %arg1: memref<?xf32>) attributes {test.ptr = "func"} { + %0, %1 = memref.distinct_objects %arg, %arg1 {test.ptr = "distinct"} : memref<?xf32>, memref<?xf32> + return +} diff --git a/mlir/test/Conversion/MathToXeVM/lit.local.cfg b/mlir/test/Conversion/MathToXeVM/lit.local.cfg new file mode 100644 index 0000000..cc1ce35 --- /dev/null +++ b/mlir/test/Conversion/MathToXeVM/lit.local.cfg @@ -0,0 +1,7 @@ +spirv_backend_tests = [ + 'native-spirv-builtins.mlir', +] + +# Exclude SPIRV backend tests if SPIRV target is disabled: +if(not config.run_xevm_tests): + config.excludes.update(spirv_backend_tests) diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir new file mode 100644 index 0000000..d76627b --- /dev/null +++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir @@ -0,0 +1,155 @@ +// RUN: mlir-opt %s -convert-math-to-xevm \ +// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH' +// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \ +// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH' + +module @test_module { + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + // + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32> + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16> + // + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 + // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16 + // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 + // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 + // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 + + // CHECK-LABEL: func @math_ops + func.func @math_ops() { + + %c1_f16 = arith.constant 1. : f16 + %c1_f32 = arith.constant 1. : f32 + %c1_f64 = arith.constant 1. : f64 + + // CHECK: math.exp + %exp_normal_f16 = math.exp %c1_f16 : f16 + // CHECK: math.exp + %exp_normal_f32 = math.exp %c1_f32 : f32 + // CHECK: math.exp + %exp_normal_f64 = math.exp %c1_f64 : f64 + + // Check float operations are converted properly: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16 + %exp_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32 + %exp_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64 + %exp_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16 + %exp_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32 + %exp_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64 + %exp_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64 + + // CHECK: math.exp + %exp_none_f16 = math.exp %c1_f16 fastmath<none> : f16 + // CHECK: math.exp + %exp_none_f32 = math.exp %c1_f32 fastmath<none> : f32 + // CHECK: math.exp + %exp_none_f64 = math.exp %c1_f64 fastmath<none> : f64 + + // Check vector operations: + + %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64> + %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64> + %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64> + %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64> + %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64> + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<2xf64>) -> vector<2xf64> + %exp_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<3xf64>) -> vector<3xf64> + %exp_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<4xf64>) -> vector<4xf64> + %exp_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64> + %exp_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf64>) -> vector<16xf64> + %exp_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64> + + %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32> + %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16> + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<16xf32>) -> vector<16xf32> + %exp_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32> + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf16>) -> vector<4xf16> + %exp_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16> + + // Check unsupported vector sizes are not converted: + + %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64> + %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64> + + // CHECK: math.exp + %exp_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64> + // CHECK: math.exp + %exp_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64> + + // Check fastmath flags propagate properly: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16 + %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f16 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf, nsz, arcp, contract, afn>} : (f32) -> f32 + %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath<nnan,ninf,nsz,arcp,contract,afn> : f32 + // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, afn, reassoc>} : (f32) -> f32 + %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath<afn,reassoc,nnan> : f32 + + // Check all other math operations: + + // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16 + %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32 + %exp2_afn_f32 = math.exp2 %c1_f32 fastmath<afn> : f32 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16 + %log_afn_f16 = math.log %c1_f16 fastmath<afn> : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32 + %log2_afn_f32 = math.log2 %c1_f32 fastmath<afn> : f32 + + // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64 + %log10_afn_f64 = math.log10 %c1_f64 fastmath<afn> : f64 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16, f16) -> f16 + %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath<afn> : f16 + + // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64 + %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath<afn> : f64 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16 + %sin_afn_f16 = math.sin %c1_f16 fastmath<afn> : f16 + + // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32 + %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath<afn> : f32 + + // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64 + %tan_afn_f64 = math.tan %c1_f64 fastmath<afn> : f64 + + %c6_9_f32 = arith.constant 6.9 : f32 + %c7_f32 = arith.constant 7. : f32 + + // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32 + // CHECK-NO-ARITH: arith.divf + %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath<afn> : f32 + + return + } +} diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir new file mode 100644 index 0000000..82426c4 --- /dev/null +++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \ +// RUN: -debug-only=serialize-to-isa 2> %t +// RUN: FileCheck --input-file=%t %s +// REQUIRES: asserts +// +// MathToXeVM pass generates OpenCL intrinsics function calls when converting +// Math ops with `fastmath` attr to native function calls. It is assumed that +// the SPIRV backend would correctly convert these intrinsics calls to OpenCL +// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp). +// +// To ensure this assumption holds, this test verifies that the SPIRV backend +// behaves as expected. + +module @test_ocl_intrinsics attributes {gpu.container_module} { + gpu.module @kernel [#xevm.target] { + llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} { + // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16 + // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]] + %c0_f16 = llvm.mlir.constant(0. : f16) : f16 + // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32 + // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]] + %c0_f32 = llvm.mlir.constant(0. : f32) : f32 + // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64 + // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]] + %c0_f64 = llvm.mlir.constant(0. : f64) : f64 + + // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2 + // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]] + %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64> + // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3 + // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]] + %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32> + // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4 + // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]] + %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64> + // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8 + // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]] + %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64> + // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16 + // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]] + %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16> + + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]] + %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]] + %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32 + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]] + %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64 + + // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]] + %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64> + // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]] + %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32> + // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]] + %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64> + // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]] + %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64> + // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]] + %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16> + + // SPIRV backend does not currently handle fastmath flags: The SPIRV + // backend would need to generate OpDecorate calls to decorate math ops + // with FPFastMathMode/FPFastMathModeINTEL decorations. + // + // FIXME: When support for fastmath flags in the SPIRV backend is added, + // add tests here to ensure fastmath flags are converted to the correct + // OpDecorate calls. + // + // See: + // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions + // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate + + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]] + %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]] + %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32 + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]] + %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]] + %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32 + // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]] + %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64> + // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]] + %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16> + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]] + %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64 + // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]] + %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]] + %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32 + // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]] + %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64 + // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]] + %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32 + + llvm.return + } + + llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16 + llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64 + llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64> + llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32> + llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64> + llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64> + llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16> + llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32 + llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64> + llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16> + llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64 + llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16 + llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32 + llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64 + llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32 + } +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index a7a73ae..780c25a 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1538,6 +1538,92 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8> // ----- +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()> +// CHECK-LABEL: @rescale_no_const +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) { + // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32> + // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8> + // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8> + // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) { + // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8): + // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32 + // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32 + // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32 + // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32 + // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32 + // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32 + // CHECK: %c-128_i32 = arith.constant -128 : i32 + // CHECK: %c127_i32 = arith.constant 127 : i32 + // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32 + // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32 + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> + return %0 : tensor<2xi8> +} + +// ----- + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()> +// CHECK-LABEL: @rescale_no_const_per_channel +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]] +func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) { + // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8> + // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) { + // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8): + // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32 + // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32 + // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32 + // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32 + // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32 + // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32 + // CHECK: %c-128_i32 = arith.constant -128 : i32 + // CHECK: %c127_i32 = arith.constant 127 : i32 + // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32 + // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32 + %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> + return %0 : tensor<2xi8> +} + +// ----- + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()> +// CHECK-LABEL: @rescale_no_const_per_channel_input_output_zp_ui8 +// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]] +// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]] +func.func @rescale_no_const_per_channel_input_output_zp_ui8(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xui8>, %output_zp : tensor<1xui8>) -> (tensor<2xui8>) { + // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xui8> into tensor<ui8> + // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xui8> into tensor<ui8> + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xui8> + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<ui8>, tensor<ui8>) outs([[INIT]] : tensor<2xui8>) { + // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: ui8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8): + // CHECK: [[INPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG3]] : ui8 to i8 + // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32 + // CHECK: [[OUTPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8 + // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extui [[OUTPUT_ZP_I8]] : i8 to i32 + // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32 + // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32 + // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32 + // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32 + // CHECK: %c0_i32 = arith.constant 0 : i32 + // CHECK: %c255_i32 = arith.constant 255 : i32 + // CHECK: [[MAX:%.+]] = arith.maxsi %c0_i32, [[TMP3]] : i32 + // CHECK: [[MIN:%.+]] = arith.minsi %c255_i32, [[MAX]] : i32 + %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xui8>, tensor<1xui8>) -> tensor<2xui8> + return %0 : tensor<2xui8> +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @reverse diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2d33888..d669a3b 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -76,6 +76,18 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> { // ----- +func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1xf32> + return %0 : vector<1xf32> +} +// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32 +// CHECK-SAME: %[[A:.*]]: f32) +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK-NOT: llvm.shufflevector +// CHECK: return %[[T0]] : vector<1xf32> + +// ----- + func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> { %0 = vector.broadcast %arg0 : f32 to vector<[2]xf32> return %0 : vector<[2]xf32> diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index e6f22f0..a9ab0be 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -1,17 +1,13 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> -#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]> -#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - -gpu.module @load_store_check { +gpu.module @test_kernel { // CHECK-LABEL: func.func @dpas( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { // Loads are checked in a separate test. // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>} // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> return %d : vector<8xf32> } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir new file mode 100644 index 0000000..d4cb493 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -0,0 +1,201 @@ +// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s + +gpu.module @test_kernel [#xevm.target<chip = "pvc">] { + + // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> + // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) + //CHECK-LABEL: load_store_matrix_1 + gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + + gpu.return %1: f32 + } + +// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_2 + gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c13:.*]] = arith.constant 13 : index + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 + + + %tid_x = gpu.thread_id x + %c13 = arith.constant 13 : index + %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index + gpu.return %1: f16 + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_3 + gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c19:.*]] = arith.constant 19 : index + %tid_x = gpu.thread_id x + %c19 = arith.constant 19: index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 + %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index + + //CHECK: gpu.return %[[loaded]] : f16 + gpu.return %1: f16 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_4 + gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> + + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> + + %tid_x = gpu.thread_id x + %c16 = arith.constant 16 : index + %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16> + + //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index + + gpu.return %1: vector<8xf16> + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_5 + gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[c48:.*]] = arith.constant 48 : index + + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index + //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index + //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 + //CHECK: %[[c2:.*]] = arith.constant 2 : i32 + //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 + //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> + //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> + //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> + + %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16> + + //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) + + xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index + + gpu.return %1: vector<8xf16> + } + +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index 0b150e9..9c552d8 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -14,19 +14,36 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64 // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> // CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) { - // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16> - // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16> - // CHECK: scf.yield %[[VAR8]] : f16 - // CHECK: } else { - // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16> - // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16> + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16 // CHECK: scf.yield %[[VAR7]] : f16 + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16 + // CHECK: scf.yield %[[CST_0]] : f16 // CHECK: } %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> gpu.return } } + +// ----- +gpu.module @test { +// CHECK-LABEL: @source_materialize_single_elem_vec +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16> +gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) { + %1 = arith.constant dense<1>: vector<1xi1> + %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> + // CHECK: %[[VAR_IF:.*]] = scf.if + // CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16> + %c0 = arith.constant 0 : index + vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16> + gpu.return +} +} + // ----- gpu.module @test { diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir index 7e562b00..a109f42 100644 --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -60,30 +60,74 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) { return } -// CHECK-LABEL: strides( -// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]] -// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]] -// CHECK: llvm.mlir.constant(2 : i64) : i64 +/// Intrinsics require stride in number of bytes. +// CHECK-LABEL: strides_implicit( +// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]] +// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]] // CHECK: llvm.extractvalue %{{.+}}[4, 0] -// CHECK: %[[STRIDE_1:.+]] = llvm.mul -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]] -// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]] -// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]] -// CHECK: llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[LOAD_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0] +// CHECK: %[[LOAD_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE]], %[[LOAD_BUF_STRIDE]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]] +// CHECK: %[[STORE_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_1]] +// CHECK: %[[STORE_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_2]] // CHECK: llvm.extractvalue %{{.+}}[4, 0] -// CHECK: %[[STRIDE_2:.+]] = llvm.mul -// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]] -func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) { +// CHECK: %[[STORE_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0] +// CHECK: %[[STORE_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[STORE_STRIDE_3:.+]] = llvm.mul %[[STORE_STRIDE_SCALE]], %[[STORE_BUF_STRIDE]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_3]] +func.func @strides_implicit(%arg0: memref<16x32xi8>, + %arg1: memref<32x32xbf16, strided<[64, 1]>>, + %arg2: memref<16x32xf32, strided<[?, 1]>>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16> - %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> - %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16> - amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16> - amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16> - amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16> + %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8> + %2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> + %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32> + amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8> + amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16> + amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32> + return +} + +/// Intrinsics require stride in number of bytes. +// CHECK-LABEL: strides_explicit( +// CHECK-SAME: %[[STRIDE:.+]]: index +// CHECK-DAG: %[[STRIDE_I64:.+]] = builtin.unrealized_conversion_cast %[[STRIDE]] : index to i64 +// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[C64_I64:.+]] = builtin.unrealized_conversion_cast %[[C64]] : index to i64 +// CHECK: %[[LOAD_STRIDE_SCALE_1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_1]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]] +// CHECK: %[[LOAD_STRIDE_SCALE_2:.+]] = llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_2]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]] +// CHECK: %[[LOAD_STRIDE_SCALE_3:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_3]], %[[C64_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]] +// CHECK: %[[STORE_STRIDE_SCALE_1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[STORE_STRIDE_1:.+]] = llvm.mul %[[STORE_STRIDE_SCALE_1]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_1]] +// CHECK: %[[STORE_STRIDE_SCALE_2:.+]] = llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[STORE_STRIDE_2:.+]] = llvm.mul %[[STORE_STRIDE_SCALE_2]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_2]] +// CHECK: %[[STORE_STRIDE_SCALE_3:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[STORE_STRIDE_3:.+]] = llvm.mul %[[STORE_STRIDE_SCALE_3]], %[[C64_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_3]] +func.func @strides_explicit(%stride: index, + %arg0: memref<?xi8>, + %arg1: memref<16x32xbf16>, + %arg2: memref<32x32xf32, strided<[64, 1]>>) { + %0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %1 = amx.tile_load %arg0[%0], %stride : memref<?xi8> into !amx.tile<16x32xi8> + %2 = amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !amx.tile<16x32xbf16> + %3 = amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !amx.tile<16x16xf32> + amx.tile_store %arg0[%0], %1, %stride : memref<?xi8>, !amx.tile<16x32xi8> + amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !amx.tile<16x32xbf16> + amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !amx.tile<16x16xf32> return } diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir index 1b7f781..3d0f276 100644 --- a/mlir/test/Dialect/AMX/roundtrip.mlir +++ b/mlir/test/Dialect/AMX/roundtrip.mlir @@ -1,5 +1,33 @@ // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s +// CHECK-LABEL: tloadstore +// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} : +// CHECK-SAME: memref<?xbf16> into !amx.tile<16x32xbf16> +// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : +// CHECK-SAME: memref<?x?xbf16> into !amx.tile<16x32xbf16> +// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : +// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} : +// CHECK-SAME: memref<?xbf16>, !amx.tile<16x32xbf16> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} : +// CHECK-SAME: memref<?x?xbf16>, !amx.tile<16x32xbf16> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] : +// CHECK-SAME: memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16> +func.func @tloadstore(%stride: index, + %arg0: memref<?xbf16>, + %arg1: memref<?x?xbf16>, + %arg2: memref<?x?xbf16, strided<[64, 1]>>) { + %0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %1 = amx.tile_load %arg0[%0], %stride : memref<?xbf16> into !amx.tile<16x32xbf16> + %2 = amx.tile_load %arg1[%0, %0], %stride : memref<?x?xbf16> into !amx.tile<16x32xbf16> + %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> + amx.tile_store %arg0[%0], %3, %stride : memref<?xbf16>, !amx.tile<16x32xbf16> + amx.tile_store %arg1[%0, %0], %1, %stride : memref<?x?xbf16>, !amx.tile<16x32xbf16> + amx.tile_store %arg2[%0, %0], %2 : memref<?x?xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16> + return +} + // CHECK-LABEL: tzero // CHECK: amx.tile_zero : !amx.tile<16x16xbf16> // CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16> diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index e56079c..1169cd1 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind // ----- +// CHECK-LABEL: func @delin_apply_cancel_exact +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]] +// CHECK-NOT: memref.store +// CHECK: return +func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref<?xindex>) { + %a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index + %b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index + %c:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + %t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0] + memref.store %t2, %arg1[%t2] : memref<?xindex> + + %t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1] + memref.store %t3, %arg1[%t3] : memref<?xindex> + + %t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0] + memref.store %t4, %arg1[%t4] : memref<?xindex> + + %t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0] + memref.store %t5, %arg1[%t5] : memref<?xindex> + + %t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0] + memref.store %t6, %arg1[%t5] : memref<?xindex> + + return +} + +// ----- + +// CHECK-LABEL: func @delin_apply_cancel_exact_dim +// CHECK: affine.for %[[arg1:.+]] = 0 to 256 +// CHECK: memref.store %[[arg1]] +// CHECK: return +func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) { + affine.for %arg1 = 0 to 256 { + %a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index + %i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1) + memref.store %i, %arg0[%i] : memref<?xindex> + } + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)> +// CHECK-LABEL: func @delin_apply_cancel_const_term +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref<?xindex>) { + %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)> +// CHECK-LABEL: func @delin_apply_cancel_var_term +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref<?xindex>, %arg2: index) { + %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)> +// CHECK-LABEL: func @delin_apply_cancel_nested_exprs +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref<?xindex>) { + %a:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20) +// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref<?xindex>) { + %a:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)> +// CHECK-LABEL: func @delin_apply_dont_cancel_partial +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5) +// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1] +// CHECK: return +func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref<?xindex>) { + %a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + // CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index // CHECK-SAME: (%[[ARG0:.*]]: index) // CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir index c58b153..21b508e 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir @@ -65,13 +65,13 @@ func.func @main(%t: tensor<?xf32>, %sz: index, %idx: index) -> (f32, f32) { // ----- -func.func @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> { +func.func private @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> { func.return %A : tensor<?xf32> } -// CHECK-LABEL: func @return_arg +// CHECK-LABEL: func private @return_arg // CHECK-SAME: %[[A:.*]]: memref<?xf32 // CHECK-NOT: return %[[A]] -// NO-DROP-LABEL: func @return_arg +// NO-DROP-LABEL: func private @return_arg // NO-DROP-SAME: %[[A:.*]]: memref<?xf32 // NO-DROP: return %[[A]] diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 6054a61..d5f834b 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -171,9 +171,9 @@ func.func @func_without_tensor_args(%v : vector<10xf32>) -> () { // Bufferization of a function that is reading and writing. %t0 is writable, so // no copy should be inserted. -// CHECK-LABEL: func @inner_func( +// CHECK-LABEL: func private @inner_func( // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 -func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { +func.func private @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { // CHECK-NOT: copy %f = arith.constant 1.0 : f32 %c0 = arith.constant 0 : index @@ -186,9 +186,9 @@ func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { return %0, %1 : tensor<?xf32>, f32 } -// CHECK-LABEL: func @call_func_with_non_tensor_return( +// CHECK-LABEL: func private @call_func_with_non_tensor_return( // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 -func.func @call_func_with_non_tensor_return( +func.func private @call_func_with_non_tensor_return( %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) { // CHECK-NOT: alloc // CHECK-NOT: copy @@ -203,9 +203,9 @@ func.func @call_func_with_non_tensor_return( // Bufferization of a function that is reading and writing. %t0 is not writable, // so a copy is needed. -// CHECK-LABEL: func @inner_func( +// CHECK-LABEL: func private @inner_func( // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 -func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { +func.func private @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { // CHECK-NOT: copy %f = arith.constant 1.0 : f32 %c0 = arith.constant 0 : index @@ -276,10 +276,10 @@ func.func @main(%t: tensor<?xf32> {bufferization.writable = false}) -> (f32) { // This function does not read, just write. We need an alloc, but no copy. -// CHECK-LABEL: func @does_not_read( +// CHECK-LABEL: func private @does_not_read( // CHECK-NOT: alloc // CHECK-NOT: copy -func.func @does_not_read(%t: tensor<?xf32>) -> tensor<?xf32> { +func.func private @does_not_read(%t: tensor<?xf32>) -> tensor<?xf32> { %f0 = arith.constant 0.0 : f32 %r = linalg.fill ins(%f0 : f32) outs(%t : tensor<?xf32>) -> tensor<?xf32> return %r : tensor<?xf32> @@ -354,9 +354,9 @@ func.func @main() { // A write inside an scf.execute_region. An equivalent tensor is yielded. -// CHECK-LABEL: func @execute_region_test( +// CHECK-LABEL: func private @execute_region_test( // CHECK-SAME: %[[m1:.*]]: memref<?xf32 -func.func @execute_region_test(%t1 : tensor<?xf32>) +func.func private @execute_region_test(%t1 : tensor<?xf32>) -> (f32, tensor<?xf32>, f32) { %f1 = arith.constant 0.0 : f32 @@ -397,11 +397,11 @@ func.func @no_inline_execute_region_not_canonicalized() { // CHECK: func private @some_external_func(memref<?xf32, strided<[?], offset: ?>>) func.func private @some_external_func(tensor<?xf32>) -// CHECK: func @scf_for_with_tensor_insert_slice( +// CHECK: func private @scf_for_with_tensor_insert_slice( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>> // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>> // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>> -func.func @scf_for_with_tensor_insert_slice( +func.func private @scf_for_with_tensor_insert_slice( %A : tensor<?xf32>, %B : tensor<?xf32>, %C : tensor<4xf32>, %lb : index, %ub : index, %step : index) -> (tensor<?xf32>, tensor<?xf32>) @@ -456,11 +456,11 @@ func.func @bar( // ----- -// CHECK: func @init_and_dot( +// CHECK: func private @init_and_dot( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, strided<[?], offset: ?>> // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, strided<[?], offset: ?>> // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, strided<[], offset: ?>> -func.func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> { +func.func private @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> { // CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32 %v0 = arith.constant 0.0 : f32 @@ -574,9 +574,9 @@ func.func @entry(%A : tensor<?xf32> {bufferization.buffer_layout = affine_map<(i // No alloc or copy inside of the loop. -// CHECK-LABEL: func @inner_func( +// CHECK-LABEL: func private @inner_func( // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 -func.func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> { +func.func private @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> { %f = arith.constant 1.0 : f32 %c0 = arith.constant 0 : index // CHECK: memref.store %{{.*}}, %[[arg0]] diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir index e2ab876..b52612d 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir @@ -24,10 +24,46 @@ // CHECK-NOT: copy // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32) - // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}> + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32{{.*}}> return %1, %0 : f32, tensor<?xf32> } "test.finish" () : () -> () }) : () -> () +// ----- +#enc1 = #test.tensor_encoding<"hello"> +#enc2 = #test.tensor_encoding<"not hello"> + +"test.symbol_scope_isolated"() ({ + // CHECK: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>) + // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"hello">> + func.func @inner_func(%t: tensor<?xf32, #enc1>) + -> tensor<?xf32, #enc1> { + // CHECK: return %[[arg0]] + return %t : tensor<?xf32, #enc1> + } + + // CHECK: func @outer_func( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>) + // CHECK-SAME: -> (memref<?xf32, #test.memref_layout<"hello">>, + // CHECK-SAME: memref<?xf32, #test.memref_layout<"not hello">>) + func.func @outer_func(%t0: tensor<?xf32, #enc1>) + -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) { + // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]]) + %0 = call @inner_func(%t0) + : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>) + + // CHECK: %[[local:.*]] = "test.create_memref_op"() : () + // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"not hello">> + %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2> + // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]]) + %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>) + -> tensor<?xf32, #enc2> + + // CHECK: return %[[call]], %[[dummy]] + return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2> + } + "test.finish" () : () -> () +}) : () -> () diff --git a/mlir/test/Dialect/LLVMIR/bytecode.mlir b/mlir/test/Dialect/LLVMIR/bytecode.mlir new file mode 100644 index 0000000..821b0ac --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/bytecode.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -verify-roundtrip %s + +#access_group = #llvm.access_group<id = distinct[0]<>> +#access_group1 = #llvm.access_group<id = distinct[1]<>> +#di_subprogram = #llvm.di_subprogram<recId = distinct[2]<>> +#loc1 = loc("test.f90":12:14) +#loc2 = loc("test":4:3) +#loc6 = loc(fused<#di_subprogram>[#loc1]) +#loc7 = loc(fused<#di_subprogram>[#loc2]) +#loop_annotation = #llvm.loop_annotation<disableNonforced = false, mustProgress = true, startLoc = #loc6, endLoc = #loc7, parallelAccesses = #access_group, #access_group1> +module { + llvm.func @imp_fn() { + llvm.return loc(#loc2) + } loc(#loc8) + llvm.func @loop_annotation_with_locs() { + llvm.br ^bb1 {loop_annotation = #loop_annotation} loc(#loc4) + ^bb1: // pred: ^bb0 + llvm.return loc(#loc5) + } loc(#loc3) +} loc(#loc) +#di_file = #llvm.di_file<"test.f90" in ""> +#di_subroutine_type = #llvm.di_subroutine_type<callingConvention = DW_CC_program> +#loc = loc("test":0:0) +#loc3 = loc("test-path":36:3) +#loc4 = loc("test-path":37:5) +#loc5 = loc("test-path":39:5) +#di_compile_unit = #llvm.di_compile_unit<id = distinct[3]<>, sourceLanguage = DW_LANG_Fortran95, file = #di_file, isOptimized = false, emissionKind = Full> +#di_compile_unit1 = #llvm.di_compile_unit<id = distinct[4]<>, sourceLanguage = DW_LANG_Fortran95, file = #di_file, isOptimized = false, emissionKind = Full> +#di_compile_unit2 = #llvm.di_compile_unit<id = distinct[5]<>, sourceLanguage = DW_LANG_Fortran95, file = #di_file, isOptimized = false, emissionKind = Full> +#di_module = #llvm.di_module<file = #di_file, scope = #di_compile_unit1, name = "mod1"> +#di_module1 = #llvm.di_module<file = #di_file, scope = #di_compile_unit2, name = "mod2"> +#di_imported_entity = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #di_subprogram, entity = #di_module, file = #di_file, line = 1> +#di_imported_entity1 = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #di_subprogram, entity = #di_module1, file = #di_file, line = 1> +#di_subprogram1 = #llvm.di_subprogram<recId = distinct[2]<>, id = distinct[6]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, retainedNodes = #di_imported_entity, #di_imported_entity1> +#loc8 = loc(fused<#di_subprogram1>[#loc1]) diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir index 8accf6e..755e3a3 100644 --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr { // ----- +// CHECK-LABEL: fold_shufflevector +// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32> +llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> { + // CHECK-NOT: llvm.shufflevector + %c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32> + // CHECK: llvm.return %[[ARG1]] + llvm.return %c : vector<1xf32> +} + +// ----- + // Check that LLVM constants participate in cross-dialect constant folding. The // resulting constant is created in the arith dialect because the last folded // operation belongs to it. diff --git a/mlir/test/Dialect/LLVMIR/debuginfo.mlir b/mlir/test/Dialect/LLVMIR/debuginfo.mlir index 1834b0a..d7bf99b 100644 --- a/mlir/test/Dialect/LLVMIR/debuginfo.mlir +++ b/mlir/test/Dialect/LLVMIR/debuginfo.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s // CHECK-DAG: #[[FILE:.*]] = #llvm.di_file<"debuginfo.mlir" in "/test/"> #file = #llvm.di_file<"debuginfo.mlir" in "/test/"> diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 358bd33..242c04f 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1035,6 +1035,20 @@ llvm.func @rocdl.s.wait.expcnt() { llvm.return } +llvm.func @rocdl.s.wait.asynccnt() { + // CHECK-LABEL: rocdl.s.wait.asynccnt + // CHECK: rocdl.s.wait.asynccnt 0 + rocdl.s.wait.asynccnt 0 + llvm.return +} + +llvm.func @rocdl.s.wait.tensorcnt() { + // CHECK-LABEL: rocdl.s.wait.tensorcnt + // CHECK: rocdl.s.wait.tensorcnt 0 + rocdl.s.wait.tensorcnt 0 + llvm.return +} + // ----- llvm.func @rocdl.readfirstlane(%src : f32) -> f32 { diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 7344797..00e763a 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -verify-roundtrip %s // CHECK-LABEL: func @baz @@ -757,7 +757,7 @@ llvm.func @stackrestore(%arg0: !llvm.ptr) { // CHECK-LABEL: @experimental_noalias_scope_decl llvm.func @experimental_noalias_scope_decl() { - // CHECK: llvm.intr.experimental.noalias.scope.decl #{{.*}} + // CHECK: llvm.intr.experimental.noalias.scope.decl #alias_scope{{.*}} llvm.intr.experimental.noalias.scope.decl #alias_scope llvm.return } @@ -767,7 +767,7 @@ llvm.func @experimental_noalias_scope_decl() { // CHECK-LABEL: @experimental_noalias_scope_with_string_id llvm.func @experimental_noalias_scope_with_string_id() { - // CHECK: llvm.intr.experimental.noalias.scope.decl #{{.*}} + // CHECK: llvm.intr.experimental.noalias.scope.decl #alias_scope{{.*}} llvm.intr.experimental.noalias.scope.decl #alias_scope2 llvm.return } diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir index 18a09f4..12292ee 100644 --- a/mlir/test/Dialect/Linalg/decompose-pack.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir @@ -31,6 +31,25 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi // ----- +func.func @NCHW_to_NCHWc(%src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> { + %pack = linalg.pack %src + inner_dims_pos = [1] + inner_tiles = [32] into %dest + : tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32> + return %pack : tensor<2x1x16x8x32xf32> +} +// CHECK-LABEL: func.func @NCHW_to_NCHWc( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x16x8x32xf32> +// CHECK: %[[TR:.*]] = linalg.transpose ins(%[[SRC]] : tensor<2x32x16x8xf32>) outs(%[[INIT]] : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1] +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]] +// CHECK-SAME: [0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1] +// CHECK-SAME: : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32> +// CHECK: return %[[RES]] : tensor<2x1x16x8x32xf32> + +// ----- + func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> { %0 = linalg.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32> return %0 : tensor<1x1x8x2xf32> @@ -157,6 +176,8 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t // ----- +// Note - un-tiled outer dims are permueted. However, these are unit dims, which is supported. + func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> { %0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32> return %0 : tensor<1x1x1x1x2x?xf32> @@ -182,6 +203,28 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x // ----- +// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (7,1) -> (1, 7) + +func.func @negative_not_all_dims_tiled_outer_dim_0_permuted(%input: tensor<7x1x5x1xf32>, %output: tensor<1x7x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x7x1x1x2x?xf32> { + %0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<7x1x5x1xf32> -> tensor<1x7x1x1x2x?xf32> + return %0 : tensor<1x7x1x1x2x?xf32> +} +// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_0_permuted +// CHECK: linalg.pack + +// ----- + +// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (1, 7) -> (7, 1). + +func.func @negative_not_all_dims_tiled_outer_dim_1_permuted(%input: tensor<1x7x5x1xf32>, %output: tensor<7x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<7x1x1x1x2x?xf32> { + %0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x7x5x1xf32> -> tensor<7x1x1x1x2x?xf32> + return %0 : tensor<7x1x1x1x2x?xf32> +} +// CHECK-LABEL: func.func @negative_not_all_dims_tiled_outer_dim_1_permuted +// CHECK: linalg.pack + +// ----- + func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{ %0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32> return %0 : tensor<1x1x32x8xf32> @@ -295,3 +338,21 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32> // CHECK: return %[[INSERT]] + +// ----- + +/// Note "126", which is a non-unit tiled-outer-dim. This is not supported. + +func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> { + %pack = linalg.pack %src + padding_value(%pad : f32) + outer_dims_perm = [0, 3, 2, 1] + inner_dims_pos = [3] + inner_tiles = [8] + into %dest + : tensor<1x1x1x1001xf32> -> tensor<1x126x1x1x8xf32> + + return %pack : tensor<1x126x1x1x8xf32> +} +// CHECK-LABEL: @negative_non_unit_tiled_outer_dim( +// CHECK: linalg.pack diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir index 618ba34..66cae5c 100644 --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -1011,6 +1011,20 @@ module attributes { transform.target_tag = "start_here" } { } -> tensor<1x1x4xf32> return } + + func.func @generic_none(%arg0: tensor<128x128xi32>, %arg1: tensor<128x128xi32>, %arg2: tensor<128x128xi32>) { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : tensor<128x128xi32>, tensor<128x128xi32>) + outs(%arg2 : tensor<128x128xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + linalg.yield %out : i32 + } -> tensor<128x128xi32> + return + } } // ----- diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir index 9616a3e..1df15e8 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -10,10 +10,10 @@ // TODO: Some test cases from this file should be moved to other dialects. -// CHECK-LABEL: func @fill_inplace( +// CHECK-LABEL: func private @fill_inplace( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>> -// CHECK-NO-LAYOUT-MAP-LABEL: func @fill_inplace(%{{.*}}: memref<?xf32>) { -func.func @fill_inplace( +// CHECK-NO-LAYOUT-MAP-LABEL: func private @fill_inplace(%{{.*}}: memref<?xf32>) { +func.func private @fill_inplace( %A : tensor<?xf32> {bufferization.writable = true}) -> tensor<?xf32> { @@ -56,10 +56,10 @@ func.func @not_inplace( // ----- -// CHECK-LABEL: func @not_inplace +// CHECK-LABEL: func private @not_inplace // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>) { -// CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref<?x?xf32>) { -func.func @not_inplace( +// CHECK-NO-LAYOUT-MAP-LABEL: func private @not_inplace(%{{.*}}: memref<?x?xf32>) { +func.func private @not_inplace( %A : tensor<?x?xf32> {bufferization.writable = true}) -> tensor<?x?xf32> { @@ -235,7 +235,7 @@ func.func @dominance_violation_bug_1( // ----- -func.func @gather_like( +func.func private @gather_like( %arg0 : tensor<?x?xf32> {bufferization.writable = false}, %arg1 : tensor<?xi32> {bufferization.writable = false}, %arg2 : tensor<?x?xf32> {bufferization.writable = true}) @@ -254,7 +254,7 @@ func.func @gather_like( } -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } -// CHECK-LABEL: func @gather_like( +// CHECK-LABEL: func private @gather_like( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32, // CHECK-SAME: %[[ARG1:.+]]: memref<?xi32 // CHECK-SAME: %[[ARG2:.+]]: memref<?x?xf32 diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 9a44f95..7dc0a87b 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -18,7 +18,7 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -48,7 +48,7 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor< module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op) transform.yield @@ -57,6 +57,60 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @fuse_unary_param +func.func @fuse_unary_param(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>) + outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> + %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) + outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> + return %1 : tensor<?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.param.constant 32 : i32 -> !transform.param<i32> + %2 = transform.param.constant 1 : i32 -> !transform.param<i32> + %3, %loops:2 = transform.structured.fuse %0 tile_sizes [%1, 32] interchange [0, %2] + : (!transform.any_op, !transform.param<i32>, !transform.param<i32>) -> + (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unary_forall +func.func @fuse_unary_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { + + // CHECK: %[[RES:.*]] = scf.forall + // CHECK: linalg.exp + // CHECK: linalg.add + // CHECK: return %[[RES]] + %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>) + outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> + %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) + outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32> + return %1 : tensor<?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop = transform.structured.fuse %0 tile_sizes [32, 32] {use_forall} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @interchange_reduction // CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { @@ -93,7 +147,7 @@ func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf3 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [5, 0, 7] interchange [0, 2, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) %2, %loops_2 = transform.structured.tile_using_for %1 tile_sizes [0, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) @@ -121,7 +175,7 @@ func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [16, 32] interchange [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -147,7 +201,7 @@ func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [3, 5, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -173,7 +227,7 @@ func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]} + %1, %loops:3 = transform.structured.fuse %0 tile_sizes [3, 5, 2, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -204,7 +258,7 @@ func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -238,7 +292,7 @@ func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -273,7 +327,7 @@ func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + %1, %loops:2 = transform.structured.fuse %0 tile_sizes [32, 32] interchange [0, 1] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -299,7 +353,7 @@ func.func @bubble_up_extract_slice_through_expand_shape(%0: tensor<60xf32>) -> t module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 1, 5] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) transform.yield } @@ -324,7 +378,7 @@ func.func @bubble_up_extract_slice_through_expand_shape_full_inner_dim(%0: tenso module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:2 = transform.structured.fuse %0 tile_sizes [1, 2, 0] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) transform.yield } @@ -348,7 +402,7 @@ func.func @no_bubble_up_extract_slice_through_expand_shape_non_contiguous(%0: te module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 2, 5] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) transform.yield } @@ -379,7 +433,7 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true : + %transformed, %loops:4 = transform.structured.fuse %0 tile_sizes [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -408,7 +462,7 @@ module { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true : + %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [0, 0, 1, 0] interchange [0, 1, 2, 3] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) transform.yield } @@ -433,7 +487,7 @@ func.func @no_bubble_up_extract_slice_through_expand_shape_on_cleanup_false(%0: module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = false : + %transformed, %loops:3 = transform.structured.fuse %0 tile_sizes [1, 1, 5] interchange [0, 1, 2] : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op) transform.yield } @@ -456,7 +510,7 @@ func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [1, 0, 0] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) transform.yield } @@ -482,7 +536,7 @@ func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + %transformed, %loops:1 = transform.structured.fuse %0 tile_sizes [1, 0, 0] interchange [0, 1, 2] {apply_cleanup} : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) transform.yield } diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir index 35f520a..93a0336 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s +///---------------------------------------------------------------------------------------- +/// Tests for linalg.dot +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_dot func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) { @@ -20,6 +24,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.matvec +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_matvec func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { @@ -41,6 +49,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.matmul +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_matmul func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> @@ -138,6 +150,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.batch_matmul +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_batch_matmul func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32> @@ -159,6 +175,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.cantract +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: @matmul_as_contract // CHECK-SAME: %[[A:.*]]: tensor<24x12xf32> // CHECK-SAME: %[[B:.*]]: tensor<12x25xf32> @@ -220,6 +240,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.fill +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: func @test_vectorize_fill func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> @@ -259,70 +283,14 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @test_vectorize_copy -func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { - // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> - // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> - memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32> - return -} +///---------------------------------------------------------------------------------------- +/// Tests for linalg.pack +///---------------------------------------------------------------------------------------- -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} +// Note, see a similar test in: +// * vectorization.mlir. -// ----- - -// CHECK-LABEL: func @test_vectorize_copy_0d -func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) { - // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>) - // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32> - // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32> - // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32> - // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32> - memref.copy %A, %B : memref<f32> to memref<f32> - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @test_vectorize_copy_complex -// CHECK-NOT: vector< -func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) { - memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>> - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// Input identical as the test in vectorization.mlir. Output is different - -// vector sizes are inferred (rather than user-specified) and hence _no_ -// masking was used. - -func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { +func.func @pack_no_padding(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32> return %pack : tensor<4x1x32x16x2xf32> } @@ -336,7 +304,7 @@ module attributes {transform.with_named_sequence} { } } -// CHECK-LABEL: func.func @test_vectorize_pack( +// CHECK-LABEL: func.func @pack_no_padding( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { // CHECK-DAG: %[[VAL_2:.*]] = ub.poison : f32 @@ -349,13 +317,16 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { +// Note, see a similar test in: +// * vectorization.mlir. + +func.func @pack_with_padding(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { %pad = arith.constant 0.000000e+00 : f32 %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> return %pack : tensor<32x4x1x16x2xf32> } -// CHECK-LABEL: func.func @test_vectorize_padded_pack( +// CHECK-LABEL: func.func @pack_with_padding( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 @@ -377,6 +348,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.map +///---------------------------------------------------------------------------------------- + func.func @vectorize_map(%arg0: memref<64xf32>, %arg1: memref<64xf32>, %arg2: memref<64xf32>) { linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>) @@ -403,6 +378,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.transpose +///---------------------------------------------------------------------------------------- + func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>, %arg1: memref<32x64x16xf32>) { linalg.transpose ins(%arg0 : memref<16x32x64xf32>) @@ -424,6 +403,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.reduce +///---------------------------------------------------------------------------------------- + func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>, %arg1: memref<16x64xf32>) { linalg.reduce ins(%arg0 : memref<16x32x64xf32>) @@ -449,6 +432,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.generic +///---------------------------------------------------------------------------------------- + #matmul_trait = { indexing_maps = [ affine_map<(m, n, k) -> (m, k)>, @@ -1446,6 +1433,8 @@ module attributes {transform.with_named_sequence} { // ----- +// TODO: Two Linalg Ops in one tests - either split or document "why". + // CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> // CHECK-LABEL: func @fused_broadcast_red_2d @@ -1896,3 +1885,65 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +///---------------------------------------------------------------------------------------- +/// Tests for memref.copy +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func @test_vectorize_copy +func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_0d +func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) { + // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>) + // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32> + // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32> + // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32> + // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32> + memref.copy %A, %B : memref<f32> to memref<f32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_complex +// CHECK-NOT: vector< +func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) { + memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 11bea8d..1304a90 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -1307,14 +1307,17 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf /// Tests for linalg.pack ///---------------------------------------------------------------------------------------- -// Input identical as the test in vectorization-with-patterns.mlir. Output is -// different - vector sizes are inferred (rather than user-specified) and hence -// masking was used. +// This packing requires no padding, so no out-of-bounds read/write vector Ops. -// CHECK-LABEL: func @test_vectorize_pack +// Note, see a similar test in: +// * vectorization-with-patterns.mlir +// The output is identical (the input vector sizes == the inferred vector +// sizes, i.e. the tensor sizes). + +// CHECK-LABEL: func @pack_no_padding // CHECK-SAME: %[[SRC:.*]]: tensor<32x8x16xf32>, // CHECK-SAME: %[[DEST:.*]]: tensor<4x1x32x16x2xf32> -func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { +func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { %pack = linalg.pack %src outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32> return %pack : tensor<4x1x32x16x2xf32> } @@ -1325,9 +1328,9 @@ func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x1 // CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> // CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32> // CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] // CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32> -// CHECK: return %[[write]] : tensor<4x1x32x16x2xf32> +// CHECK: return %[[WRITE]] : tensor<4x1x32x16x2xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) { @@ -1339,14 +1342,18 @@ module attributes {transform.with_named_sequence} { // ----- -// Input identical as the test in vectorization-with-patterns.mlir. Output is -// different - vector sizes are inferred (rather than user-specified) and hence -// masking was used. +// This packing does require padding, so there are out-of-bounds read/write +// vector Ops. + +// Note, see a similar test in: +// * vectorization-with-patterns.mlir. +// The output is different (the input vector sizes != inferred vector sizes, +// i.e. the tensor sizes). -// CHECK-LABEL: func @test_vectorize_padded_pack +// CHECK-LABEL: func @pack_with_padding // CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>, // CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32> -func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { +func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { %pad = arith.constant 0.000000e+00 : f32 %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> return %pack : tensor<32x4x1x16x2xf32> @@ -1364,9 +1371,9 @@ func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<3 // CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> // CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> // CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] // CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> -// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32> +// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -1378,10 +1385,46 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @test_vectorize_dynamic_pack +// This packing does require padding, so there are out-of-bounds read/write +// vector Ops. + +// Note, see a similar test in: +// * vectorization-with-patterns.mlir. +// The output is identical (in both cases the vector sizes are inferred). + +// CHECK-LABEL: func @pack_with_padding_no_vector_sizes +// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32> +func.func @pack_with_padding_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { + %pad = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> + return %pack : tensor<32x4x1x16x2xf32> +} +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]] +// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32> +// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> +// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] +// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @pack_with_dynamic_dims // CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>, // CHECK-SAME: %[[DEST:.*]]: tensor<?x?x16x2xf32> -func.func @test_vectorize_dynamic_pack(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> { +func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> { %pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32> return %pack : tensor<?x?x16x2xf32> } @@ -1418,64 +1461,6 @@ module attributes {transform.with_named_sequence} { } } -// ----- - -// CHECK-LABEL: func @test_vectorize_pack_no_vector_sizes -// CHECK-SAME: %[[SRC:.*]]: tensor<64x4xf32>, -// CHECK-SAME: %[[DEST:.*]]: tensor<2x4x16x2xf32> -func.func @test_vectorize_pack_no_vector_sizes(%src: tensor<64x4xf32>, %dest: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> { - %pack = linalg.pack %src outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %dest : tensor<64x4xf32> -> tensor<2x4x16x2xf32> - return %pack : tensor<2x4x16x2xf32> -} -// CHECK-DAG: %[[CST:.*]] = ub.poison : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32> -// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<64x4xf32> to vector<4x16x2x2xf32> -// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32> -// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] -// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32> -// CHECK: return %[[WRITE]] : tensor<2x4x16x2xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 : !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes -// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>, -// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32> -func.func @test_vectorize_padded_pack_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { - %pad = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> - return %pack : tensor<32x4x1x16x2xf32> -} -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]] -// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32> -// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> -// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> -// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] -// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> -// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 : !transform.any_op - transform.yield - } -} - - ///---------------------------------------------------------------------------------------- /// Tests for other Ops ///---------------------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir index 603ace8..3d4bec7 100644 --- a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir +++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir @@ -3,7 +3,7 @@ func.func @test_static_memref_alloc() { %0 = memref.alloca() {test.ptr} : memref<10x20xf32> // CHECK: Successfully generated alloc for operation: %[[ORIG:.*]] = memref.alloca() {test.ptr} : memref<10x20xf32> - // CHECK: Generated: %{{.*}} = memref.alloca() : memref<10x20xf32> + // CHECK: Generated: %{{.*}} = memref.alloca() {acc.var_name = #acc.var_name<"test_alloc">} : memref<10x20xf32> return } @@ -19,6 +19,6 @@ func.func @test_dynamic_memref_alloc() { // CHECK: Generated: %[[DIM0:.*]] = memref.dim %[[ORIG]], %[[C0]] : memref<?x?xf32> // CHECK: Generated: %[[C1:.*]] = arith.constant 1 : index // CHECK: Generated: %[[DIM1:.*]] = memref.dim %[[ORIG]], %[[C1]] : memref<?x?xf32> - // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32> + // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"test_alloc">} : memref<?x?xf32> return } diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir new file mode 100644 index 0000000..8846c9e --- /dev/null +++ b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(test-acc-recipe-populate{recipe-type=firstprivate})" | FileCheck %s + +// CHECK: acc.firstprivate.recipe @firstprivate_scalar : memref<f32> init { +// CHECK: ^bb0(%{{.*}}: memref<f32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32> +// CHECK: acc.yield %[[ALLOC]] : memref<f32> +// CHECK: } copy { +// CHECK: ^bb0(%[[SRC:.*]]: memref<f32>, %[[DST:.*]]: memref<f32>): +// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<f32> to memref<f32> +// CHECK: acc.terminator +// CHECK: } +// CHECK-NOT: destroy + +func.func @test_scalar() { + %0 = memref.alloca() {test.var = "scalar"} : memref<f32> + return +} + +// ----- + +// CHECK: acc.firstprivate.recipe @firstprivate_static_2d : memref<10x20xf32> init { +// CHECK: ^bb0(%{{.*}}: memref<10x20xf32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32> +// CHECK: } copy { +// CHECK: ^bb0(%[[SRC:.*]]: memref<10x20xf32>, %[[DST:.*]]: memref<10x20xf32>): +// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<10x20xf32> to memref<10x20xf32> +// CHECK: acc.terminator +// CHECK: } +// CHECK-NOT: destroy + +func.func @test_static_2d() { + %0 = memref.alloca() {test.var = "static_2d"} : memref<10x20xf32> + return +} + +// ----- + +// CHECK: acc.firstprivate.recipe @firstprivate_dynamic_2d : memref<?x?xf32> init { +// CHECK: ^bb0(%[[ARG:.*]]: memref<?x?xf32>): +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32> +// CHECK: } copy { +// CHECK: ^bb0(%[[SRC:.*]]: memref<?x?xf32>, %[[DST:.*]]: memref<?x?xf32>): +// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<?x?xf32> to memref<?x?xf32> +// CHECK: acc.terminator +// CHECK: } destroy { +// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>): +// CHECK: memref.dealloc %[[VAL]] : memref<?x?xf32> +// CHECK: acc.terminator +// CHECK: } + +func.func @test_dynamic_2d(%arg0: index, %arg1: index) { + %0 = memref.alloc(%arg0, %arg1) {test.var = "dynamic_2d"} : memref<?x?xf32> + return +} + +// ----- + +// CHECK: acc.firstprivate.recipe @firstprivate_mixed_dims : memref<10x?xf32> init { +// CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>): +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32> +// CHECK: } copy { +// CHECK: ^bb0(%[[SRC:.*]]: memref<10x?xf32>, %[[DST:.*]]: memref<10x?xf32>): +// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<10x?xf32> to memref<10x?xf32> +// CHECK: acc.terminator +// CHECK: } destroy { +// CHECK: ^bb0(%{{.*}}: memref<10x?xf32>, %[[VAL:.*]]: memref<10x?xf32>): +// CHECK: memref.dealloc %[[VAL]] : memref<10x?xf32> +// CHECK: acc.terminator +// CHECK: } + +func.func @test_mixed_dims(%arg0: index) { + %0 = memref.alloc(%arg0) {test.var = "mixed_dims"} : memref<10x?xf32> + return +} + +// ----- + +// CHECK: acc.firstprivate.recipe @firstprivate_scalar_int : memref<i32> init { +// CHECK: ^bb0(%{{.*}}: memref<i32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32> +// CHECK: acc.yield %[[ALLOC]] : memref<i32> +// CHECK: } copy { +// CHECK: ^bb0(%[[SRC:.*]]: memref<i32>, %[[DST:.*]]: memref<i32>): +// CHECK: memref.copy %[[SRC]], %[[DST]] : memref<i32> to memref<i32> +// CHECK: acc.terminator +// CHECK: } +// CHECK-NOT: destroy + +func.func @test_scalar_int() { + %0 = memref.alloca() {test.var = "scalar_int"} : memref<i32> + return +} + diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir new file mode 100644 index 0000000..3d5a918 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(test-acc-recipe-populate{recipe-type=private})" | FileCheck %s + +// CHECK: acc.private.recipe @private_scalar : memref<f32> init { +// CHECK: ^bb0(%{{.*}}: memref<f32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32> +// CHECK: acc.yield %[[ALLOC]] : memref<f32> +// CHECK: } +// CHECK-NOT: destroy + +func.func @test_scalar() { + %0 = memref.alloca() {test.var = "scalar"} : memref<f32> + return +} + +// ----- + +// CHECK: acc.private.recipe @private_static_2d : memref<10x20xf32> init { +// CHECK: ^bb0(%{{.*}}: memref<10x20xf32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32> +// CHECK: } +// CHECK-NOT: destroy + +func.func @test_static_2d() { + %0 = memref.alloca() {test.var = "static_2d"} : memref<10x20xf32> + return +} + +// ----- + +// CHECK: acc.private.recipe @private_dynamic_2d : memref<?x?xf32> init { +// CHECK: ^bb0(%[[ARG:.*]]: memref<?x?xf32>): +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32> +// CHECK: } destroy { +// CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>): +// CHECK: memref.dealloc %[[VAL]] : memref<?x?xf32> +// CHECK: acc.terminator +// CHECK: } + +func.func @test_dynamic_2d(%arg0: index, %arg1: index) { + %0 = memref.alloc(%arg0, %arg1) {test.var = "dynamic_2d"} : memref<?x?xf32> + return +} + +// ----- + +// CHECK: acc.private.recipe @private_mixed_dims : memref<10x?xf32> init { +// CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>): +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32> +// CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32> +// CHECK: } destroy { +// CHECK: ^bb0(%{{.*}}: memref<10x?xf32>, %[[VAL:.*]]: memref<10x?xf32>): +// CHECK: memref.dealloc %[[VAL]] : memref<10x?xf32> +// CHECK: acc.terminator +// CHECK: } + +func.func @test_mixed_dims(%arg0: index) { + %0 = memref.alloc(%arg0) {test.var = "mixed_dims"} : memref<10x?xf32> + return +} + +// ----- + +// CHECK: acc.private.recipe @private_scalar_int : memref<i32> init { +// CHECK: ^bb0(%{{.*}}: memref<i32>): +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32> +// CHECK: acc.yield %[[ALLOC]] : memref<i32> +// CHECK: } +// CHECK-NOT: destroy + +func.func @test_scalar_int() { + %0 = memref.alloca() {test.var = "scalar_int"} : memref<i32> + return +} + diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir index a1067ec..af09dc8 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -8,11 +8,11 @@ // Test bufferization using memref types that have no layout map. // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null -// CHECK-LABEL: func @scf_for_yield_only( +// CHECK-LABEL: func private @scf_for_yield_only( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>, // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>> // CHECK-SAME: ) -> memref<?xf32> { -func.func @scf_for_yield_only( +func.func private @scf_for_yield_only( %A : tensor<?xf32> {bufferization.writable = false}, %B : tensor<?xf32> {bufferization.writable = true}, %lb : index, %ub : index, %step : index) @@ -85,11 +85,11 @@ func.func @nested_scf_for(%A : tensor<?xf32> {bufferization.writable = true}, // ----- -// CHECK-LABEL: func @scf_for_with_tensor.insert_slice +// CHECK-LABEL: func private @scf_for_with_tensor.insert_slice // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>> // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>> // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>> -func.func @scf_for_with_tensor.insert_slice( +func.func private @scf_for_with_tensor.insert_slice( %A : tensor<?xf32> {bufferization.writable = false}, %B : tensor<?xf32> {bufferization.writable = true}, %C : tensor<4xf32> {bufferization.writable = false}, @@ -471,11 +471,11 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>, // ----- -// CHECK-LABEL: func.func @parallel_insert_slice_no_conflict( +// CHECK-LABEL: func private @parallel_insert_slice_no_conflict( // CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index, // CHECK-SAME: %[[arg1:.*]]: memref<?xf32, strided{{.*}}>, // CHECK-SAME: %[[arg2:.*]]: memref<?xf32, strided{{.*}}> -func.func @parallel_insert_slice_no_conflict( +func.func private @parallel_insert_slice_no_conflict( %idx: index, %idx2: index, %arg1: tensor<?xf32> {bufferization.writable = true}, diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index 5f95da2..b6c72be 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -8,12 +8,12 @@ // Test bufferization using memref types that have no layout map. // RUN: mlir-opt %s -one-shot-bufferize="unknown-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null -// CHECK-LABEL: func @insert_slice_fun +// CHECK-LABEL: func private @insert_slice_fun // CHECK-SAME: %[[A0:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>, // CHECK-SAME: %[[A1:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>, // CHECK-SAME: %[[t0:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>, // CHECK-SAME: %[[t1:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>> -func.func @insert_slice_fun( +func.func private @insert_slice_fun( %A0 : tensor<?xf32> {bufferization.writable = false}, %A1 : tensor<?xf32> {bufferization.writable = true}, %t0 : tensor<4xf32> {bufferization.writable = false}, @@ -331,12 +331,12 @@ func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index) // ----- // CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)> -// CHECK-LABEL: func.func @cast_retains_buffer_layout( +// CHECK-LABEL: func.func private @cast_retains_buffer_layout( // CHECK-SAME: %[[t:.*]]: memref<?xf32, #[[$map]]>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> { // CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, #[[$map]]> to memref<10xf32, #[[$map]]> // CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref<?xf32, strided<[1], offset: 7>> // CHECK: return %[[slice]] -func.func @cast_retains_buffer_layout( +func.func private @cast_retains_buffer_layout( %t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>}, %sz: index) @@ -353,12 +353,12 @@ func.func @cast_retains_buffer_layout( // ----- -// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided( +// CHECK-LABEL: func private @cast_retains_buffer_layout_strided( // CHECK-SAME: %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> { // CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>> // CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>> // CHECK: return %[[slice]] -func.func @cast_retains_buffer_layout_strided( +func.func private @cast_retains_buffer_layout_strided( %t: tensor<?xf32> {bufferization.buffer_layout = strided<[1], offset: 5>}, %sz: index) diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir index 04a99b5..32fb0c9 100644 --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -149,7 +149,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %copy [2, 3] + %a, %b, %c = transform.structured.fuse %copy tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir index d6c886c..a0c59c0 100644 --- a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -1,12 +1,14 @@ // RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL // RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K // RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1 // ----- -// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>} -// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} -// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>} +// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>} // CHECK-LABEL: test_simple func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir new file mode 100644 index 0000000..51089df --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" + +// ----- + +func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> + // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16> + return %0 : tensor<1x14x28xf16> +} + +// ----- + +func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir new file mode 100644 index 0000000..8164509 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s + +// ----- + +func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16> + return %0 : tensor<1x14x28xf16> +} + +// ----- + +// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type +func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir new file mode 100644 index 0000000..023a0e5 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir @@ -0,0 +1,311 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s + +///===----------------------------------------------===// +/// Tests of `StepCompareFolder` +///===----------------------------------------------===// + + +///===------------------------------------===// +/// Tests of `ugt` (unsigned greater than) +///===------------------------------------===// + +// CHECK-LABEL: @ugt_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 3 > [0, 1, 2] => [true, true, true] => true for all indices => fold + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ugt_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ugt_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 > [0, 1, 2] => [true, true, false] => not same for all indices => don't fold + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ugt_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 3 => [false, false, false] => false for all indices => fold + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ugt_constant_max_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_max_rhs() -> vector<3xi1> { + // The largest i64 possible: + %cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + + +// ----- + +// CHECK-LABEL: @ugt_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 2 => [false, false, false] => false for all indices => fold + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ugt_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ugt_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 1 => [false, false, true] => not same for all indices => don't fold + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `uge` (unsigned greater than or equal) +///===------------------------------------===// + + +// CHECK-LABEL: @uge_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @uge_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 >= [0, 1, 2] => [true, true, true] => true for all indices => fold + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_uge_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_uge_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 >= [0, 1, 2] => [true, false, false] => not same for all indices => don't fold + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @uge_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @uge_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 3 => [false, false, false] => false for all indices => fold + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_uge_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_uge_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 2 => [false, false, true] => not same for all indices => don't fold + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + + +///===------------------------------------===// +/// Tests of `ult` (unsigned less than) +///===------------------------------------===// + + +// CHECK-LABEL: @ult_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ult_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 < [0, 1, 2] => [false, false, false] => false for all indices => fold + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ult_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ult_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 < [0, 1, 2] => [false, false, true] => not same for all indices => don't fold + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ult_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ult_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] < 3 => [true, true, true] => true for all indices => fold + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ult_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ult_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] < 2 => [true, true, false] => not same for all indices => don't fold + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `ule` (unsigned less than or equal) +///===------------------------------------===// + +// CHECK-LABEL: @ule_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ule_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ule_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ule_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ule_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ule_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ule_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ule_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `eq` (equal) +///===------------------------------------===// + +// CHECK-LABEL: @eq_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @eq_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_eq_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_eq_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `ne` (not equal) +///===------------------------------------===// + +// CHECK-LABEL: @ne_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ne_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ne_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ne_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 35db14e..e5a98b5 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -188,15 +188,38 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf // CHECK-LABEL: func @vector_fma // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32> -// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern. -func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{ +func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{ %0 = vector.fma %a, %a, %a : vector<3x2x2xf32> return %0 : vector<3x2x2xf32> } -// CHECK-LABEL: func @negative_vector_fma_3d -// CHECK-NOT: vector.extract_strided_slice -// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32> -// CHECK: return +// CHECK-LABEL: func @vector_fma_3d +// CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32> +// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32> +// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32> +// CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32> +// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32> +// CHECK: return %[[I2]] : vector<3x2x2xf32> func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> { %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32> @@ -440,3 +463,36 @@ func.func @vector_step() -> vector<32xindex> { // CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex> // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex> // CHECK: return %[[INS3]] : vector<32xindex> + + +func.func @elementwise_3D_to_2D(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> { + %0 = arith.addf %v1, %v2 : vector<2x2x2xf32> + return %0 : vector<2x2x2xf32> +} +// CHECK-LABEL: func @elementwise_3D_to_2D +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32> +// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32> +// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32> +// CHECK: return %[[I1]] : vector<2x2x2xf32> + + +func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> { + %0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32> + return %0 : vector<2x2x2x2xf32> +} + +// CHECK-LABEL: func @elementwise_4D_to_2D +// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> +// CHECK-NOT: arith.addf +// CHECK: return diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index bb76392..401cdd29 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1) { // CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> () // CHECK-PROP: return // CHECK-PROP: } + +// ----- +func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) { + %r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] -> + (vector<1xf32>, vector<2xf32>, vector<1xf32>) { + %2 = "some_def"() : () -> (vector<32xf32>) + %3 = "some_def"() : () -> (vector<64xf32>) + gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32> + } + %r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>) + return %r0 : vector<1xf32> +} + +// CHECK-PROP: func @dedup_unused_result +// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>) +// CHECK-PROP: %[[Y0:.*]] = "some_def"() : () -> vector<32xf32> +// CHECK-PROP: %[[Y1:.*]] = "some_def"() : () -> vector<64xf32> +// CHECK-PROP: gpu.yield %[[Y0]] : vector<32xf32> +// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 228ef69d..ebbe3ce 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16> // ----- func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed mem_desc shape}} + // expected-error@+1 {{data shape must not exceed mem_desc shape}} %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16> return } @@ -871,6 +871,14 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { } // ----- +func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> + return +} + + +// ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16> @@ -892,30 +900,16 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve } // ----- -func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed source shape}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) { - // expected-error@+1 {{result must inherit the source strides}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{failed to verify that all of {src, res} have same element type}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>> +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } // ----- -func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result rank must not exceed source rank}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16> +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index bb37902..0a10f68 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -825,53 +825,73 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } -// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) { +// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> gpu.return } -// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) -gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { +// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) +gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16> gpu.return } +// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) +gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) +gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) +gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> + gpu.return +} -// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> gpu.return } -// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>> +// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { +gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16> + xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>> +// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) +gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> + xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> gpu.return } -// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) -gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>> +// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) { +gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> + xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir index 9d63c2d..fe4f44c 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir @@ -584,3 +584,101 @@ gpu.module @test_kernel { gpu.return } } + +// ----- +gpu.module @test_kernel { + // CHECK-LABEL: load_with_offsets + // CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32> + gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32> + + gpu.return %ld : vector<32xf32> + } +} + +// ----- +gpu.module @test_kernel { + // CHECK-LABEL: store_with_offsets + // CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets(%src: ui64) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + + %st_vec = arith.constant dense<1023.0>: vector<32xf32> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [16]>, + layout_operand_2 = #xegpu.layout<inst_data = [16]>, + layout_operand_3 = #xegpu.layout<inst_data = [16]>, + l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1> + + gpu.return + } +} + +// ----- +gpu.module @test_kernel { + // CHECK-LABEL: load_with_offsets_chunk + // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32> + // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> + // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> + // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> + // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + // CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32> + gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32> + gpu.return %ld : vector<32x4xf32> + } +} + +// ----- +gpu.module @test_kernel { + // CHECK-LABEL: store_with_offsets_chunk + // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32 + // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex> + // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex> + // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex> + // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + // CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1> + gpu.func @store_with_offsets_chunk(%src: ui64) { + %cst = arith.constant dense<[ + 0, 8, 16, 24, 32, 40, 48, 56, + 64, 72, 80, 88, 96, 104, 112, 120, + 128, 136, 144, 152, 160, 168, 176, 184, + 192, 200, 208, 216, 224, 232, 240, 248 + ]> : vector<32xindex> + + %c17 = arith.constant 17: index + %mask = vector.create_mask %c17: vector<32xi1> + + %st_vec = arith.constant dense<1023.>: vector<32x4xf32> + xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout_operand_0 = #xegpu.layout<inst_data = [16, 2]>, + layout_operand_2 = #xegpu.layout<inst_data = [16, 2]>, + layout_operand_3 = #xegpu.layout<inst_data = [16, 2]>, + l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1> + gpu.return + } +} diff --git a/mlir/test/Integration/GPU/SPIRV/simple_add.mlir b/mlir/test/Integration/GPU/SPIRV/simple_add.mlir index cb16c37..b3154d4 100644 --- a/mlir/test/Integration/GPU/SPIRV/simple_add.mlir +++ b/mlir/test/Integration/GPU/SPIRV/simple_add.mlir @@ -3,7 +3,16 @@ // RUN: | FileCheck %s // CHECK: data = -// CHECK-RAW: [[[7.7, 0, 0], [7.7, 0, 0], [7.7, 0, 0]], [[0, 7.7, 0], [0, 7.7, 0], [0, 7.7, 0]], [[0, 0, 7.7], [0, 0, 7.7], [0, 0, 7.7]]] +// CHECK{LITERAL}: [[[7.7, 0, 0], +// CHECK{LITERAL}: [7.7, 0, 0], +// CHECK{LITERAL}: [7.7, 0, 0]], +// CHECK{LITERAL}: [[0, 7.7, 0], +// CHECK{LITERAL}: [0, 7.7, 0], +// CHECK{LITERAL}: [0, 7.7, 0]], +// CHECK{LITERAL}: [[0, 0, 7.7], +// CHECK{LITERAL}: [0, 0, 7.7], +// CHECK{LITERAL}: [0, 0, 7.7]]] + module attributes { gpu.container_module, spirv.target_env = #spirv.target_env< diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index 8a0390a..8116044 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -17,7 +17,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %matmul [10, 20] + %a, %b, %c = transform.structured.fuse %matmul tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -69,7 +69,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %mm1, %mm2 = transform.split_handle %matmuls : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %a, %b = transform.structured.fuse %mm2 [10] + %a, %b = transform.structured.fuse %mm2 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -188,7 +188,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -248,7 +248,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] interchange[1, 0] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] interchange[1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -307,7 +307,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -367,7 +367,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b, %c = transform.structured.fuse %generic [10, 20] + %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) transform.yield } @@ -423,7 +423,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %mm1, %mm2, %mm3 = transform.split_handle %matmuls : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %a, %b = transform.structured.fuse %mm3 [10] + %a, %b = transform.structured.fuse %mm3 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -512,7 +512,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %generic1, %generic2, %generic3 = transform.split_handle %generics : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %a, %b = transform.structured.fuse %generic3 [10] + %a, %b = transform.structured.fuse %generic3 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -568,7 +568,7 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.structured.fuse %pad [8] + %a, %b = transform.structured.fuse %pad tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -614,7 +614,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %b = transform.structured.fuse %matmul [0, 1, 0] + %a, %b = transform.structured.fuse %matmul tile_sizes [0, 1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -652,7 +652,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %a, %loops:4 = transform.structured.fuse %generic {tile_sizes = [1, 16, 16, 16], tile_interchange = [0, 1, 2, 3], apply_cleanup = false} + %a, %loops:4 = transform.structured.fuse %generic tile_sizes [1, 16, 16, 16] interchange [0, 1, 2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) transform.yield } diff --git a/mlir/test/Pass/remark-final.mlir b/mlir/test/Pass/remark-final.mlir new file mode 100644 index 0000000..325271e --- /dev/null +++ b/mlir/test/Pass/remark-final.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s --test-remark --remarks-filter="category.*" --remark-policy=final 2>&1 | FileCheck %s +// RUN: mlir-opt %s --test-remark --remarks-filter="category.*" --remark-policy=final --remark-format=yaml --remarks-output-file=%t.yaml +// RUN: FileCheck --check-prefix=CHECK-YAML %s < %t.yaml +module @foo { + "test.op"() : () -> () + +} + +// CHECK-YAML-NOT: This is a test passed remark (should be dropped) +// CHECK-YAML-DAG: !Analysis +// CHECK-YAML-DAG: !Failure +// CHECK-YAML-DAG: !Passed + +// CHECK-NOT: This is a test passed remark (should be dropped) +// CHECK-DAG: remark: [Analysis] test-remark +// CHECK-DAG: remark: [Failure] test-remark | Category:category-2-failed +// CHECK-DAG: remark: [Passed] test-remark | Category:category-1-passed diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 4281f41..9f1c816 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -315,16 +315,13 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) } // CPP-DEFAULT: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { -// CPP-DEFAULT-NEXT: int32_t [[VAL_3:v[0-9]+]] = *([[VAL_2]] - [[VAL_1]]); -// CPP-DEFAULT-NEXT: return [[VAL_3]]; +// CPP-DEFAULT-NEXT: return *([[VAL_2]] - [[VAL_1]]); // CPP-DEFAULT-NEXT: } // CPP-DECLTOP: int32_t expression_with_dereference(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2]]) { -// CPP-DECLTOP-NEXT: int32_t [[VAL_3:v[0-9]+]]; -// CPP-DECLTOP-NEXT: [[VAL_3]] = *([[VAL_2]] - [[VAL_1]]); -// CPP-DECLTOP-NEXT: return [[VAL_3]]; +// CPP-DECLTOP-NEXT: return *([[VAL_2]] - [[VAL_1]]); // CPP-DECLTOP-NEXT: } -func.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 { +emitc.func @expression_with_dereference(%arg1: i32, %arg2: !emitc.ptr<i32>) -> i32 { %c = emitc.expression %arg1, %arg2 : (i32, !emitc.ptr<i32>) -> i32 { %e = emitc.sub %arg2, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32> %d = emitc.apply "*"(%e) : (!emitc.ptr<i32>) -> i32 @@ -384,19 +381,16 @@ func.func @expression_with_subscript_user(%arg0: !emitc.ptr<!emitc.opaque<"void" // CPP-DEFAULT: bool expression_with_load(int32_t [[VAL_1:v.+]], int32_t [[VAL_2:v.+]], int32_t* [[VAL_3:v.+]]) { // CPP-DEFAULT-NEXT: int64_t [[VAL_4:v.+]] = 0; // CPP-DEFAULT-NEXT: int32_t [[VAL_5:v.+]] = 42; -// CPP-DEFAULT-NEXT: bool [[VAL_6:v.+]] = [[VAL_5]] + [[VAL_2]] < [[VAL_3]][[[VAL_4]]] + [[VAL_1]]; -// CPP-DEFAULT-NEXT: return [[VAL_6]]; +// CPP-DEFAULT-NEXT: return [[VAL_5]] + [[VAL_2]] < [[VAL_3]][[[VAL_4]]] + [[VAL_1]]; // CPP-DECLTOP: bool expression_with_load(int32_t [[VAL_1:v.+]], int32_t [[VAL_2:v.+]], int32_t* [[VAL_3:v.+]]) { // CPP-DECLTOP-NEXT: int64_t [[VAL_4:v.+]]; // CPP-DECLTOP-NEXT: int32_t [[VAL_5:v.+]]; -// CPP-DECLTOP-NEXT: bool [[VAL_6:v.+]]; // CPP-DECLTOP-NEXT: [[VAL_4]] = 0; // CPP-DECLTOP-NEXT: [[VAL_5]] = 42; -// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_5]] + [[VAL_2]] < [[VAL_3]][[[VAL_4]]] + [[VAL_1]]; -// CPP-DECLTOP-NEXT: return [[VAL_6]]; +// CPP-DECLTOP-NEXT: return [[VAL_5]] + [[VAL_2]] < [[VAL_3]][[[VAL_4]]] + [[VAL_1]]; -func.func @expression_with_load(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 { +emitc.func @expression_with_load(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 { %c0 = "emitc.constant"() {value = 0 : i64} : () -> i64 %0 = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32> %ptr = emitc.subscript %arg2[%c0] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32> @@ -408,22 +402,19 @@ func.func @expression_with_load(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) %e = emitc.cmp lt, %b, %d :(i32, i32) -> i1 yield %e : i1 } - return %result : i1 + emitc.return %result : i1 } // CPP-DEFAULT: bool expression_with_load_and_call(int32_t* [[VAL_1:v.+]]) { // CPP-DEFAULT-NEXT: int64_t [[VAL_2:v.+]] = 0; -// CPP-DEFAULT-NEXT: bool [[VAL_3:v.+]] = [[VAL_1]][[[VAL_2]]] + bar([[VAL_1]][[[VAL_2]]]) < [[VAL_1]][[[VAL_2]]]; -// CPP-DEFAULT-NEXT: return [[VAL_3]]; +// CPP-DEFAULT-NEXT: return [[VAL_1]][[[VAL_2]]] + bar([[VAL_1]][[[VAL_2]]]) < [[VAL_1]][[[VAL_2]]]; // CPP-DECLTOP: bool expression_with_load_and_call(int32_t* [[VAL_1:v.+]]) { // CPP-DECLTOP-NEXT: int64_t [[VAL_2:v.+]]; -// CPP-DECLTOP-NEXT: bool [[VAL_3:v.+]]; // CPP-DECLTOP-NEXT: [[VAL_2]] = 0; -// CPP-DECLTOP-NEXT: [[VAL_3]] = [[VAL_1]][[[VAL_2]]] + bar([[VAL_1]][[[VAL_2]]]) < [[VAL_1]][[[VAL_2]]]; -// CPP-DECLTOP-NEXT: return [[VAL_3]]; +// CPP-DECLTOP-NEXT: return [[VAL_1]][[[VAL_2]]] + bar([[VAL_1]][[[VAL_2]]]) < [[VAL_1]][[[VAL_2]]]; -func.func @expression_with_load_and_call(%arg0: !emitc.ptr<i32>) -> i1 { +emitc.func @expression_with_load_and_call(%arg0: !emitc.ptr<i32>) -> i1 { %c0 = "emitc.constant"() {value = 0 : i64} : () -> i64 %ptr = emitc.subscript %arg0[%c0] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32> %result = emitc.expression %ptr : (!emitc.lvalue<i32>) -> i1 { @@ -435,7 +426,7 @@ func.func @expression_with_load_and_call(%arg0: !emitc.ptr<i32>) -> i1 { %f = emitc.cmp lt, %e, %b :(i32, i32) -> i1 yield %f : i1 } - return %result : i1 + emitc.return %result : i1 } @@ -458,3 +449,204 @@ emitc.func @expression_with_call_opaque_with_args_array(%0 : i32, %1 : i32) { } return } + +// CPP-DEFAULT: void inline_side_effects_into_assign(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int64_t [[VAL_3:v[0-9]+]] = 0; +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42; +// CPP-DEFAULT-NEXT: [[VAL_4]] = [[VAL_4]] * [[VAL_1]] + [[VAL_2]][[[VAL_3]]]; +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: void inline_side_effects_into_assign(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int64_t [[VAL_3:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_3]] = 0; +// CPP-DECLTOP-NEXT: [[VAL_4]] = 42; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_4]] * [[VAL_1]] + [[VAL_2]][[[VAL_3]]]; +// CPP-DECLTOP-NEXT: return; +// CPP-DECLTOP-NEXT: } + +emitc.func @inline_side_effects_into_assign(%arg0: i32, %arg1: !emitc.ptr<i32>) { + %c0 = "emitc.constant"() {value = 0 : i64} : () -> i64 + %0 = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32> + %ptr = emitc.subscript %arg1[%c0] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32> + %result = emitc.expression %arg0, %0, %ptr : (i32, !emitc.lvalue<i32>, !emitc.lvalue<i32>) -> i32 { + %a = emitc.load %0 : !emitc.lvalue<i32> + %b = emitc.mul %a, %arg0 : (i32, i32) -> i32 + %c = emitc.load %ptr : !emitc.lvalue<i32> + %d = emitc.add %b, %c : (i32, i32) -> i32 + yield %d : i32 + } + emitc.assign %result : i32 to %0 : !emitc.lvalue<i32> + emitc.return +} + +// CPP-DEFAULT: void do_not_inline_side_effects_into_assign(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int64_t [[VAL_3:v[0-9]+]] = 0; +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42; +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_4]] * [[VAL_1]]; +// CPP-DEFAULT-NEXT: [[VAL_2]][[[VAL_3]]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: void do_not_inline_side_effects_into_assign(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int64_t [[VAL_3:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_3]] = 0; +// CPP-DECLTOP-NEXT: [[VAL_4]] = 42; +// CPP-DECLTOP-NEXT: [[VAL_5:v[0-9]+]] = [[VAL_4]] * [[VAL_1]]; +// CPP-DECLTOP-NEXT: [[VAL_2]][[[VAL_3]]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: return; +// CPP-DECLTOP-NEXT: } + +emitc.func @do_not_inline_side_effects_into_assign(%arg0: i32, %arg1: !emitc.ptr<i32>) { + %c0 = "emitc.constant"() {value = 0 : i64} : () -> i64 + %0 = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32> + %ptr = emitc.subscript %arg1[%c0] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32> + %result = emitc.expression %arg0, %0 : (i32, !emitc.lvalue<i32>) -> i32 { + %a = emitc.load %0 : !emitc.lvalue<i32> + %b = emitc.mul %a, %arg0 : (i32, i32) -> i32 + yield %b : i32 + } + emitc.assign %result : i32 to %ptr : !emitc.lvalue<i32> + emitc.return +} + +// CPP-DEFAULT: int32_t do_not_inline_non_preceding_side_effects(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int64_t [[VAL_3:v[0-9]+]] = 0; +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 42; +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_4]] * [[VAL_1]]; +// CPP-DEFAULT-NEXT: [[VAL_2]][[[VAL_3]]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: return [[VAL_5]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t do_not_inline_non_preceding_side_effects(int32_t [[VAL_1:v[0-9]+]], int32_t* [[VAL_2:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int64_t [[VAL_3:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: [[VAL_3:v[0-9]+]] = 0; +// CPP-DECLTOP-NEXT: [[VAL_4:v[0-9]+]] = 42; +// CPP-DECLTOP-NEXT: [[VAL_5:v[0-9]+]] = [[VAL_4]] * [[VAL_1]]; +// CPP-DECLTOP-NEXT: [[VAL_2]][[[VAL_3]]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: return [[VAL_5]]; +// CPP-DECLTOP-NEXT: } + +emitc.func @do_not_inline_non_preceding_side_effects(%arg0: i32, %arg1: !emitc.ptr<i32>) -> i32 { + %c0 = "emitc.constant"() {value = 0 : i64} : () -> i64 + %0 = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32> + %ptr = emitc.subscript %arg1[%c0] : (!emitc.ptr<i32>, i64) -> !emitc.lvalue<i32> + %result = emitc.expression %arg0, %0 : (i32, !emitc.lvalue<i32>) -> i32 { + %a = emitc.load %0 : !emitc.lvalue<i32> + %b = emitc.mul %a, %arg0 : (i32, i32) -> i32 + yield %b : i32 + } + emitc.assign %arg0 : i32 to %ptr : !emitc.lvalue<i32> + emitc.return %result : i32 +} + +// CPP-DEFAULT: int32_t inline_side_effects_into_if(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DEFAULT-NEXT: if (bar([[VAL_1]], [[VAL_2]]) < [[VAL_3]]) { +// CPP-DEFAULT-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: } else { +// CPP-DEFAULT-NEXT: [[VAL_4]] = [[VAL_2]]; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_4]]; +// CPP-DEFAULT-NEXT: return [[VAL_5]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: int32_t inline_side_effects_into_if(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: if (bar([[VAL_1]], [[VAL_2]]) < [[VAL_3]]) { +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: } else { +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_2]]; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_4]]; +// CPP-DECLTOP-NEXT: return [[VAL_5]]; +// CPP-DECLTOP-NEXT: } + +func.func @inline_side_effects_into_if(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.lvalue<i32> + %cond = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i1 { + %a = emitc.call_opaque "bar" (%arg0, %arg1) : (i32, i32) -> (i32) + %b = emitc.cmp lt, %a, %arg2 :(i32, i32) -> i1 + emitc.yield %b : i1 + } + emitc.if %cond { + emitc.assign %arg0 : i32 to %v : !emitc.lvalue<i32> + emitc.yield + } else { + emitc.assign %arg1 : i32 to %v : !emitc.lvalue<i32> + emitc.yield + } + %v_load = emitc.load %v : !emitc.lvalue<i32> + return %v_load : i32 +} + +// CPP-DEFAULT: void inline_side_effects_into_switch(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: switch (bar([[VAL_1]], [[VAL_2]]) + [[VAL_3]]) { +// CPP-DEFAULT-NEXT: case 2: { +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = func_b(); +// CPP-DEFAULT-NEXT: break; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: case 5: { +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = func_a(); +// CPP-DEFAULT-NEXT: break; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: default: { +// CPP-DEFAULT-NEXT: float [[VAL_6:v[0-9]+]] = 4.200000000e+01f; +// CPP-DEFAULT-NEXT: func2([[VAL_6]]); +// CPP-DEFAULT-NEXT: break; +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: void inline_side_effects_into_switch(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: float [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: switch (bar([[VAL_1]], [[VAL_2]]) + [[VAL_3]]) { +// CPP-DECLTOP-NEXT: case 2: { +// CPP-DECLTOP-NEXT: [[VAL_4]] = func_b(); +// CPP-DECLTOP-NEXT: break; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: case 5: { +// CPP-DECLTOP-NEXT: [[VAL_5]] = func_a(); +// CPP-DECLTOP-NEXT: break; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: default: { +// CPP-DECLTOP-NEXT: [[VAL_6]] = 4.200000000e+01f; +// CPP-DECLTOP-NEXT: func2([[VAL_6]]); +// CPP-DECLTOP-NEXT: break; +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: } +// CPP-DECLTOP-NEXT: return; +// CPP-DECLTOP-NEXT: } + +func.func @inline_side_effects_into_switch(%arg0: i32, %arg1: i32, %arg2: i32) { + %0 = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i32 { + %a = emitc.call_opaque "bar" (%arg0, %arg1) : (i32, i32) -> (i32) + %b = emitc.add %a, %arg2 :(i32, i32) -> i32 + emitc.yield %b : i32 + } + emitc.switch %0 : i32 + case 2 { + %1 = emitc.call_opaque "func_b" () : () -> i32 + emitc.yield + } + case 5 { + %2 = emitc.call_opaque "func_a" () : () -> i32 + emitc.yield + } + default { + %3 = "emitc.constant"(){value = 42.0 : f32} : () -> f32 + emitc.call_opaque "func2" (%3) : (f32) -> () + emitc.yield + } + return +} diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll index e056e43..61376b8 100644 --- a/mlir/test/Target/LLVMIR/Import/debug-info.ll +++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll @@ -240,11 +240,10 @@ define void @subprogram() !dbg !3 { define void @func_loc() !dbg !3 { ret void } -; CHECK-DAG: #[[NAME_LOC:.+]] = loc("func_loc") ; CHECK-DAG: #[[FILE_LOC:.+]] = loc("debug-info.ll":42:0) ; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<id = distinct[{{.*}}]<>, compileUnit = #{{.*}}, scope = #{{.*}}, name = "func_loc", file = #{{.*}}, line = 42, subprogramFlags = Definition> -; CHECK: loc(fused<#[[SP]]>[#[[NAME_LOC]], #[[FILE_LOC]]] +; CHECK: loc(fused<#[[SP]]>[#[[FILE_LOC]]] !llvm.dbg.cu = !{!1} !llvm.module.flags = !{!0} diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll index cc3d799..00d09ba 100644 --- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll @@ -393,6 +393,12 @@ declare void @alwaysinline_attribute() alwaysinline // ----- +; CHECK-LABEL: @inlinehint_attribute +; CHECK-SAME: attributes {inline_hint} +declare void @inlinehint_attribute() inlinehint + +// ----- + ; CHECK-LABEL: @optnone_attribute ; CHECK-SAME: attributes {no_inline, optimize_none} declare void @optnone_attribute() noinline optnone diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir index abdf2fe..160a9ce 100644 --- a/mlir/test/Target/LLVMIR/amx.mlir +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -23,6 +23,19 @@ func.func @amx_tile_load_store(%base: memref<?x?xi8>, %out: memref<?x?xi8>, return } +// CHECK-LABEL: define void @amx_tile_load_store_strided +func.func @amx_tile_load_store_strided(%base: memref<?xi8>, %out: memref<?xi8>, + %idx: index, %stride: index) +{ + // CHECK: call x86_amx @llvm.x86.tileloadd64.internal + // CHECK: call void @llvm.x86.tilestored64.internal + %val = amx.tile_load %base[%idx], %stride + : memref<?xi8> into !amx.tile<16x64xi8> + amx.tile_store %out[%idx], %val, %stride + : memref<?xi8>, !amx.tile<16x64xi8> + return +} + // CHECK-LABEL: define void @amx_tile_mulf_bf16 func.func @amx_tile_mulf_bf16( %matA: memref<?x?xbf16>, %matB: memref<?x?xbf16>, %idx: index, diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 69814f2..cc243c8 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2555,6 +2555,17 @@ llvm.func @always_inline() attributes { always_inline } { // ----- +// CHECK-LABEL: @inline_hint +// CHECK-SAME: #[[ATTRS:[0-9]+]] +llvm.func @inline_hint() attributes { inline_hint } { + llvm.return +} + +// CHECK: #[[ATTRS]] +// CHECK-SAME: inlinehint + +// ----- + // CHECK-LABEL: @optimize_none // CHECK-SAME: #[[ATTRS:[0-9]+]] llvm.func @optimize_none() attributes { no_inline, optimize_none } { diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir new file mode 100644 index 0000000..04e2ddf --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1 +llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8 + %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN) + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8 + %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 0b36154..6cccfe4 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) { // ----- +llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) { + // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}} + %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN) + llvm.return +} + +// ----- + llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) { // expected-error @below {{cache eviction priority supported only for cache level L2}} nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1> @@ -559,3 +567,25 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ %res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1 llvm.return } + +// ----- + +// Test that ensures invalid row/col layouts for matrices A and B are not accepted +llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { + // expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<col>, + multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +// ----- + +// Test for range validation - invalid range where lower == upper but not at extremes +func.func @invalid_range_equal_bounds() { + // expected-error @below {{invalid range attribute: Lower == Upper, but they aren't min (0) or max (4294967295) value! This is an invalid constant range.}} + %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32 + return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 00a479d..594ae48 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -152,6 +152,10 @@ llvm.func @nvvm_special_regs() -> i32 { %74 = nvvm.read.ptx.sreg.lanemask.ge : i32 //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt %75 = nvvm.read.ptx.sreg.lanemask.gt : i32 + // CHECK: %76 = call range(i32 0, 0) i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %76 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 0> : i32 + // CHECK: %77 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %77 = nvvm.read.ptx.sreg.tid.x range <i32, 4294967295, 4294967295> : i32 llvm.return %1 : i32 } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index fdd2c91..6536fac 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -276,6 +276,20 @@ llvm.func @rocdl.s.wait.expcnt() { llvm.return } +llvm.func @rocdl.s.wait.asynccnt() { + // CHECK-LABEL: rocdl.s.wait.asynccnt + // CHECK-NEXT: call void @llvm.amdgcn.s.wait.asynccnt(i16 0) + rocdl.s.wait.asynccnt 0 + llvm.return +} + +llvm.func @rocdl.s.wait.tensorcnt() { + // CHECK-LABEL: rocdl.s.wait.tensorcnt + // CHECK-NEXT: call void @llvm.amdgcn.s.wait.tensorcnt(i16 0) + rocdl.s.wait.tensorcnt 0 + llvm.return +} + llvm.func @rocdl.setprio() { // CHECK: call void @llvm.amdgcn.s.setprio(i16 0) rocdl.s.setprio 0 diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt index 9187998..c37671a 100644 --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(MLIRTestAnalysis DataFlow/TestDenseForwardDataFlowAnalysis.cpp DataFlow/TestLivenessAnalysis.cpp DataFlow/TestSparseBackwardDataFlowAnalysis.cpp + DataFlow/TestStridedMetadataRangeAnalysis.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp new file mode 100644 index 0000000..6ac09fd --- /dev/null +++ b/mlir/test/lib/Analysis/DataFlow/TestStridedMetadataRangeAnalysis.cpp @@ -0,0 +1,86 @@ +//===- TestStridedMetadataRangeAnalysis.cpp - Test strided md analysis ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::dataflow; + +static void printAnalysisResults(DataFlowSolver &solver, Operation *op, + raw_ostream &os) { + // Collect the strided metadata of the op results. + SmallVector<std::pair<unsigned, const StridedMetadataRangeLattice *>> results; + for (OpResult result : op->getResults()) { + const auto *state = solver.lookupState<StridedMetadataRangeLattice>(result); + // Skip the result if it's uninitialized. + if (!state || state->getValue().isUninitialized()) + continue; + + // Skip the result if the range is empty. + const mlir::StridedMetadataRange &md = state->getValue(); + if (md.getOffsets().empty() && md.getSizes().empty() && + md.getStrides().empty()) + continue; + results.push_back({result.getResultNumber(), state}); + } + + // Early exit if there's no metadata to print. + if (results.empty()) + return; + + // Print the metadata. + os << "Op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n"; + for (auto [idx, state] : results) + os << " result[" << idx << "]: " << state->getValue() << "\n"; + os << "\n"; +} + +namespace { +struct TestStridedMetadataRangeAnalysisPass + : public PassWrapper<TestStridedMetadataRangeAnalysisPass, + OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestStridedMetadataRangeAnalysisPass) + + StringRef getArgument() const override { + return "test-strided-metadata-range-analysis"; + } + void runOnOperation() override { + Operation *op = getOperation(); + + DataFlowSolver solver; + solver.load<DeadCodeAnalysis>(); + solver.load<SparseConstantPropagation>(); + solver.load<IntegerRangeAnalysis>(); + solver.load<StridedMetadataRangeAnalysis>(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + op->walk( + [&](Operation *op) { printAnalysisResults(solver, op, llvm::errs()); }); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestStridedMetadataRangeAnalysisPass() { + PassRegistration<TestStridedMetadataRangeAnalysisPass>(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp index 1e2d4a7..4069a74 100644 --- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp +++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp @@ -11,11 +11,25 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "TestAttributes.h" // TestTensorEncodingAttr, TestMemRefLayoutAttr +#include "TestDialect.h" + using namespace mlir; namespace { +MemRefLayoutAttrInterface +getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) { + if (auto encoding = dyn_cast_if_present<test::TestTensorEncodingAttr>( + tensorType.getEncoding())) { + return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get( + tensorType.getContext(), encoding.getDummy())); + } + return {}; +} + struct TestOneShotModuleBufferizePass : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass) @@ -25,6 +39,7 @@ struct TestOneShotModuleBufferizePass : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<test::TestDialect>(); registry.insert<bufferization::BufferizationDialect>(); } StringRef getArgument() const final { @@ -41,6 +56,17 @@ struct TestOneShotModuleBufferizePass bufferization::OneShotBufferizationOptions opt; opt.bufferizeFunctionBoundaries = true; + opt.functionArgTypeConverterFn = + [&](bufferization::TensorLikeType tensor, Attribute memSpace, + func::FuncOp, const bufferization::BufferizationOptions &) { + assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors"); + auto tensorType = cast<RankedTensorType>(tensor); + auto layout = getMemRefLayoutForTensorEncoding(tensorType); + return cast<bufferization::BufferLikeType>( + MemRefType::get(tensorType.getShape(), + tensorType.getElementType(), layout, memSpace)); + }; + bufferization::BufferizationState bufferizationState; if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt, diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 727c84c..8c5c8e8 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -276,10 +276,8 @@ void TestLinalgTransforms::runOnOperation() { Operation *consumer = opOperand->getOwner(); // If we have a pack/unpack consumer and a producer that has multiple // uses, do not apply the folding patterns. - if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) && - isa<TilingInterface>(producer) && !producer->hasOneUse()) - return false; - return true; + return !(isa<linalg::PackOp, linalg::UnPackOp>(consumer) && + isa<TilingInterface>(producer) && !producer->hasOneUse()); }; applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn); } diff --git a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt index f84055d..1e59338 100644 --- a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(MLIROpenACCTestPasses TestOpenACC.cpp TestPointerLikeTypeInterface.cpp + TestRecipePopulate.cpp EXCLUDE_FROM_LIBMLIR ) diff --git a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp index 9886240..bea21b9 100644 --- a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp +++ b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp @@ -15,9 +15,13 @@ namespace test { // Forward declarations of individual test pass registration functions void registerTestPointerLikeTypeInterfacePass(); +void registerTestRecipePopulatePass(); // Unified registration function for all OpenACC tests -void registerTestOpenACC() { registerTestPointerLikeTypeInterfacePass(); } +void registerTestOpenACC() { + registerTestPointerLikeTypeInterfacePass(); + registerTestRecipePopulatePass(); +} } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp index 85f9283..027b0a1 100644 --- a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp +++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp @@ -196,13 +196,15 @@ void TestPointerLikeTypeInterfacePass::testGenAllocate( newBuilder.setInsertionPointAfter(op); // Call the genAllocate API + bool needsFree = false; Value allocRes = pointerType.genAllocate(newBuilder, loc, "test_alloc", - result.getType(), result); + result.getType(), result, needsFree); if (allocRes) { llvm::errs() << "Successfully generated alloc for operation: "; op->print(llvm::errs()); llvm::errs() << "\n"; + llvm::errs() << "\tneeds free: " << (needsFree ? "true" : "false") << "\n"; // Print all operations that were inserted for (Operation *insertedOp : tracker.insertedOps) { @@ -230,8 +232,8 @@ void TestPointerLikeTypeInterfacePass::testGenFree(Operation *op, Value result, // Call the genFree API auto typedResult = cast<TypedValue<PointerLikeType>>(result); - bool success = - pointerType.genFree(newBuilder, loc, typedResult, result.getType()); + bool success = pointerType.genFree(newBuilder, loc, typedResult, result, + result.getType()); if (success) { llvm::errs() << "Successfully generated free for operation: "; diff --git a/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp new file mode 100644 index 0000000..35f092c --- /dev/null +++ b/mlir/test/lib/Dialect/OpenACC/TestRecipePopulate.cpp @@ -0,0 +1,110 @@ +//===- TestRecipePopulate.cpp - Test Recipe Population -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for testing the createAndPopulate methods +// of the recipe operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/CommandLine.h" + +using namespace mlir; +using namespace mlir::acc; + +namespace { + +struct TestRecipePopulatePass + : public PassWrapper<TestRecipePopulatePass, OperationPass<ModuleOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRecipePopulatePass) + + TestRecipePopulatePass() = default; + TestRecipePopulatePass(const TestRecipePopulatePass &pass) + : PassWrapper(pass) { + recipeType = pass.recipeType; + } + + Pass::Option<std::string> recipeType{ + *this, "recipe-type", + llvm::cl::desc("Recipe type: private or firstprivate"), + llvm::cl::init("private")}; + + StringRef getArgument() const override { return "test-acc-recipe-populate"; } + + StringRef getDescription() const override { + return "Test OpenACC recipe population"; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<acc::OpenACCDialect>(); + registry.insert<arith::ArithDialect>(); + registry.insert<memref::MemRefDialect>(); + } +}; + +void TestRecipePopulatePass::runOnOperation() { + auto module = getOperation(); + OpBuilder builder(&getContext()); + + // Collect all test variables + SmallVector<std::tuple<Operation *, Value, std::string>> testVars; + + module.walk([&](Operation *op) { + if (auto varName = op->getAttrOfType<StringAttr>("test.var")) { + for (auto result : op->getResults()) { + testVars.push_back({op, result, varName.str()}); + } + } + }); + + // Generate recipes at module level + builder.setInsertionPoint(&module.getBodyRegion().front(), + module.getBodyRegion().front().begin()); + + for (auto [op, var, varName] : testVars) { + Location loc = op->getLoc(); + + std::string recipeName = recipeType.getValue() + "_" + varName; + ValueRange bounds; // No bounds for memref tests + + if (recipeType == "private") { + auto recipe = PrivateRecipeOp::createAndPopulate( + builder, loc, recipeName, var.getType(), varName, bounds); + + if (!recipe) { + op->emitError("Failed to create private recipe for ") << varName; + } + } else if (recipeType == "firstprivate") { + auto recipe = FirstprivateRecipeOp::createAndPopulate( + builder, loc, recipeName, var.getType(), varName, bounds); + + if (!recipe) { + op->emitError("Failed to create firstprivate recipe for ") << varName; + } + } + } +} + +} // namespace + +namespace mlir { +namespace test { + +void registerTestRecipePopulatePass() { + PassRegistration<TestRecipePopulatePass>(); +} + +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 5685004..9e7e4f8 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/TensorEncoding.td" // All of the attributes will extend this class. class Test_Attr<string name, list<Trait> traits = []> @@ -439,4 +440,20 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> { let hasStorageCustomConstructor = 1; } +def TestTensorEncodingAttr : Test_Attr<"TestTensorEncoding", + [DeclareAttrInterfaceMethods<VerifiableTensorEncoding>]> { + let mnemonic = "tensor_encoding"; + + let parameters = (ins "mlir::StringAttr":$dummy); + let assemblyFormat = "`<` $dummy `>`"; +} + +def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout", + [DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> { + let mnemonic = "memref_layout"; + + let parameters = (ins "mlir::StringAttr":$dummy); + let assemblyFormat = "`<` $dummy `>`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index fe1e916..9db7b01 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -542,6 +542,24 @@ test::detail::TestCustomStorageCtorAttrAttrStorage::construct( } //===----------------------------------------------------------------------===// +// TestTensorEncodingAttr +//===----------------------------------------------------------------------===// + +::llvm::LogicalResult TestTensorEncodingAttr::verifyEncoding( + mlir::ArrayRef<int64_t> shape, mlir::Type elementType, + llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const { + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// TestMemRefLayoutAttr +//===----------------------------------------------------------------------===// + +mlir::AffineMap TestMemRefLayoutAttr::getAffineMap() const { + return mlir::AffineMap::getMultiDimIdentityMap(1, getContext()); +} + +//===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h index 778d84fa..0ad5ab6 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -24,6 +24,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/IR/TensorEncoding.h" // generated files require above includes to come first #include "TestAttrInterfaces.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h index f2adca6..bcf3b55d 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -18,6 +18,7 @@ #include "TestInterfaces.h" #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td index 2b5491f..37a263f 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -24,7 +24,10 @@ def Test_Dialect : Dialect { let useDefaultTypePrinterParser = 0; let useDefaultAttributePrinterParser = 1; let isExtensible = 1; - let dependentDialects = ["::mlir::DLTIDialect"]; + let dependentDialects = [ + "::mlir::DLTIDialect", + "::mlir::bufferization::BufferizationDialect" + ]; let discardableAttrs = (ins "mlir::IntegerAttr":$discardable_attr_key, "SimpleAAttr":$other_discardable_attr_key diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 53055fe..b211e24 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete( return createNewMultiAllocaWithoutSlot(slot, builder, *this); } +namespace { +/// Returns test dialect's memref layout for test dialect's tensor encoding when +/// applicable. +MemRefLayoutAttrInterface +getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) { + if (auto encoding = + dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) { + return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get( + tensorType.getContext(), encoding.getDummy())); + } + return {}; +} + +/// Auxiliary bufferization function for test and builtin tensors. +bufferization::BufferLikeType +convertTensorToBuffer(mlir::Operation *op, + const bufferization::BufferizationOptions &options, + bufferization::TensorLikeType tensorLike) { + auto buffer = + *tensorLike.getBufferType(options, [&]() { return op->emitError(); }); + if (auto memref = dyn_cast<MemRefType>(buffer)) { + // Note: For the sake of testing, we want to ensure that encoding -> layout + // bufferization happens. This is currently achieved manually. + auto layout = + getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike)); + return cast<bufferization::BufferLikeType>( + MemRefType::get(memref.getShape(), memref.getElementType(), layout, + memref.getMemorySpace())); + } + return buffer; +} +} // namespace + ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( ::mlir::RewriterBase &rewriter, const ::mlir::bufferization::BufferizationOptions &options, @@ -1435,8 +1468,8 @@ TestMultiSlotAlloca::handleDestructuringComplete( return failure(); const auto outType = getOutput().getType(); - const auto bufferizedOutType = test::TestMemrefType::get( - getContext(), outType.getShape(), outType.getElementType(), nullptr); + const auto bufferizedOutType = + convertTensorToBuffer(getOperation(), options, outType); // replace op with memref analogy auto dummyMemrefOp = test::TestDummyMemrefOp::create( rewriter, getLoc(), bufferizedOutType, *buffer); @@ -1470,13 +1503,12 @@ TestMultiSlotAlloca::handleDestructuringComplete( mlir::FailureOr<mlir::bufferization::BufferLikeType> test::TestCreateTensorOp::getBufferType( - mlir::Value value, const mlir::bufferization::BufferizationOptions &, + mlir::Value value, const mlir::bufferization::BufferizationOptions &options, const mlir::bufferization::BufferizationState &, llvm::SmallVector<::mlir::Value> &) { - const auto type = dyn_cast<test::TestTensorType>(value.getType()); + const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType()); if (type == nullptr) return failure(); - return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get( - getContext(), type.getShape(), type.getElementType(), nullptr)); + return convertTensorToBuffer(getOperation(), options, type); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6329d61..05a33cf 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -32,6 +32,7 @@ include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ValueBoundsOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" // Include the attribute definitions. include "TestAttrDefs.td" @@ -2335,7 +2336,7 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op", } //===----------------------------------------------------------------------===// -// Copy Operation Test +// Copy Operation Test //===----------------------------------------------------------------------===// def CopyOp : TEST_Op<"copy", []> { @@ -3676,10 +3677,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", ["bufferize", "bufferizesToMemoryRead", "bufferizesToMemoryWrite", "getAliasingValues"]>]> { let arguments = (ins - Arg<TestTensorType>:$input + Arg<Bufferization_TensorLikeTypeInterface>:$input ); let results = (outs - Arg<TestTensorType>:$output + Arg<Bufferization_TensorLikeTypeInterface>:$output ); let extraClassDefinition = [{ @@ -3701,10 +3702,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { let arguments = (ins - Arg<TestMemrefType>:$input + Arg<Bufferization_BufferLikeTypeInterface>:$input ); let results = (outs - Arg<TestMemrefType>:$output + Arg<Bufferization_BufferLikeTypeInterface>:$output ); } @@ -3714,7 +3715,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op", "bufferizesToMemoryWrite", "getAliasingValues", "bufferizesToAllocation"]>]> { let arguments = (ins); - let results = (outs Arg<TestTensorType>:$output); + let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output); let extraClassDefinition = [{ bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&, const ::mlir::bufferization::AnalysisState&) { @@ -3738,7 +3739,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op", def TestCreateMemrefOp : TEST_Op<"create_memref_op"> { let arguments = (ins); - let results = (outs Arg<TestMemrefType>:$output); + let results = (outs Arg<Bufferization_BufferLikeTypeInterface>:$output); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 97fc699..496f18b 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -938,10 +938,10 @@ public: // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "TestTransformDialectExtensionTypes.cpp.inc" diff --git a/mlir/test/lib/Pass/TestRemarksPass.cpp b/mlir/test/lib/Pass/TestRemarksPass.cpp index 3b25686..5ca2d1a 100644 --- a/mlir/test/lib/Pass/TestRemarksPass.cpp +++ b/mlir/test/lib/Pass/TestRemarksPass.cpp @@ -43,7 +43,12 @@ public: << remark::add("This is a test missed remark") << remark::reason("because we are testing the remark pipeline") << remark::suggest("try using the remark pipeline feature"); - + mlir::remark::passed( + loc, + remark::RemarkOpts::name("test-remark").category("category-1-passed")) + << remark::add("This is a test passed remark (should be dropped)") + << remark::reason("because we are testing the remark pipeline") + << remark::suggest("try using the remark pipeline feature"); mlir::remark::passed( loc, remark::RemarkOpts::name("test-remark").category("category-1-passed")) diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll index 4e869e5..4be30d8 100644 --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -28,7 +28,7 @@ // CHECK: operation "test.op3" // CHECK: )mlir", context), std::forward<ConfigsT>(configs)...) -// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) { +// CHECK{LITERAL}: [[maybe_unused]] static void populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) { // CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext(), configs...); // CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext(), configs...); // CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext(), configs...); diff --git a/mlir/test/mlir-tblgen/cpp-class-comments.td b/mlir/test/mlir-tblgen/cpp-class-comments.td index a896888..9dcf975 100644 --- a/mlir/test/mlir-tblgen/cpp-class-comments.td +++ b/mlir/test/mlir-tblgen/cpp-class-comments.td @@ -96,17 +96,14 @@ def EncodingTrait : AttrInterface<"EncodingTrait"> { }]; let methods = [ ]; -// ATTR-INTERFACE: namespace mlir -// ATTR-INTERFACE-NEXT: namespace a -// ATTR-INTERFACE-NEXT: namespace traits +// ATTR-INTERFACE: namespace mlir::a::traits { // ATTR-INTERFACE-NEXT: /// Common trait for all layouts. // ATTR-INTERFACE-NEXT: class EncodingTrait; } def SimpleEncodingTrait : AttrInterface<"SimpleEncodingTrait"> { let cppNamespace = "a::traits"; -// ATTR-INTERFACE: namespace a { -// ATTR-INTERFACE-NEXT: namespace traits { +// ATTR-INTERFACE: namespace a::traits { // ATTR-INTERFACE-NEXT: class SimpleEncodingTrait; } @@ -116,8 +113,7 @@ def SimpleOpInterface : OpInterface<"SimpleOpInterface"> { Simple Op Interface description }]; -// OP-INTERFACE: namespace a { -// OP-INTERFACE-NEXT: namespace traits { +// OP-INTERFACE: namespace a::traits { // OP-INTERFACE-NEXT: /// Simple Op Interface description // OP-INTERFACE-NEXT: class SimpleOpInterface; } diff --git a/mlir/test/mlir-tblgen/dialect.td b/mlir/test/mlir-tblgen/dialect.td index f35ce34..9b45495 100644 --- a/mlir/test/mlir-tblgen/dialect.td +++ b/mlir/test/mlir-tblgen/dialect.td @@ -62,9 +62,14 @@ def E_SpecialNSOp : Op<E_Dialect, "special_ns_op", []> { // DEF: ::E::SPECIAL_NS::SpecialNSOp definitions // DECL-LABEL: GET_OP_CLASSES +// DECL: namespace a { // DECL: a::SomeOp declarations +// DECL: namespace BNS { // DECL: BNS::SomeOp declarations +// DECL: namespace C::CC { // DECL: ::C::CC::SomeOp declarations // DECL: DSomeOp declarations +// DECL: namespace ENS { // DECL: ENS::SomeOp declarations +// DECL: namespace E::SPECIAL_NS { // DECL: ::E::SPECIAL_NS::SpecialNSOp declarations diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py index 26ee9f3..66c4018 100644 --- a/mlir/test/python/dialects/gpu/dialect.py +++ b/mlir/test/python/dialects/gpu/dialect.py @@ -1,6 +1,7 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * +import mlir.ir as ir import mlir.dialects.gpu as gpu import mlir.dialects.gpu.passes from mlir.passmanager import * @@ -64,3 +65,95 @@ def testObjectAttr(): # CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf"> print(o) assert o.kernels == kernelTable + + +# CHECK-LABEL: testGPUFuncOp +@run +def testGPUFuncOp(): + assert gpu.GPUFuncOp.__doc__ is not None + module = Module.create() + with InsertionPoint(module.body): + gpu_module_name = StringAttr.get("gpu_module") + gpumodule = gpu.GPUModuleOp(gpu_module_name) + block = gpumodule.bodyRegion.blocks.append() + + def builder(func: gpu.GPUFuncOp) -> None: + gpu.GlobalIdOp(gpu.Dimension.x) + gpu.ReturnOp([]) + + with InsertionPoint(block): + name = StringAttr.get("kernel0") + func_type = ir.FunctionType.get(inputs=[], results=[]) + type_attr = TypeAttr.get(func_type) + func = gpu.GPUFuncOp(type_attr, name) + func.attributes["sym_name"] = name + func.attributes["gpu.kernel"] = UnitAttr.get() + + try: + func.entry_block + assert False, "Expected RuntimeError" + except RuntimeError as e: + assert ( + str(e) + == "Entry block does not exist for kernel0. Do you need to call the add_entry_block() method on this GPUFuncOp?" + ) + + block = func.add_entry_block() + with InsertionPoint(block): + builder(func) + + try: + func.add_entry_block() + assert False, "Expected RuntimeError" + except RuntimeError as e: + assert str(e) == "Entry block already exists for kernel0" + + func = gpu.GPUFuncOp( + func_type, + sym_name="kernel1", + kernel=True, + body_builder=builder, + known_block_size=[1, 2, 3], + known_grid_size=DenseI32ArrayAttr.get([4, 5, 6]), + ) + + assert func.name.value == "kernel1" + assert func.function_type.value == func_type + assert func.arg_attrs == None + assert func.res_attrs == None + assert func.arguments == [] + assert func.entry_block == func.body.blocks[0] + assert func.is_kernel + assert func.known_block_size == DenseI32ArrayAttr.get( + [1, 2, 3] + ), func.known_block_size + assert func.known_grid_size == DenseI32ArrayAttr.get( + [4, 5, 6] + ), func.known_grid_size + + func = gpu.GPUFuncOp( + func_type, + sym_name="non_kernel_func", + body_builder=builder, + ) + assert not func.is_kernel + assert func.known_block_size is None + assert func.known_grid_size is None + + print(module) + + # CHECK: gpu.module @gpu_module + # CHECK: gpu.func @kernel0() kernel { + # CHECK: %[[VAL_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } + # CHECK: gpu.func @kernel1() kernel attributes + # CHECK-SAME: known_block_size = array<i32: 1, 2, 3> + # CHECK-SAME: known_grid_size = array<i32: 4, 5, 6> + # CHECK: %[[VAL_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } + # CHECK: gpu.func @non_kernel_func() { + # CHECK: %[[VAL_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py new file mode 100644 index 0000000..8f2142a --- /dev/null +++ b/mlir/test/python/dialects/openacc.py @@ -0,0 +1,171 @@ +# RUN: %PYTHON %s | FileCheck %s +from unittest import result +from mlir.ir import ( + Context, + FunctionType, + Location, + Module, + InsertionPoint, + IntegerType, + IndexType, + MemRefType, + F32Type, + Block, + ArrayAttr, + Attribute, + UnitAttr, + StringAttr, + DenseI32ArrayAttr, + ShapedType, +) +from mlir.dialects import openacc, func, arith, memref +from mlir.extras import types + + +def run(f): + print("\n// TEST:", f.__name__) + with Context(), Location.unknown(): + f() + return f + + +@run +def testParallelMemcpy(): + module = Module.create() + + dynamic = ShapedType.get_dynamic_size() + memref_f32_1d_any = MemRefType.get([dynamic], types.f32()) + + with InsertionPoint(module.body): + function_type = FunctionType.get( + [memref_f32_1d_any, memref_f32_1d_any, types.i64()], [] + ) + f = func.FuncOp( + type=function_type, + name="memcpy_idiom", + ) + f.attributes["sym_visibility"] = StringAttr.get("public") + + with InsertionPoint(f.add_entry_block()): + c1024 = arith.ConstantOp(types.i32(), 1024) + c128 = arith.ConstantOp(types.i32(), 128) + + arg0, arg1, arg2 = f.arguments + + copied = openacc.copyin( + acc_var=arg0.type, + var=arg0, + var_type=types.f32(), + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + created = openacc.create_( + acc_var=arg1.type, + var=arg1, + var_type=types.f32(), + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + + parallel_op = openacc.ParallelOp( + asyncOperands=[], + waitOperands=[], + numGangs=[c1024], + numWorkers=[], + vectorLength=[c128], + reductionOperands=[], + privateOperands=[], + firstprivateOperands=[], + dataClauseOperands=[], + ) + + # Set required device_type and segment attributes to satisfy verifier + acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")]) + parallel_op.numGangsDeviceType = acc_device_none + parallel_op.numGangsSegments = DenseI32ArrayAttr.get([1]) + parallel_op.vectorLengthDeviceType = acc_device_none + + parallel_block = Block.create_at_start(parent=parallel_op.region, arg_types=[]) + + with InsertionPoint(parallel_block): + c0 = arith.ConstantOp(types.i64(), 0) + c1 = arith.ConstantOp(types.i64(), 1) + + loop_op = openacc.LoopOp( + results_=[], + lowerbound=[c0], + upperbound=[f.arguments[2]], + step=[c1], + gangOperands=[], + workerNumOperands=[], + vectorOperands=[], + tileOperands=[], + cacheOperands=[], + privateOperands=[], + reductionOperands=[], + firstprivateOperands=[], + ) + + # Set loop attributes: gang and independent on device_type<none> + acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")]) + loop_op.gang = acc_device_none + loop_op.independent = acc_device_none + + loop_block = Block.create_at_start( + parent=loop_op.region, arg_types=[types.i64()] + ) + + with InsertionPoint(loop_block): + idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0]) + val = memref.load(memref=copied, indices=[idx]) + memref.store(value=val, memref=created, indices=[idx]) + openacc.YieldOp([]) + + openacc.YieldOp([]) + + deleted = openacc.delete( + acc_var=copied, + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + copied = openacc.copyout( + acc_var=created, + var=arg1, + var_type=types.f32(), + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + func.ReturnOp([]) + + print(module) + + # CHECK: TEST: testParallelMemcpy + # CHECK-LABEL: func.func public @memcpy_idiom( + # CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: i64) { + # CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : i32 + # CHECK: %[[CONSTANT_1:.*]] = arith.constant 128 : i32 + # CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ARG0]] : memref<?xf32>) -> memref<?xf32> + # CHECK: %[[CREATE_0:.*]] = acc.create varPtr(%[[ARG1]] : memref<?xf32>) -> memref<?xf32> + # CHECK: acc.parallel num_gangs({%[[CONSTANT_0]] : i32}) vector_length(%[[CONSTANT_1]] : i32) { + # CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i64 + # CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i64 + # CHECK: acc.loop gang control(%[[VAL_0:.*]] : i64) = (%[[CONSTANT_2]] : i64) to (%[[ARG2]] : i64) step (%[[CONSTANT_3]] : i64) { + # CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i64 to index + # CHECK: %[[LOAD_0:.*]] = memref.load %[[COPYIN_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32> + # CHECK: memref.store %[[LOAD_0]], %[[CREATE_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32> + # CHECK: acc.yield + # CHECK: } attributes {independent = [#acc.device_type<none>]} + # CHECK: acc.yield + # CHECK: } + # CHECK: acc.delete accPtr(%[[COPYIN_0]] : memref<?xf32>) + # CHECK: acc.copyout accPtr(%[[CREATE_0]] : memref<?xf32>) to varPtr(%[[ARG1]] : memref<?xf32>) + # CHECK: return + # CHECK: } diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 8785d6d..d6b70dc 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -109,13 +109,29 @@ def testFuseOpCompact(target): ) # CHECK-LABEL: TEST: testFuseOpCompact # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] - # CHECK-SAME: interchange [0, 1] apply_cleanup = true + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8] + # CHECK-SAME: interchange [0, 1] {apply_cleanup} # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) @run @create_sequence +def testFuseOpCompactForall(target): + structured.FuseOp( + target, + tile_sizes=[4, 8], + apply_cleanup=True, + use_forall=True, + ) + # CHECK-LABEL: TEST: testFuseOpCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8] + # CHECK-SAME: {apply_cleanup, use_forall} + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + +@run +@create_sequence def testFuseOpNoArg(target): structured.FuseOp(target) # CHECK-LABEL: TEST: testFuseOpNoArg @@ -126,13 +142,51 @@ def testFuseOpNoArg(target): @run @create_sequence +def testFuseOpParams(target): + structured.FuseOp( + target, + tile_sizes=[constant_param(4), Attribute.parse("8")], + tile_interchange=[constant_param(0), Attribute.parse("1")], + ) + # CHECK-LABEL: TEST: testFuseOpParams + # CHECK: transform.sequence + # CHECK-DAG: %[[P:.*]] = transform.param.constant 4 + # CHECK-DAG: %[[I:.*]] = transform.param.constant 0 + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse + # CHECK-SAME: tile_sizes [%[[P]], 8] + # CHECK-SAME: interchange [%[[I]], 1] + # CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpHandles(target): + size1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"]) + structured.FuseOp( + target, + tile_sizes=[size1, 8], + tile_interchange=[ichange1, 1], + ) + # CHECK-LABEL: TEST: testFuseOpHandles + # CHECK: transform.sequence + # CHECK: %[[H:.*]] = transform.structured.match + # CHECK: %[[I:.*]] = transform.structured.match + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse + # CHECK-SAME: tile_sizes [%[[H]], 8] + # CHECK-SAME: interchange [%[[I]], 1] + # CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence def testFuseOpAttributes(target): attr = DenseI64ArrayAttr.get([4, 8]) ichange = DenseI64ArrayAttr.get([0, 1]) structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange) # CHECK-LABEL: TEST: testFuseOpAttributes # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8] # CHECK-SAME: interchange [0, 1] # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index cb4cfc8c..1d4ede1 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -569,12 +569,30 @@ def testOperationAttributes(): # CHECK: Attribute value b'text' print(f"Attribute value {sattr.value_bytes}") + # Python dict-style iteration # We don't know in which order the attributes are stored. - # CHECK-DAG: NamedAttribute(dependent="text") - # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) - # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) - for attr in op.attributes: - print(str(attr)) + # CHECK-DAG: dependent + # CHECK-DAG: other.attribute + # CHECK-DAG: some.attribute + for name in op.attributes: + print(name) + + # Basic dict-like introspection + # CHECK: True + print("some.attribute" in op.attributes) + # CHECK: False + print("missing" in op.attributes) + # CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute'] + print("Keys:", sorted(op.attributes.keys())) + # CHECK: Values count 3 + print("Values count", len(op.attributes.values())) + # CHECK: Items count 3 + print("Items count", len(op.attributes.items())) + + # Dict() conversion test + d = {k: v.value for k, v in dict(op.attributes).items()} + # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1} + print("Dict mapping", d) # Check that exceptions are raised as expected. try: diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 6432fae..8842180 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -151,6 +151,7 @@ void registerTestSliceAnalysisPass(); void registerTestSPIRVCPURunnerPipeline(); void registerTestSPIRVFuncSignatureConversion(); void registerTestSPIRVVectorUnrolling(); +void registerTestStridedMetadataRangeAnalysisPass(); void registerTestTensorCopyInsertionPass(); void registerTestTensorLikeAndBufferLikePass(); void registerTestTensorTransforms(); @@ -299,6 +300,7 @@ void registerTestPasses() { mlir::test::registerTestSPIRVCPURunnerPipeline(); mlir::test::registerTestSPIRVFuncSignatureConversion(); mlir::test::registerTestSPIRVVectorUnrolling(); + mlir::test::registerTestStridedMetadataRangeAnalysisPass(); mlir::test::registerTestTensorCopyInsertionPass(); mlir::test::registerTestTensorLikeAndBufferLikePass(); mlir::test::registerTestTensorTransforms(); diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp index f99dcdb..76122a0 100644 --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -19,6 +19,7 @@ #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/VirtualFileSystem.h" #include <set> using namespace mlir; @@ -41,6 +42,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, bool dumpODS, std::set<std::string> *includedFiles) { llvm::SourceMgr sourceMgr; sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc()); // If we are dumping ODS information, also enable documentation to ensure the diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index d55ad482..11bf9ce 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -701,11 +702,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); auto enumerants = enumInfo.getAllCases(); - SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, cppNamespace); // Emit the enum class definition emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); @@ -766,8 +763,7 @@ public: os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); } - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; + ns.close(); // Generate a generic parser and printer for the enum. std::string qualName = @@ -790,13 +786,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); - StringRef cppNamespace = enumInfo.getCppNamespace(); - SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); if (enumInfo.isBitEnum()) { emitSymToStrFnForBitEnum(enumDef, os); @@ -810,10 +801,6 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { if (enumInfo.genSpecializedAttr()) emitSpecializedAttrDef(enumDef, os); - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; - os << "\n"; } static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index 96af14d..11a2db4 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -416,7 +416,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) { // Emit the function converting the enum attribute to its LLVM counterpart. os << formatv( - "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n", + "[[maybe_unused]] static {0} convert{1}ToLLVM({2}::{1} value) {{\n", llvmClass, cppClassName, cppNamespace); os << " switch (value) {\n"; @@ -444,7 +444,7 @@ static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) { StringRef cppNamespace = enumAttr.getCppNamespace(); // Emit the function converting the enum attribute to its LLVM counterpart. - os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t " + os << formatv("[[maybe_unused]] static int64_t " "convert{0}ToLLVM({1}::{0} value) {{\n", cppClassName, cppNamespace); os << " switch (value) {\n"; @@ -474,7 +474,7 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) { StringRef cppNamespace = enumInfo.getCppNamespace(); // Emit the function converting the enum attribute from its LLVM counterpart. - os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} " + os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM({2} " "value) {{\n", cppNamespace, cppClassName, llvmClass); os << " switch (value) {\n"; @@ -509,10 +509,9 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) { StringRef cppNamespace = enumInfo.getCppNamespace(); // Emit the function converting the enum attribute from its LLVM counterpart. - os << formatv( - "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t " - "value) {{\n", - cppNamespace, cppClassName); + os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM(int64_t " + "value) {{\n", + cppNamespace, cppClassName); os << " switch (value) {\n"; for (const auto &enumerant : enumInfo.getAllCases()) { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index daae3c7..3718648 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -4896,7 +4896,7 @@ static void emitOpClassDefs(const RecordKeeper &records, constraintPrefix); os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); staticVerifierEmitter.collectOpConstraints(defs); - staticVerifierEmitter.emitOpConstraints(defs); + staticVerifierEmitter.emitOpConstraints(); // Emit the classes. emitOpClasses(records, defs, os, staticVerifierEmitter, diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 730b5b2..ab8d534 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -342,11 +343,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { } void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; - + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); for (auto &method : interface.getMethods()) { os << "template<typename " << valueTemplate << ">\n"; emitCPPType(method.getReturnType(), os); @@ -442,18 +439,11 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { method.isStatic() ? &ctx : &nonStaticMethodFmt); os << "\n}\n"; } - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; - - os << "namespace detail {\n"; + auto cppNamespace = (interface.getCppNamespace() + "::detail").str(); + llvm::NamespaceEmitter ns(os, cppNamespace); StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); @@ -504,10 +494,6 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; - os << "}// namespace detail\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } static void emitInterfaceDeclMethods(const Interface &interface, @@ -533,10 +519,7 @@ static void emitInterfaceDeclMethods(const Interface &interface, } void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); // Emit a forward declaration of the interface class so that it becomes usable // in the signature of its methods. @@ -545,16 +528,10 @@ void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { StringRef interfaceName = interface.getName(); os << "class " << interfaceName << ";\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); @@ -631,9 +608,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { } os << "};\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } bool InterfaceGenerator::emitInterfaceDecls() { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 40bc1a9..c3034bb8 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -2120,7 +2120,7 @@ static void emitRewriters(const RecordKeeper &records, raw_ostream &os) { } // Emit function to add the generated matchers to the pattern list. - os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" + os << "[[maybe_unused]] void populateWithGenerated(" "::mlir::RewritePatternSet &patterns) {\n"; for (const auto &name : rewriterNames) { os << " patterns.add<" << name << ">(patterns.getContext());\n"; diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 3ead2f0..ca291b5 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -259,8 +259,8 @@ static void emitInterfaceDecl(const Availability &availability, std::string interfaceTraitsName = std::string(formatv("{0}Traits", interfaceName)); - StringRef cppNamespace = availability.getInterfaceClassNamespace(); - llvm::NamespaceEmitter nsEmitter(os, cppNamespace); + llvm::NamespaceEmitter nsEmitter(os, + availability.getInterfaceClassNamespace()); os << "class " << interfaceName << ";\n\n"; // Emit the traits struct containing the concept and model declarations. @@ -418,15 +418,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); StringRef enumName = enumInfo.getEnumClassName(); - StringRef cppNamespace = enumInfo.getCppNamespace(); auto enumerants = enumInfo.getAllCases(); - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; - + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); llvm::StringSet<> handledClasses; // Place all availability specifications to their corresponding @@ -441,9 +435,6 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { enumName); handledClasses.insert(className); } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { @@ -459,31 +450,19 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); - StringRef cppNamespace = enumInfo.getCppNamespace(); - - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); - if (enumInfo.isBitEnum()) { + if (enumInfo.isBitEnum()) emitAvailabilityQueryForBitEnum(enumDef, os); - } else { + else emitAvailabilityQueryForIntEnum(enumDef, os); - } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; - os << "\n"; } static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os, records); - auto defs = records.getAllDerivedDefinitions("EnumInfo"); - for (const auto *def : defs) + for (const Record *def : records.getAllDerivedDefinitions("EnumInfo")) emitEnumDef(*def, os); return false; diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index fe26fc1..2a58305 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -113,8 +113,7 @@ static Match tensorMatch(TensorId tid) { return Match(tid); } static Match synZeroMatch() { return Match(); } #define IMPL_BINOP_PATTERN(OP, KIND) \ - LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \ - const Match &e1) { \ + [[maybe_unused]] static Match OP##Match(const Match &e0, const Match &e1) { \ return Match(KIND, e0, e1); \ } FOREVERY_BINOP(IMPL_BINOP_PATTERN) diff --git a/mlir/unittests/IR/RemarkTest.cpp b/mlir/unittests/IR/RemarkTest.cpp index bcbda90..09c576c 100644 --- a/mlir/unittests/IR/RemarkTest.cpp +++ b/mlir/unittests/IR/RemarkTest.cpp @@ -53,10 +53,12 @@ TEST(Remark, TestOutputOptimizationRemark) { /*missed=*/categoryUnroll, /*analysis=*/categoryRegister, /*failed=*/categoryInliner}; - + std::unique_ptr<remark::RemarkEmittingPolicyAll> policy = + std::make_unique<remark::RemarkEmittingPolicyAll>(); LogicalResult isEnabled = mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - context, yamlFile, llvm::remarks::Format::YAML, cats); + context, yamlFile, llvm::remarks::Format::YAML, std::move(policy), + cats); ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; // PASS: something succeeded @@ -202,9 +204,10 @@ TEST(Remark, TestOutputOptimizationRemarkDiagnostic) { /*missed=*/categoryUnroll, /*analysis=*/categoryRegister, /*failed=*/categoryUnroll}; - - LogicalResult isEnabled = - remark::enableOptimizationRemarks(context, nullptr, cats, true); + std::unique_ptr<remark::RemarkEmittingPolicyAll> policy = + std::make_unique<remark::RemarkEmittingPolicyAll>(); + LogicalResult isEnabled = remark::enableOptimizationRemarks( + context, nullptr, std::move(policy), cats, true); ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; @@ -282,8 +285,11 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) { /*analysis=*/std::nullopt, /*failed=*/categoryLoopunroll}; + std::unique_ptr<remark::RemarkEmittingPolicyAll> policy = + std::make_unique<remark::RemarkEmittingPolicyAll>(); LogicalResult isEnabled = remark::enableOptimizationRemarks( - context, std::make_unique<MyCustomStreamer>(), cats, true); + context, std::make_unique<MyCustomStreamer>(), std::move(policy), cats, + true); ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; // Remark 1: pass, category LoopUnroll @@ -311,4 +317,66 @@ TEST(Remark, TestCustomOptimizationRemarkDiagnostic) { EXPECT_NE(errOut.find(pass2Msg), std::string::npos); // printed EXPECT_EQ(errOut.find(pass3Msg), std::string::npos); // filtered out } + +TEST(Remark, TestRemarkFinal) { + testing::internal::CaptureStderr(); + const auto *pass1Msg = "I failed"; + const auto *pass2Msg = "I failed too"; + const auto *pass3Msg = "I succeeded"; + const auto *pass4Msg = "I succeeded too"; + + std::string categoryLoopunroll("LoopUnroll"); + + std::string seenMsg = ""; + + { + MLIRContext context; + Location loc = FileLineColLoc::get(&context, "test.cpp", 1, 5); + Location locOther = FileLineColLoc::get(&context, "test.cpp", 55, 5); + + // Setup the remark engine + mlir::remark::RemarkCategories cats{/*all=*/"", + /*passed=*/categoryLoopunroll, + /*missed=*/categoryLoopunroll, + /*analysis=*/categoryLoopunroll, + /*failed=*/categoryLoopunroll}; + + std::unique_ptr<remark::RemarkEmittingPolicyFinal> policy = + std::make_unique<remark::RemarkEmittingPolicyFinal>(); + LogicalResult isEnabled = remark::enableOptimizationRemarks( + context, std::make_unique<MyCustomStreamer>(), std::move(policy), cats, + true); + ASSERT_TRUE(succeeded(isEnabled)) << "Failed to enable remark engine"; + + // Remark 1: failure + remark::failed( + loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << pass1Msg; + + // Remark 2: failure + remark::missed( + loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << remark::reason(pass2Msg); + + // Remark 3: pass + remark::passed( + loc, remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << pass3Msg; + + // Remark 4: pass + remark::passed( + locOther, + remark::RemarkOpts::name("Unroller").category(categoryLoopunroll)) + << pass4Msg; + } + + llvm::errs().flush(); + std::string errOut = ::testing::internal::GetCapturedStderr(); + + // Containment checks for messages. + EXPECT_EQ(errOut.find(pass1Msg), std::string::npos); // dropped + EXPECT_EQ(errOut.find(pass2Msg), std::string::npos); // dropped + EXPECT_NE(errOut.find(pass3Msg), std::string::npos); // shown + EXPECT_NE(errOut.find(pass4Msg), std::string::npos); // shown +} } // namespace diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py index f80a181..3712a6b 100755 --- a/mlir/utils/generate-test-checks.py +++ b/mlir/utils/generate-test-checks.py @@ -31,13 +31,16 @@ import argparse import os # Used to advertise this file's name ("autogenerated_note"). import re import sys +from collections import Counter ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " ADVERT_END = """ -// The script is designed to make adding checks to -// a test case fast, it is *not* designed to be authoritative -// about what constitutes a good test! The CHECK should be -// minimized and named to reflect the test intent. +// This script is intended to make adding checks to a test case quick and easy. +// It is *not* authoritative about what constitutes a good test. After using the +// script, be sure to review and refine the generated checks. For example, +// CHECK lines should be minimized and named to reflect the test’s intent. +// For comprehensive guidelines, see: +// * https://mlir.llvm.org/getting_started/TestingGuide/ """ @@ -45,6 +48,9 @@ ADVERT_END = """ SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" SSA_RE = re.compile(SSA_RE_STR) +# Regex matching `dialect.op_name` (e.g. `vector.transfer_read`). +SSA_OP_NAME_RE = re.compile(r"\b(?:\s=\s[a-z_]+)[.]([a-z_]+)\b") + # Regex matching the left-hand side of an assignment SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) @@ -63,7 +69,12 @@ ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR) class VariableNamer: def __init__(self, variable_names): self.scopes = [] + # Counter for generic FileCHeck names, e.g. VAL_#N self.name_counter = 0 + # Counters for FileCheck names derived from Op names, e.g. + # TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a + # dedicated counter for every Op type present in the input. + self.op_name_counter = Counter() # Number of variable names to still generate in parent scope self.generate_in_parent_scope_left = 0 @@ -77,17 +88,29 @@ class VariableNamer: self.generate_in_parent_scope_left = n # Generate a substitution name for the given ssa value name. - def generate_name(self, source_variable_name, use_ssa_name): + def generate_name(self, source_variable_name, use_ssa_name, op_name=""): # Compute variable name - variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' - if variable_name == '': + variable_name = ( + self.variable_names.pop(0) if len(self.variable_names) > 0 else "" + ) + if variable_name == "": # If `use_ssa_name` is set, use the MLIR SSA value name to generate # a FileCHeck substation string. As FileCheck requires these # strings to start with a character, skip MLIR variables starting # with a digit (e.g. `%0`). + # + # The next fallback option is to use the op name, if the + # corresponding match succeeds. + # + # If neither worked, use a generic name: `VAL_#N`. if use_ssa_name and source_variable_name[0].isalpha(): variable_name = source_variable_name.upper() + elif op_name != "": + variable_name = ( + op_name.upper() + "_" + str(self.op_name_counter[op_name]) + ) + self.op_name_counter[op_name] += 1 else: variable_name = "VAL_" + str(self.name_counter) self.name_counter += 1 @@ -123,6 +146,7 @@ class VariableNamer: def clear_names(self): self.name_counter = 0 self.used_variable_names = set() + self.op_name_counter.clear() class AttributeNamer: @@ -170,8 +194,12 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re # Process the rest that contained an SSA value name. for chunk in line_chunks: - m = SSA_RE.match(chunk) - ssa_name = m.group(0) if m is not None else '' + ssa = SSA_RE.match(chunk) + op_name_with_dialect = SSA_OP_NAME_RE.search(chunk) + ssa_name = ssa.group(0) if ssa is not None else "" + op_name = ( + op_name_with_dialect.group(1) if op_name_with_dialect is not None else "" + ) # Check if an existing variable exists for this name. variable = None @@ -185,7 +213,7 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re output_line += "%[[" + variable + "]]" else: # Otherwise, generate a new variable. - variable = variable_namer.generate_name(ssa_name, use_ssa_name) + variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name) if strict_name_re: # Use stricter regexp for the variable name, if requested. # Greedy matching may cause issues with the generic '.*' |