aboutsummaryrefslogtreecommitdiff
path: root/mlir/include
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir-c/Rewrite.h2
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h54
-rw-r--r--mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h8
-rw-r--r--mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h27
-rw-r--r--mlir/include/mlir/Conversion/Passes.h1
-rw-r--r--mlir/include/mlir/Conversion/Passes.td32
-rw-r--r--mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td93
-rw-r--r--mlir/include/mlir/Dialect/Affine/IR/AffineOps.td1
-rw-r--r--mlir/include/mlir/Dialect/Affine/LoopUtils.h2
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithOps.td94
-rw-r--r--mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h57
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt4
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td353
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td5
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td40
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td87
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h37
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRef.h1
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td3
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACC.h4
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td43
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td94
-rw-r--r--mlir/include/mlir/Dialect/Shard/IR/ShardOps.td119
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h48
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc940
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td33
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h13
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td7
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td1
-rw-r--r--mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td175
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td6
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td65
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td62
-rw-r--r--mlir/include/mlir/IR/CommonTypeConstraints.td8
-rw-r--r--mlir/include/mlir/IR/Remarks.h140
-rw-r--r--mlir/include/mlir/Interfaces/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Interfaces/InferIntRangeInterface.h12
-rw-r--r--mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h145
-rw-r--r--mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td45
-rw-r--r--mlir/include/mlir/Interfaces/ViewLikeInterface.h16
-rw-r--r--mlir/include/mlir/Interfaces/ViewLikeInterface.td12
-rw-r--r--mlir/include/mlir/Remark/RemarkStreamer.h1
-rw-r--r--mlir/include/mlir/TableGen/CodeGenHelpers.h2
-rw-r--r--mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h71
-rw-r--r--mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h9
45 files changed, 2415 insertions, 558 deletions
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 2db1d84..fe42a20 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -352,7 +352,7 @@ typedef struct {
/// Create a rewrite pattern that matches the operation
/// with the given rootName, corresponding to mlir::OpRewritePattern.
-MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames);
diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
new file mode 100644
index 0000000..72ac247
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h
@@ -0,0 +1,54 @@
+//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+
+namespace mlir {
+namespace dataflow {
+
+/// This lattice element represents the strided metadata of an SSA value.
+class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> {
+public:
+ using Lattice::Lattice;
+};
+
+/// Strided metadata range analysis determines the strided metadata ranges of
+/// SSA values using operations that define `InferStridedMetadataInterface`.
+///
+/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and
+/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not
+/// loaded in the same solver context.
+class StridedMetadataRangeAnalysis
+ : public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> {
+public:
+ StridedMetadataRangeAnalysis(DataFlowSolver &solver,
+ int32_t indexBitwidth = 64);
+
+ /// At an entry point, we cannot reason about strided metadata ranges unless
+ /// the type also encodes the data. For example, a memref with static layout.
+ void setToEntryState(StridedMetadataRangeLattice *lattice) override;
+
+ /// Visit an operation. Invoke the transfer function on each operation that
+ /// implements `InferStridedMetadataInterface`.
+ LogicalResult
+ visitOperation(Operation *op,
+ ArrayRef<const StridedMetadataRangeLattice *> operands,
+ ArrayRef<StridedMetadataRangeLattice *> results) override;
+
+private:
+ /// Index bitwidth to use when operating with the int-ranges.
+ int32_t indexBitwidth = 64;
+};
+} // namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e79..60f1888 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>
@@ -19,8 +20,11 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`,
+// none of the chipset dependent patterns are added.
+void populateMathToROCDLConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
new file mode 100644
index 0000000..91d3c92
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -0,0 +1,27 @@
+//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to XeVM calls.
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b2..40d866e 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3c18ecc..70e3e45 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.
+
+ The chipset option specifies the target AMDGPU architecture. If the chipset
+ is empty, none of the chipset-dependent patterns are added, and the pass
+ will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
@@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
+ let options = [Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"\"",
+ "Chipset that these operations will run on">];
}
//===----------------------------------------------------------------------===//
@@ -797,6 +804,31 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
}
//===----------------------------------------------------------------------===//
+// MathToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm"> {
+ let summary =
+ "Convert (fast) math operations to native XeVM/SPIRV equivalents";
+ let description = [{
+ This pass converts supported math ops marked with the `afn` fastmath flag
+ to function calls for OpenCL `native_` math intrinsics: These intrinsics
+ are typically mapped directly to native device instructions, often resulting
+ in better performance. However, the precision/error of these intrinsics
+ are implementation-defined, and thus math ops are only converted when they
+ have the `afn` fastmath flag enabled.
+ }];
+ let options = [Option<
+ "convertArith", "convert-arith", "bool", /*default=*/"true",
+ "Convert supported Arith ops (e.g. arith.divf) as well.">];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "xevm::XeVMDialect",
+ "LLVM::LLVMDialect",
+ ];
+}
+
+//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 8370d35..7184de9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -112,6 +112,97 @@ def AMDGPU_ExtPackedFp8Op :
}];
}
+def IsValidBlockSize: AttrConstraint<
+ CPred<"::llvm::is_contained({16, 32}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
+ "whose value is 16 or 32">;
+
+def AMDGPU_ScaledExtPacked816Op
+ : AMDGPU_Op<"scaled_ext_packed816", [Pure, AllShapesMatch<["source", "res"]>]>,
+ Arguments<(
+ ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>,
+ FixedVectorOfShapeAndType<[8], F8E4M3FN>,
+ FixedVectorOfShapeAndType<[8], F8E5M2>,
+ FixedVectorOfShapeAndType<[16], F6E2M3FN>,
+ FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source,
+ FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale,
+ ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
+ ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
+ ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>,
+ Results<(
+ outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>,
+ FixedVectorOfShapeAndType<[8], F16>,
+ FixedVectorOfShapeAndType<[8], BF16>,
+ FixedVectorOfShapeAndType<[16], F32>,
+ FixedVectorOfShapeAndType<[16], F16>,
+ FixedVectorOfShapeAndType<[16], BF16>]>:$res)> {
+
+ let summary = "Extend a vector of packed floating point values";
+
+ let description = [{
+ The scales applied to the input microfloats are stored in two bytes which
+ come from the `scales` input provided in a *half* of the wave identified
+ by `firstScaleLane`. The pair of bytes used is selected by
+ `firstScaleByte`. The 16 vectors in consecutive lanes starting from
+ `firstScaleLane` (which we'll call the scale vectors) will be used by both
+ halves of the wave (with lane L reading from L % 16'th scale vector), but
+ each half will use a different byte.
+
+ When the block size is 32, `firstScaleByte` can be either 0 or 2,
+ selecting halves of the scale vectors. Lanes 0-15 will read from
+ `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1.
+ For example:
+ ```mlir
+ // Input: 8-element vector of F8E4M3FN, converting to F32
+ // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(32) firstScaleLane(0) firstScaleByte(0)
+ : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+
+ // Input: 16-element vector of F6E2M3FN, converting to F16
+ // Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(32) firstScaleLane(1) firstScaleByte(2)
+ : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ ```
+
+ However, when the block size is 16, `firstScaleByte` can be 0 or 1.
+ Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors,
+ while lanes 16-31 read from `firstScaleByte` + 2.
+ For example:
+ ```mlir
+ // Input: 8-element vector of F8E5M2, converting to BF16
+ // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2)
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(16) firstScaleLane(0) firstScaleByte(0)
+ : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // Input: 16-element vector of F6E3M2FN, converting to F32
+ // Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2)
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(16) firstScaleLane(1) firstScaleByte(1)
+ : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ ```
+
+ Note: the layout for the scales generally mirrors how the WMMA
+ instructions use for matix scales. These selection operands allows
+ one to choose portions of the matrix to convert.
+
+ Available on gfx1250+.
+ }];
+
+ let assemblyFormat = [{
+ attr-dict $source
+ `scale` `(` $scale `)`
+ `blockSize` `(` $blockSize `)`
+ `firstScaleLane` `(` $firstScaleLane`)`
+ `firstScaleByte` `(` $firstScaleByte `)`
+ `:` type($source) `,` type($scale) `->` type($res)
+ }];
+
+ let hasVerifier = 1;
+
+}
+
def AMDGPU_ScaledExtPackedOp
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
Arguments<(
@@ -860,7 +951,7 @@ def AMDGPU_MFMAOp :
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
types of the source and destination arguments.
- For information on the layouts of the input and output matrces (which are stored
+ For information on the layouts of the input and output matrices (which are stored
in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation.
The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index e52b7d2..12a7935 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -330,7 +330,6 @@ def AffineForOp : Affine_Op<"for",
Speculation::Speculatability getSpeculatability();
}];
- let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasRegionVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 9b59af7..830c394 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -61,7 +61,7 @@ LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
/// Returns true if `loops` is a perfectly nested loop nest, where loops appear
/// in it from outermost to innermost.
-bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops);
+[[maybe_unused]] bool isPerfectlyNested(ArrayRef<AffineForOp> loops);
/// Get perfectly nested sequence of loops starting at root of loop nest
/// (the first op being another AffineFor, and the second op - a terminator).
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097..a38cf41 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1229,37 +1229,50 @@ def Arith_ScalingExtFOp
let summary = "Upcasts input floats using provided scales values following "
"OCP MXFP Spec";
let description = [{
- This operation upcasts input floating-point values using provided scale
- values. It expects both scales and the input operand to be of the same shape,
- making the operation elementwise. Scales are usually calculated per block
- following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
-
- If scales are calculated per block where blockSize != 1, then scales may
- require broadcasting to make this operation elementwise. For example, let's
- say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
- assuming quantization happens on the last axis, the input can be reshaped to
- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
- per block on the last axis. Therefore, scales will be of shape
- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
- shape as long as it is broadcast compatible with the input, e.g.,
- `<1 x 1 x ... (dimN/blockSize) x 1>`.
-
- In this example, before calling into `arith.scaling_extf`, scales must be
- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
- that there could be multiple quantization axes. Internally,
- `arith.scaling_extf` would perform the following:
+ This operation upcasts input floating-point values using provided scale
+ values. It expects both scales and the input operand to be of the same shape,
+ making the operation elementwise. Scales are usually calculated per block
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
- ```
- resultTy = get_type(result)
- scaleTy = get_type(scale)
- inputTy = get_type(input)
- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
- scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
- input.extf = arith.extf(input) : inputTy to resultTy
- result = arith.mulf(scale.extf, input.extf)
+ If scales are calculated per block where blockSize != 1, then scales may
+ require broadcasting to make this operation elementwise. For example, let's
+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+ In this example, before calling into `arith.scaling_extf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
+ `arith.scaling_extf` would perform the following:
+
+ ```mlir
+ // Cast scale to result type.
+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
+ %1 = arith.extf %0 : f8E8M0FNU to f16
+
+ // Cast input to result type.
+ %2 = arith.extf %3 : f4E2M1FN to f16
+
+ // Perform scaling
+ %3 = arith.mulf %2, %1 : f16
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
+
+ Example:
+
+ ```mlir
+ // Upcast from f4E2M1FN to f32.
+ %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
+
+ // Element-wise upcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
+ ```
}];
let hasVerifier = 1;
let assemblyFormat =
@@ -1397,14 +1410,27 @@ def Arith_ScalingTruncFOp
that there could be multiple quantization axes. Internally,
`arith.scaling_truncf` would perform the following:
+ ```mlir
+ // Cast scale to input type.
+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
+ %1 = arith.extf %0 : f8E8M0FNU to f16
+
+ // Perform scaling.
+ %3 = arith.divf %2, %1 : f16
+
+ // Cast to result type.
+ %4 = arith.truncf %3 : f16 to f4E2M1FN
```
- scaleTy = get_type(scale)
- inputTy = get_type(input)
- resultTy = get_type(result)
- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
- scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
- result = arith.divf(input, scale.extf)
- result.cast = arith.truncf(result, resultTy)
+
+ Example:
+
+ ```mlir
+ // Downcast from f32 to f4E2M1FN.
+ %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN
+
+ // Element-wise downcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
```
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index 035235f..fccb49d 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - GPU NVVM pipeline entry points --------------------------===//
+//===- Passes.h - GPU pipeline entry points--------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -60,6 +60,52 @@ struct GPUToNVVMPipelineOptions
llvm::cl::init(false)};
};
+// Options for the gpu to xevm pipeline.
+struct GPUToXeVMPipelineOptions
+ : public PassPipelineOptions<GPUToXeVMPipelineOptions> {
+ PassOptions::Option<std::string> xegpuOpLevel{
+ *this, "xegpu-op-level",
+ llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
+ "subgroup | lane"),
+ llvm::cl::init("workgroup")};
+ // General lowering controls.
+ PassOptions::Option<bool> use64bitIndex{
+ *this, "use-64bit-index",
+ llvm::cl::desc("Bitwidth of the index type (host & device)"),
+ llvm::cl::init(true)};
+ PassOptions::Option<bool> kernelBarePtrCallConv{
+ *this, "kernel-bare-ptr-calling-convention",
+ llvm::cl::desc("Use bare pointer calling convention for device kernels"),
+ llvm::cl::init(false)};
+ PassOptions::Option<bool> hostBarePtrCallConv{
+ *this, "host-bare-ptr-calling-convention",
+ llvm::cl::desc("Use bare pointer calling convention for host launches"),
+ llvm::cl::init(false)};
+ PassOptions::Option<std::string> binaryFormat{
+ *this, "binary-format",
+ llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"),
+ llvm::cl::init("fatbin")};
+ // Options mirroring xevm-attach-target (GpuXeVMAttachTarget).
+ PassOptions::Option<std::string> xevmModuleMatcher{
+ *this, "xevm-module-matcher",
+ llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"),
+ llvm::cl::init("")};
+ PassOptions::Option<std::string> zebinTriple{
+ *this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"),
+ llvm::cl::init("spirv64-unknown-unknown")};
+ PassOptions::Option<std::string> zebinChip{
+ *this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"),
+ llvm::cl::init("bmg")};
+ PassOptions::Option<unsigned> optLevel{
+ *this, "opt-level",
+ llvm::cl::desc("Optimization level for attached target/codegen"),
+ llvm::cl::init(2)};
+ PassOptions::Option<std::string> cmdOptions{
+ *this, "igc-cmd-options",
+ llvm::cl::desc("Additional downstream compiler command line options"),
+ llvm::cl::init("")};
+};
+
//===----------------------------------------------------------------------===//
// Building and Registering.
//===----------------------------------------------------------------------===//
@@ -70,8 +116,15 @@ struct GPUToNVVMPipelineOptions
void buildLowerToNVVMPassPipeline(OpPassManager &pm,
const GPUToNVVMPipelineOptions &options);
-/// Register all pipeleines for the `gpu` dialect.
+/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main
+/// dialects into XeVM targets. Begins with GPU code regions, then handles host
+/// code.
+void buildLowerToXeVMPassPipeline(OpPassManager &pm,
+ const GPUToXeVMPipelineOptions &options);
+
+/// Register all pipelines for the `gpu` dialect.
void registerGPUToNVVMPipeline();
+void registerGPUToXeVMPipeline();
} // namespace gpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 8d9474b..c301e0b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -48,6 +48,10 @@ mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conve
mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics)
add_mlir_dialect_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen)
+set(LLVM_TARGET_DEFINITIONS LLVMDialectBytecode.td)
+mlir_tablegen(LLVMDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="LLVM")
+add_public_tablegen_target(MLIRLLVMDialectBytecodeIncGen)
+
set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td)
mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td
new file mode 100644
index 0000000..e7b202c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td
@@ -0,0 +1,353 @@
+//===-- LLVMDialectBytecode.td - LLVM bytecode defs --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the LLVM bytecode reader/writer definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_DIALECT_BYTECODE
+#define LLVM_DIALECT_BYTECODE
+
+include "mlir/IR/BytecodeBase.td"
+
+//===----------------------------------------------------------------------===//
+// Bytecode classes for attributes and types.
+//===----------------------------------------------------------------------===//
+
+def String :
+ WithParser <"succeeded($_reader.readString($_var))",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeOwnedString($_getter)",
+ WithType <"StringRef">>>>;
+
+class Attr<string type> : WithType<type, Attribute>;
+
+class OptionalAttribute<string type> :
+ WithParser <"succeeded($_reader.readOptionalAttribute($_var))",
+ WithPrinter<"$_writer.writeOptionalAttribute($_getter)",
+ WithType<type, Attribute>>>;
+
+class OptionalInt<string type> :
+ WithParser <"succeeded(readOptionalInt($_reader, $_var))",
+ WithPrinter<"writeOptionalInt($_writer, $_getter)",
+ WithType<"std::optional<" # type # ">", VarInt>>>;
+
+class OptionalArrayRef<string eltType> :
+ WithParser <"succeeded(readOptionalArrayRef<"
+ # eltType # ">($_reader, $_var))",
+ WithPrinter<"writeOptionalArrayRef<"
+ # eltType # ">($_writer, $_getter)",
+ WithType<"SmallVector<"
+ # eltType # ">", Attribute>>>;
+
+class EnumClassFlag<string flag, string getter> :
+ WithParser<"succeeded($_reader.readVarInt($_var))",
+ WithBuilder<"(" # flag # ")$_args",
+ WithPrinter<"$_writer.writeVarInt((uint64_t)$_name." # getter # ")",
+ WithType<"uint64_t", VarInt>>>>;
+
+//===----------------------------------------------------------------------===//
+// General notes
+// - For each attribute or type entry, the argument names should match
+// LLVMAttrDefs.td
+// - The mnemonics are either LLVM or builtin MLIR attributes and types, but
+// regular C++ types are also allowed to match builders and parsers.
+// - DIScopeAttr and DINodeAttr are empty base classes, custom encoding not
+// needed.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// DIBasicTypeAttr
+//===----------------------------------------------------------------------===//
+
+def DIBasicTypeAttr : DialectAttribute<(attr
+ VarInt:$tag,
+ String:$name,
+ VarInt:$sizeInBits,
+ VarInt:$encoding
+)>;
+
+//===----------------------------------------------------------------------===//
+// DIExpressionAttr, DIExpressionElemAttr
+//===----------------------------------------------------------------------===//
+
+def DIExpressionElemAttr : DialectAttribute<(attr
+ VarInt:$opcode,
+ OptionalArrayRef<"uint64_t">:$arguments
+)>;
+
+def DIExpressionAttr : DialectAttribute<(attr
+ OptionalArrayRef<"DIExpressionElemAttr">:$operations
+)>;
+
+//===----------------------------------------------------------------------===//
+// DIFileAttr
+//===----------------------------------------------------------------------===//
+
+def DIFileAttr : DialectAttribute<(attr
+ String:$name,
+ String:$directory
+)>;
+
+//===----------------------------------------------------------------------===//
+// DILocalVariableAttr
+//===----------------------------------------------------------------------===//
+
+def DILocalVariableAttr : DialectAttribute<(attr
+ Attr<"DIScopeAttr">:$scope,
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalAttribute<"DIFileAttr">:$file,
+ VarInt:$line,
+ VarInt:$arg,
+ VarInt:$alignInBits,
+ OptionalAttribute<"DITypeAttr">:$type,
+ EnumClassFlag<"DIFlags", "getFlags()">:$_rawflags,
+ LocalVar<"DIFlags", "(DIFlags)_rawflags">:$flags
+)> {
+ // DILocalVariableAttr direct getter uses a `StringRef` for `name`. Since the
+ // more direct getter is prefered during bytecode reading, force the base one
+ // and prevent crashes for empty `StringAttr`.
+ let cBuilder = "$_resultType::get(context, $_args)";
+}
+
+//===----------------------------------------------------------------------===//
+// DISubroutineTypeAttr
+//===----------------------------------------------------------------------===//
+
+def DISubroutineTypeAttr : DialectAttribute<(attr
+ VarInt:$callingConvention,
+ OptionalArrayRef<"DITypeAttr">:$types
+)>;
+
+//===----------------------------------------------------------------------===//
+// DICompileUnitAttr
+//===----------------------------------------------------------------------===//
+
+def DICompileUnitAttr : DialectAttribute<(attr
+ Attr<"DistinctAttr">:$id,
+ VarInt:$sourceLanguage,
+ Attr<"DIFileAttr">:$file,
+ OptionalAttribute<"StringAttr">:$producer,
+ Bool:$isOptimized,
+ EnumClassFlag<"DIEmissionKind", "getEmissionKind()">:$_rawEmissionKind,
+ LocalVar<"DIEmissionKind", "(DIEmissionKind)_rawEmissionKind">:$emissionKind,
+ EnumClassFlag<"DINameTableKind", "getNameTableKind()">:$_rawNameTableKind,
+ LocalVar<"DINameTableKind",
+ "(DINameTableKind)_rawNameTableKind">:$nameTableKind
+)>;
+
+//===----------------------------------------------------------------------===//
+// DISubprogramAttr
+//===----------------------------------------------------------------------===//
+
+def DISubprogramAttr : DialectAttribute<(attr
+ OptionalAttribute<"DistinctAttr">:$recId,
+ Bool:$isRecSelf,
+ OptionalAttribute<"DistinctAttr">:$id,
+ OptionalAttribute<"DICompileUnitAttr">:$compileUnit,
+ OptionalAttribute<"DIScopeAttr">:$scope,
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalAttribute<"StringAttr">:$linkageName,
+ OptionalAttribute<"DIFileAttr">:$file,
+ VarInt:$line,
+ VarInt:$scopeLine,
+ EnumClassFlag<"DISubprogramFlags", "getSubprogramFlags()">:$_rawflags,
+ LocalVar<"DISubprogramFlags", "(DISubprogramFlags)_rawflags">:$subprogramFlags,
+ OptionalAttribute<"DISubroutineTypeAttr">:$type,
+ OptionalArrayRef<"DINodeAttr">:$retainedNodes,
+ OptionalArrayRef<"DINodeAttr">:$annotations
+)>;
+
+//===----------------------------------------------------------------------===//
+// DICompositeTypeAttr
+//===----------------------------------------------------------------------===//
+
+def DICompositeTypeAttr : DialectAttribute<(attr
+ OptionalAttribute<"DistinctAttr">:$recId,
+ Bool:$isRecSelf,
+ VarInt:$tag,
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalAttribute<"DIFileAttr">:$file,
+ VarInt:$line,
+ OptionalAttribute<"DIScopeAttr">:$scope,
+ OptionalAttribute<"DITypeAttr">:$baseType,
+ EnumClassFlag<"DIFlags", "getFlags()">:$_rawflags,
+ LocalVar<"DIFlags", "(DIFlags)_rawflags">:$flags,
+ VarInt:$sizeInBits,
+ VarInt:$alignInBits,
+ OptionalAttribute<"DIExpressionAttr">:$dataLocation,
+ OptionalAttribute<"DIExpressionAttr">:$rank,
+ OptionalAttribute<"DIExpressionAttr">:$allocated,
+ OptionalAttribute<"DIExpressionAttr">:$associated,
+ OptionalArrayRef<"DINodeAttr">:$elements
+)>;
+
+//===----------------------------------------------------------------------===//
+// DIDerivedTypeAttr
+//===----------------------------------------------------------------------===//
+
+def DIDerivedTypeAttr : DialectAttribute<(attr
+ VarInt:$tag,
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalAttribute<"DITypeAttr">:$baseType,
+ VarInt:$sizeInBits,
+ VarInt:$alignInBits,
+ VarInt:$offsetInBits,
+ OptionalInt<"unsigned">:$dwarfAddressSpace,
+ OptionalAttribute<"DINodeAttr">:$extraData
+)>;
+
+//===----------------------------------------------------------------------===//
+// DIImportedEntityAttr
+//===----------------------------------------------------------------------===//
+
+def DIImportedEntityAttr : DialectAttribute<(attr
+ VarInt:$tag,
+ Attr<"DIScopeAttr">:$scope,
+ Attr<"DINodeAttr">:$entity,
+ OptionalAttribute<"DIFileAttr">:$file,
+ VarInt:$line,
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalArrayRef<"DINodeAttr">:$elements
+)>;
+
+//===----------------------------------------------------------------------===//
+// DIGlobalVariableAttr, DIGlobalVariableExpressionAttr
+//===----------------------------------------------------------------------===//
+
+def DIGlobalVariableAttr : DialectAttribute<(attr
+ OptionalAttribute<"DIScopeAttr">:$scope,
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalAttribute<"StringAttr">:$linkageName,
+ Attr<"DIFileAttr">:$file,
+ VarInt:$line,
+ Attr<"DITypeAttr">:$type,
+ Bool:$isLocalToUnit,
+ Bool:$isDefined,
+ VarInt:$alignInBits
+)>;
+
+def DIGlobalVariableExpressionAttr : DialectAttribute<(attr
+ Attr<"DIGlobalVariableAttr">:$var,
+ OptionalAttribute<"DIExpressionAttr">:$expr
+)>;
+
+//===----------------------------------------------------------------------===//
+// DILabelAttr
+//===----------------------------------------------------------------------===//
+
+def DILabelAttr : DialectAttribute<(attr
+ Attr<"DIScopeAttr">:$scope,
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalAttribute<"DIFileAttr">:$file,
+ VarInt:$line
+)> {
+ // DILabelAttr direct getter uses a `StringRef` for `name`. Since the
+ // more direct getter is prefered during bytecode reading, force the base one
+ // and prevent crashes for empty `StringAttr`.
+ let cBuilder = "$_resultType::get(context, $_args)";
+}
+
+//===----------------------------------------------------------------------===//
+// DILexicalBlockAttr, DILexicalBlockFileAttr
+//===----------------------------------------------------------------------===//
+
+def DILexicalBlockAttr : DialectAttribute<(attr
+ Attr<"DIScopeAttr">:$scope,
+ OptionalAttribute<"DIFileAttr">:$file,
+ VarInt:$line,
+ VarInt:$column
+)>;
+
+def DILexicalBlockFileAttr : DialectAttribute<(attr
+ Attr<"DIScopeAttr">:$scope,
+ OptionalAttribute<"DIFileAttr">:$file,
+ VarInt:$discriminator
+)>;
+
+//===----------------------------------------------------------------------===//
+// DINamespaceAttr
+//===----------------------------------------------------------------------===//
+
+def DINamespaceAttr : DialectAttribute<(attr
+ OptionalAttribute<"StringAttr">:$name,
+ OptionalAttribute<"DIScopeAttr">:$scope,
+ Bool:$exportSymbols
+)>;
+
+//===----------------------------------------------------------------------===//
+// DISubrangeAttr
+//===----------------------------------------------------------------------===//
+
+def DISubrangeAttr : DialectAttribute<(attr
+ OptionalAttribute<"Attribute">:$count,
+ OptionalAttribute<"Attribute">:$lowerBound,
+ OptionalAttribute<"Attribute">:$upperBound,
+ OptionalAttribute<"Attribute">:$stride
+)>;
+
+//===----------------------------------------------------------------------===//
+// LoopAnnotationAttr
+//===----------------------------------------------------------------------===//
+
+def LoopAnnotationAttr : DialectAttribute<(attr
+ OptionalAttribute<"BoolAttr">:$disableNonforced,
+ OptionalAttribute<"LoopVectorizeAttr">:$vectorize,
+ OptionalAttribute<"LoopInterleaveAttr">:$interleave,
+ OptionalAttribute<"LoopUnrollAttr">:$unroll,
+ OptionalAttribute<"LoopUnrollAndJamAttr">:$unrollAndJam,
+ OptionalAttribute<"LoopLICMAttr">:$licm,
+ OptionalAttribute<"LoopDistributeAttr">:$distribute,
+ OptionalAttribute<"LoopPipelineAttr">:$pipeline,
+ OptionalAttribute<"LoopPeeledAttr">:$peeled,
+ OptionalAttribute<"LoopUnswitchAttr">:$unswitch,
+ OptionalAttribute<"BoolAttr">:$mustProgress,
+ OptionalAttribute<"BoolAttr">:$isVectorized,
+ OptionalAttribute<"FusedLoc">:$startLoc,
+ OptionalAttribute<"FusedLoc">:$endLoc,
+ OptionalArrayRef<"AccessGroupAttr">:$parallelAccesses
+)>;
+
+//===----------------------------------------------------------------------===//
+// Attributes & Types with custom bytecode handling.
+//===----------------------------------------------------------------------===//
+
+// All the attributes with custom bytecode handling.
+def LLVMDialectAttributes : DialectAttributes<"LLVM"> {
+ let elems = [
+ DIBasicTypeAttr,
+ DICompileUnitAttr,
+ DICompositeTypeAttr,
+ DIDerivedTypeAttr,
+ DIExpressionElemAttr,
+ DIExpressionAttr,
+ DIFileAttr,
+ DIGlobalVariableAttr,
+ DIGlobalVariableExpressionAttr,
+ DIImportedEntityAttr,
+ DILabelAttr,
+ DILexicalBlockAttr,
+ DILexicalBlockFileAttr,
+ DILocalVariableAttr,
+ DINamespaceAttr,
+ DISubprogramAttr,
+ DISubrangeAttr,
+ DISubroutineTypeAttr,
+ LoopAnnotationAttr
+ // Referenced attributes currently missing support:
+ // AccessGroupAttr, LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
+ // LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr,
+ // LoopPeeledAttr, LoopUnswitchAttr
+ ];
+}
+
+def LLVMDialectTypes : DialectTypes<"LLVM"> {
+ let elems = [];
+}
+
+#endif // LLVM_DIALECT_BYTECODE
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 9753dca..d0811a2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
custom<ShuffleType>(ref(type($v1)), type($res), ref($mask))
}];
+ let hasFolder = 1;
let hasVerifier = 1;
string llvmInstName = "ShuffleVector";
@@ -1985,6 +1986,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<StrAttr>:$instrument_function_exit,
OptionalAttr<UnitAttr>:$no_inline,
OptionalAttr<UnitAttr>:$always_inline,
+ OptionalAttr<UnitAttr>:$inline_hint,
OptionalAttr<UnitAttr>:$no_unwind,
OptionalAttr<UnitAttr>:$will_return,
OptionalAttr<UnitAttr>:$optimize_none,
@@ -2037,6 +2039,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
/// Returns true if the `always_inline` attribute is set, false otherwise.
bool isAlwaysInline() { return bool(getAlwaysInlineAttr()); }
+ /// Returns true if the `inline_hint` attribute is set, false otherwise.
+ bool isInlineHint() { return bool(getInlineHintAttr()); }
+
/// Returns true if the `optimize_none` attribute is set, false otherwise.
bool isOptimizeNone() { return bool(getOptimizeNoneAttr()); }
}];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 89fbeb7..d959464 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -263,6 +263,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
+ let hasVerifier = 1;
// Backwards-compatibility builder for an unspecified range.
let builders = [
@@ -279,6 +280,11 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
SetIntRangeFn setResultRanges) {
nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
}
+
+ // Verify the range attribute satisfies LLVM ConstantRange constructor requirements.
+ ::llvm::LogicalResult $cppClass::verify() {
+ return verifyConstantRangeAttr(getOperation(), getRange());
+ }
}];
}
@@ -1655,6 +1661,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+ let summary = "Convert a pair of float inputs to f4x2";
+ let description = [{
+ This Op converts each of the given float inputs to the specified fp4 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let results = (outs I8:$dst);
+ let arguments = (ins F32:$a, F32:$b,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6925cec..d2df244 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -412,6 +412,32 @@ def ROCDL_WaitExpcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.expcnt", [], 0, [0],
let assemblyFormat = "$count attr-dict";
}
+def ROCDL_WaitAsynccntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.asynccnt", [], 0, [0], ["count"]>,
+ Arguments<(ins I16Attr:$count)> {
+ let summary = "Wait until ASYNCCNT is less than or equal to `count`";
+ let description = [{
+ Wait for the counter specified to be less-than or equal-to the `count`
+ before continuing.
+
+ Available on gfx1250+.
+ }];
+ let results = (outs);
+ let assemblyFormat = "$count attr-dict";
+}
+
+def ROCDL_WaitTensorcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.tensorcnt", [], 0, [0], ["count"]>,
+ Arguments<(ins I16Attr:$count)> {
+ let summary = "Wait until TENSORCNT is less than or equal to `count`";
+ let description = [{
+ Wait for the counter specified to be less-than or equal-to the `count`
+ before continuing.
+
+ Available on gfx1250+.
+ }];
+ let results = (outs);
+ let assemblyFormat = "$count attr-dict";
+}
+
def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>,
Arguments<(ins I16Attr:$priority)> {
let assemblyFormat = "$priority attr-dict";
@@ -548,6 +574,30 @@ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_b
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+// Available from gfx1250
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
@@ -1117,6 +1167,7 @@ foreach smallT = [
ScaleArgInfo<ROCDL_V16BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V16F32Type, "F32">,
] in {
+ // Up-scaling
def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name,
[Pure], 1, [2], ["scaleSel"]>,
@@ -1132,6 +1183,42 @@ foreach smallT = [
}];
}
+
+ // Down-scaling
+ def ROCDL_CvtScaleF32Pk16 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk16." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name ;
+ let description = [{
+ Convert 8 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, multiplying by the exponent part of `scale`
+ before doing so. This op is for gfx1250+ arch.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `:` type($res)
+ }];
+ }
+
+ def ROCDL_CvtScaleF32SrPk16 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk16." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name # " with stochastic rounding";
+ let description = [{
+ Convert 8 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, multiplying by the exponent part of `scale`
+ before doing so and apply stochastic rounding. This op is for gfx1250+ arch.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `:` type($res)
+ }];
+ }
+
} // foreach largeT
} // foreach smallTOp
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ae7a085..c89fc59 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -25,7 +25,6 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
namespace mlir {
namespace bufferization {
@@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
-computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);
using PadSizeComputationFunction =
std::function<FailureOr<SmallVector<OpFoldResult>>(
- RewriterBase &, OpOperand &, ArrayRef<Range>,
+ OpBuilder &, OpOperand &, ArrayRef<Range>,
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
-FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
- ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
+FailureOr<SmallVector<OpFoldResult>>
+computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad,
+ ArrayRef<Range> iterationDomain,
+ const PadTilingInterfaceOptions &);
+
+/// Operations and values created in the process of padding a TilingInterface
+/// operation.
+struct PadTilingInterfaceResult {
+ /// The operands of the padded op.
+ SmallVector<tensor::PadOp> padOps;
+ /// The padded op, a clone of `toPad` with padded operands.
+ TilingInterface paddedOp;
+ /// Slices of the padded op's results, same types as `toPad`.
+ SmallVector<Value> replacements;
+};
-/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
-///
+/// Pad the iterator dimensions of `toPad`.
/// * "options.paddingSizes" indicates that each padding dimension should be
/// padded to the specified padding size.
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
// interpreted as the bounding box (dynamic) value to pad to.
/// * Use "options.paddingValues" to set the padding value of the created
// tensor::PadOp.
-/// * The tensor::PadOp is returned on success.
-
-FailureOr<TilingInterface>
-rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- const PadSizeComputationFunction &computePaddingSizeFun =
+FailureOr<PadTilingInterfaceResult>
+rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
+ const PadSizeComputationFunction & =
&computeIndexingMapOpInterfacePaddedShape);
namespace detail {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 30f33ed..69447f7 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -17,6 +17,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 40b7d7e..b39207f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/InferStridedMetadataInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -184,6 +185,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
Pure,
+ DistinctObjectsTrait,
DeclareOpInterfaceMethods<InferTypeOpInterface>
// ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
]> {
@@ -2084,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 8f87235..b8aa497 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -183,6 +183,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() {
return StringLiteral("acc.routine_info");
}
+static constexpr StringLiteral getVarNameAttrName() {
+ return VarNameAttr::name;
+}
+
static constexpr StringLiteral getCombinedConstructsAttrName() {
return CombinedConstructsTypeAttr::name;
}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 77e833f..1eaa21b4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -415,6 +415,13 @@ def OpenACC_ConstructResource : Resource<"::mlir::acc::ConstructResource">;
// Define a resource for the OpenACC current device setting.
def OpenACC_CurrentDeviceIdResource : Resource<"::mlir::acc::CurrentDeviceIdResource">;
+// Attribute for saving variable names - this can be attached to non-acc-dialect
+// operations in order to ensure the name is preserved.
+def OpenACC_VarNameAttr : OpenACC_Attr<"VarName", "var_name"> {
+ let parameters = (ins StringRefParameter<"">:$name);
+ let assemblyFormat = "`<` $name `>`";
+}
+
// Used for data specification in data clauses (2.7.1).
// Either (or both) extent and upperbound must be specified.
def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
@@ -1316,6 +1323,24 @@ def OpenACC_PrivateRecipeOp
}];
let hasRegionVerifier = 1;
+
+ let extraClassDeclaration = [{
+ /// Creates a PrivateRecipeOp and populates its regions based on the
+ /// variable type as long as the type implements MappableType or
+ /// PointerLikeType interface. If a type implements both, the MappableType
+ /// API will be preferred. Returns std::nullopt if the recipe cannot be
+ /// created or populated. The builder's current insertion point will be used
+ /// and it must be a valid place for this operation to be inserted. The
+ /// `recipeName` must be a unique name to prevent "redefinition of symbol"
+ /// IR errors.
+ static std::optional<PrivateRecipeOp> createAndPopulate(
+ ::mlir::OpBuilder &builder,
+ ::mlir::Location loc,
+ ::llvm::StringRef recipeName,
+ ::mlir::Type varType,
+ ::llvm::StringRef varName = "",
+ ::mlir::ValueRange bounds = {});
+ }];
}
//===----------------------------------------------------------------------===//
@@ -1410,6 +1435,24 @@ def OpenACC_FirstprivateRecipeOp
}];
let hasRegionVerifier = 1;
+
+ let extraClassDeclaration = [{
+ /// Creates a FirstprivateRecipeOp and populates its regions based on the
+ /// variable type as long as the type implements MappableType or
+ /// PointerLikeType interface. If a type implements both, the MappableType
+ /// API will be preferred. Returns std::nullopt if the recipe cannot be
+ /// created or populated. The builder's current insertion point will be used
+ /// and it must be a valid place for this operation to be inserted. The
+ /// `recipeName` must be a unique name to prevent "redefinition of symbol"
+ /// IR errors.
+ static std::optional<FirstprivateRecipeOp> createAndPopulate(
+ ::mlir::OpBuilder &builder,
+ ::mlir::Location loc,
+ ::llvm::StringRef recipeName,
+ ::mlir::Type varType,
+ ::llvm::StringRef varName = "",
+ ::mlir::ValueRange bounds = {});
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 0d16255..93e9e3d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -73,17 +73,31 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
InterfaceMethod<
/*description=*/[{
Generates allocation operations for the pointer-like type. It will create
- an allocate that produces memory space for an instance of the current type.
+ an allocate operation that produces memory space for an instance of the
+ current type.
The `varName` parameter is optional and can be used to provide a name
- for the allocated variable. If the current type is represented
- in a way that it does not capture the pointee type, `varType` must be
- passed in to provide the necessary type information.
+ for the allocated variable. When provided, it must be used by the
+ implementation; and if the implementing dialect does not have its own
+ way to save it, the discardable `acc.var_name` attribute from the acc
+ dialect will be used.
+
+ If the current type is represented in a way that it does not capture
+ the pointee type, `varType` must be passed in to provide the necessary
+ type information.
The `originalVar` parameter is optional but enables support for dynamic
types (e.g., dynamic memrefs). When provided, implementations can extract
runtime dimension information from the original variable to create
- allocations with matching dynamic sizes.
+ allocations with matching dynamic sizes. When generating recipe bodies,
+ `originalVar` should be the block argument representing the original
+ variable in the recipe region.
+
+ The `needsFree` output parameter indicates whether the allocated memory
+ requires explicit deallocation. Implementations should set this to true
+ for heap allocations that need a matching deallocation operation (e.g.,
+ alloc) and false for stack-based allocations (e.g., alloca). During
+ recipe generation, this determines whether a destroy region is created.
Returns a Value representing the result of the allocation. If no value
is returned, it means the allocation was not successfully generated.
@@ -94,7 +108,8 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
"::mlir::Location":$loc,
"::llvm::StringRef":$varName,
"::mlir::Type":$varType,
- "::mlir::Value":$originalVar),
+ "::mlir::Value":$originalVar,
+ "bool &":$needsFree),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return {};
@@ -102,23 +117,34 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
>,
InterfaceMethod<
/*description=*/[{
- Generates deallocation operations for the pointer-like type. It deallocates
- the instance provided.
+ Generates deallocation operations for the pointer-like type.
- The `varPtr` parameter is required and must represent an instance that was
- previously allocated. If the current type is represented in a way that it
- does not capture the pointee type, `varType` must be passed in to provide
- the necessary type information. Nothing is generated in case the allocate
- is `alloca`-like.
+ The `varToFree` parameter is required and must represent an instance
+ that was previously allocated. When generating recipe bodies, this
+ should be the block argument representing the private variable in the
+ destroy region.
+
+ The `allocRes` parameter is optional and provides the result of the
+ corresponding allocation from the init region. This allows implementations
+ to inspect the allocation operation to determine the appropriate
+ deallocation strategy. This is necessary because in recipe generation,
+ the allocation and deallocation occur in separate regions. Dialects that
+ use only one allocation type or can determine deallocation from type
+ information alone may ignore this parameter.
+
+ The `varType` parameter must be provided if the current type does not
+ capture the pointee type information. No deallocation is generated for
+ stack-based allocations (e.g., alloca).
- Returns true if deallocation was successfully generated or successfully
- deemed as not needed to be generated, false otherwise.
+ Returns true if deallocation was successfully generated or determined to
+ be unnecessary, false otherwise.
}],
/*retTy=*/"bool",
/*methodName=*/"genFree",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
"::mlir::Location":$loc,
- "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr,
+ "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varToFree,
+ "::mlir::Value":$allocRes,
"::mlir::Type":$varType),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -274,6 +300,14 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
The `initVal` can be empty - it is primarily needed for reductions
to ensure the variable is also initialized with appropriate value.
+ The `needsDestroy` out-parameter is set by implementations to indicate
+ that destruction code must be generated after the returned private
+ variable usages, typically in the destroy region of recipe operations
+ (for example, when heap allocations or temporaries requiring cleanup
+ are created during initialization). When `needsDestroy` is set, callers
+ should invoke `generatePrivateDestroy` in the recipe's destroy region
+ with the privatized value returned by this method.
+
If the return value is empty, it means that recipe body was not
successfully generated.
}],
@@ -284,12 +318,38 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
"::mlir::TypedValue<::mlir::acc::MappableType>":$var,
"::llvm::StringRef":$varName,
"::mlir::ValueRange":$extents,
- "::mlir::Value":$initVal),
+ "::mlir::Value":$initVal,
+ "bool &":$needsDestroy),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return {};
}]
>,
+ InterfaceMethod<
+ /*description=*/[{
+ Generates destruction operations for a privatized value previously
+ produced by `generatePrivateInit`. This is typically inserted in a
+ recipe's destroy region, after all uses of the privatized value.
+
+ The `privatized` value is the SSA value yielded by the init region
+ (and passed as the privatized argument to the destroy region).
+ Implementations should free heap-allocated storage or perform any
+ cleanup required for the given type. If no destruction is required,
+ this function should be a no-op and return `true`.
+
+ Returns true if destruction was successfully generated or deemed not
+ necessary, false otherwise.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"generatePrivateDestroy",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "::mlir::Location":$loc,
+ "::mlir::Value":$privatized),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >,
];
}
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 29b384f..5e68f75e 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -174,7 +174,7 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
```
The above returns two indices, `633` and `693`, which correspond to the
index of the previous process `(1, 1, 3)`, and the next process
- `(1, 3, 3) along the split axis `1`.
+ `(1, 3, 3)` along the split axis `1`.
A negative value is returned if there is no neighbor in the respective
direction along the given `split_axes`.
@@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
]> {
let summary = "All-gather over a device grid.";
let description = [{
- Gathers along the `gather_axis` tensor axis.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and replicates the result across all devices
+ in the group.
Example:
```mlir
@@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device grid.";
let description = [{
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ Each device in a group receives a replicated copy of the reduction result.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
- let summary = "All-slice over a device grid. This is the inverse of all-gather.";
+ let summary = "All-slice over a device grid.";
let description = [{
- Slice along the `slice_axis` tensor axis.
- This operation can be thought of as the inverse of all-gather.
- Technically, it is not required that all processes have the same input tensor.
- Each process will slice a piece of its local tensor based on its in-group device index.
- The operation does not communicate data between devices.
+ Within each device group defined by `grid_axes`, slices the input tensor along
+ the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
+ the input data is replicated along the `slice_axis`.
+ Each process simply crops its local data to the slice corresponding to its
+ in-group device index.
+ Notice: `AllSliceOp` does not involve any communication between devices and
+ devices within a group may not have replicated input data.
Example:
```mlir
@@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
```
Result:
```
- gather tensor
+ slice tensor
axis 1
------------>
+-------+-------+
@@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device grid.";
let description = [{
- Performs an all-to-all on tensor pieces split along `split_axis`.
- The resulting pieces are concatenated along `concat_axis` on ech device.
+ Each participant logically splits its input along split_axis,
+ then scatters the resulting pieces across the group defined by `grid_axes`.
+ After receiving data pieces from other participants' scatters,
+ it concatenates them along concat_axis to produce the final result.
Example:
```
@@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
]> {
let summary = "Broadcast over a device grid.";
let description = [{
- Broadcast the tensor on `root` to all devices in each respective group.
- The operation broadcasts along grid axes `grid_axes`.
- The `root` device specifies the in-group multi-index that is broadcast to
- all other devices in the group.
+ Copies the input tensor on `root` to all devices in each group defined by
+ `grid_axes`. The `root` device is defined by its in-group multi-index.
+ The contents of input tensors on non-root devices are ignored.
Example:
```
@@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
- device (1, 0) -> | | | <- device (1, 1)
+ device (1, 0) -> | * * | * * | <- device (1, 1)
+-------+-------+
```
@@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
]> {
let summary = "Gather over a device grid.";
let description = [{
- Gathers on device `root` along the `gather_axis` tensor axis.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device for each device group.
- The result tensor on non-root devices is undefined.
- Using it will result in undefined behavior.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and returns the resulting tensor on each
+ `root` device. The result on all other (non-root) devices is undefined.
+ The `root` device is defined by its in-group multi-index.
Example:
```mlir
@@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Receive from a device within a device group.
+ Receive tensor from device `source`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
+ The content of input tensor is ignored.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
]> {
let summary = "Reduce over a device grid.";
let description = [{
- Reduces on device `root` within each device group.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device within its device group.
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ The reduction result will be returned on the `root` device of each group.
+ It is undefined on all other (non-root) devices.
+ The `root` device is defined by its in-group multi-index.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device grid.";
let description = [{
- After the reduction, the result is scattered within each device group.
- The tensor is split along `scatter_axis` and the pieces distributed
- across the device group.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes` using the specified reduction method. The reduction is performed
+ element-wise across the tensor pieces from all devices in the group.
+ After reduction, the reduction result is scattered (split and distributed)
+ across the device group along `scatter_axis`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
- : tensor<3x4xf32> -> tensor<1x4xf64>
+ : tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
```
@@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
Result:
```
+-------+
- | 6 8 | <- devices (0, 0)
+ | 5 6 | <- devices (0, 0)
+-------+
- | 10 12 | <- devices (0, 1)
+ | 7 8 | <- devices (0, 1)
+-------+
- | 22 24 | <- devices (1, 0)
+ | 13 14 | <- devices (1, 0)
+-------+
- | 26 28 | <- devices (1, 1)
+ | 15 16 | <- devices (1, 1)
+-------+
```
}];
@@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
]> {
let summary = "Scatter over a device grid.";
let description = [{
- For each device group split the input tensor on the `root` device along
- axis `scatter_axis` and scatter the parts across the group devices.
+ For each device group defined by `grid_axes`, the input tensor on the `root`
+ device is split along axis `scatter_axis` and distributed across the group.
+ The content of the input on all other (non-root) devices is ignored.
+ The `root` device is defined by its in-group multi-index.
Example:
```
@@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
(0, 1)
↓
+-------+-------+ | scatter tensor
- device (0, 0) -> | | | | axis 0
- | | | ↓
+ device (0, 0) -> | * * | * * | | axis 0
+ | * * | * * | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
@@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Send from one device to another within a device group.
+ Send input tensor to device `destination`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
]> {
let summary = "Shift over a device grid.";
let description = [{
- Within each device group shift along grid axis `shift_axis` by an offset
- `offset`.
- The result on devices that do not have a corresponding source is undefined.
- `shift_axis` must be one of `grid_axes`.
- If the `rotate` attribute is present,
- instead of a shift a rotation is done.
+ Within each device group defined by `grid_axes`, shifts input tensors along the
+ device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
+ be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
+ That is, the offset wraps around according to the group size along `shift_axis`.
+ Otherwise, the results on devices without a corresponding source are undefined.
Example:
```
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 10491f6..4ecf03c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+/// A thin wrapper around the SpecificationVersion enum to represent
+/// and provide utilities around the TOSA specification version.
+class TosaSpecificationVersion {
+public:
+ TosaSpecificationVersion(uint32_t major, uint32_t minor)
+ : majorVersion(major), minorVersion(minor) {}
+ TosaSpecificationVersion(SpecificationVersion version)
+ : TosaSpecificationVersion(fromVersionEnum(version)) {}
+
+ bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
+ return this->majorVersion == baseVersion.majorVersion &&
+ this->minorVersion >= baseVersion.minorVersion;
+ }
+
+ uint32_t getMajor() const { return majorVersion; }
+ uint32_t getMinor() const { return minorVersion; }
+
+private:
+ uint32_t majorVersion = 0;
+ uint32_t minorVersion = 0;
+
+ static TosaSpecificationVersion
+ fromVersionEnum(SpecificationVersion version) {
+ switch (version) {
+ case SpecificationVersion::V_1_0:
+ return TosaSpecificationVersion(1, 0);
+ case SpecificationVersion::V_1_1_DRAFT:
+ return TosaSpecificationVersion(1, 1);
+ }
+ llvm_unreachable("Unknown TOSA version");
+ }
+};
+
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
+
/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
+ const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
- : level(level) {
+ : specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}
explicit TargetEnv(TargetEnvAttr targetAttr)
- : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
- targetAttr.getExtensions()) {}
+ : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions()) {}
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
- // TODO implement the following utilities.
- // Version getSpecVersion() const;
+ SpecificationVersion getSpecVersion() const { return specificationVersion; }
TosaLevel getLevel() const {
if (level == Level::eightK)
@@ -105,6 +140,7 @@ public:
}
private:
+ SpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 1f718ac..c1b5e78 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -2,441 +2,779 @@
// `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git
profileComplianceMap = {
{"tosa.argmax",
- {{{Profile::pro_int}, {{i8T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp32T},
- {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
- {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
- {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Profile::pro_int}, {{i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.add",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.arithmetic_right_shift",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_and",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_or",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_xor",
{{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.intdiv",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_and",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_left_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_right_shift",
{{{Profile::pro_int, Profile::pro_fp},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf}}},
{"tosa.logical_or",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.logical_xor",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.maximum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.minimum",
- {{{Profile::pro_int}, {{i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.mul",
- {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}},
- {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pow",
- {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.sub",
- {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
- {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}},
{"tosa.abs",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.bitwise_not",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}},
- {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}},
- {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.clz",
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.logical_not",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.negate",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reciprocal",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
- {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.select",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.greater_equal",
- {{{Profile::pro_int}, {{i32T, i32T, boolT}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}},
+ {{{Profile::pro_int},
+ {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_all",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_any",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf}}},
{"tosa.reduce_max",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_min",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_product",
- {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reduce_sum",
- {{{Profile::pro_int}, {{i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
{{Profile::pro_int},
- {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+ {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf},
- {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int, Profile::pro_fp},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0}},
+ anyOf},
+ {{Profile::pro_int},
+ {{{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
{{{Profile::pro_int},
- {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}},
+ {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
{{{Profile::pro_int},
- {{i8T, i32T, i8T, i8T},
- {i16T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}},
+ {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}},
{{Profile::pro_fp},
- {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}},
+ {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Profile::pro_int},
- {{boolT, i8T},
- {boolT, i16T},
- {boolT, i32T},
- {i8T, boolT},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, boolT},
- {i16T, i8T},
- {i16T, i32T},
- {i32T, boolT},
- {i32T, i8T},
- {i32T, i16T}}},
- {{Profile::pro_fp},
- {{i8T, fp16T},
- {i8T, fp32T},
- {i16T, fp16T},
- {i16T, fp32T},
- {i32T, fp16T},
- {i32T, fp32T},
- {fp16T, i8T},
- {fp16T, i16T},
- {fp16T, i32T},
- {fp16T, fp32T},
- {fp32T, i8T},
- {fp32T, i16T},
- {fp32T, i32T},
- {fp32T, fp16T}}}}},
+ {{{boolT, i8T}, SpecificationVersion::V_1_0},
+ {{boolT, i16T}, SpecificationVersion::V_1_0},
+ {{boolT, i32T}, SpecificationVersion::V_1_0},
+ {{i8T, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, boolT}, SpecificationVersion::V_1_0},
+ {{i16T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, boolT}, SpecificationVersion::V_1_0},
+ {{i32T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{i8T, fp16T}, SpecificationVersion::V_1_0},
+ {{i8T, fp32T}, SpecificationVersion::V_1_0},
+ {{i16T, fp16T}, SpecificationVersion::V_1_0},
+ {{i16T, fp32T}, SpecificationVersion::V_1_0},
+ {{i32T, fp16T}, SpecificationVersion::V_1_0},
+ {{i32T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, i8T}, SpecificationVersion::V_1_0},
+ {{fp16T, i16T}, SpecificationVersion::V_1_0},
+ {{fp16T, i32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, i8T}, SpecificationVersion::V_1_0},
+ {{fp32T, i16T}, SpecificationVersion::V_1_0},
+ {{fp32T, i32T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
- {{i8T, i8T, i8T, i8T},
- {i8T, i8T, i16T, i16T},
- {i8T, i8T, i32T, i32T},
- {i16T, i16T, i8T, i8T},
- {i16T, i16T, i16T, i16T},
- {i16T, i16T, i32T, i32T},
- {i32T, i32T, i8T, i8T},
- {i32T, i32T, i16T, i16T},
- {i32T, i32T, i32T, i32T}}}}},
+ {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT}, {i8T}, {i16T}, {i32T}},
+ {{{boolT}, SpecificationVersion::V_1_0},
+ {{i8T}, SpecificationVersion::V_1_0},
+ {{i16T}, SpecificationVersion::V_1_0},
+ {{i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
{{{Profile::pro_int, Profile::pro_fp},
- {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}},
+ {{{boolT, boolT}, SpecificationVersion::V_1_0},
+ {{i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i32T, i32T}, SpecificationVersion::V_1_0}},
anyOf},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
+ {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}},
+ {{Profile::pro_fp},
+ {{{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
extensionComplianceMap = {
{"tosa.argmax",
- {{{Extension::int16}, {{i16T, i32T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
- {{Extension::bf16}, {{bf16T, i32T}}}}},
+ {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.avg_pool2d",
- {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T},
+ SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.conv3d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.depthwise_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
- {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
+ {"tosa.fft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.matmul",
- {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
+ {{{Extension::int16},
+ {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
- {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
- {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3, Extension::fp8e5m2},
- {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
- {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
- {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}},
+ {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T},
+ SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T},
+ SpecificationVersion::V_1_1_DRAFT}},
allOf},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.max_pool2d",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rfft2d",
+ {{{Extension::fft},
+ {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose_conv2d",
- {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}},
- {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}},
+ {{{Extension::int4},
+ {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16},
+ {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}},
+ {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}},
+ {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T},
+ SpecificationVersion::V_1_0}}},
{{Extension::bf16},
- {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
+ {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
+ SpecificationVersion::V_1_0}}}}},
{"tosa.clamp",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}},
- {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
- {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
- {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}},
- {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.erf",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sigmoid",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.tanh",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.add",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.maximum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.minimum",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.mul",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.pow",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sub",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.table",
+ {{{Extension::int16},
+ {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.abs",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.ceil",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.cos",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.exp",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.floor",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.log",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.negate",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reciprocal",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.rsqrt",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.sin",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.select",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.greater_equal",
+ {{{Extension::bf16},
+ {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_max",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_min",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_product",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.reduce_sum",
+ {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.concat",
- {{{Extension::int16}, {{i16T, i16T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.pad",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reshape",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.reverse",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.slice",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.tile",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.transpose",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.gather",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.scatter",
- {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}},
+ {{{Extension::fp8e4m3},
+ {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16},
+ {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.resize",
- {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {{{Extension::int16},
+ {{{i16T, i48T}, SpecificationVersion::V_1_0},
+ {{i16T, i16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.cast",
{{{Extension::bf16},
- {{i8T, bf16T},
- {i16T, bf16T},
- {i32T, bf16T},
- {bf16T, i8T},
- {bf16T, i16T},
- {bf16T, i32T},
- {bf16T, fp32T},
- {fp32T, bf16T}}},
+ {{{i8T, bf16T}, SpecificationVersion::V_1_0},
+ {{i16T, bf16T}, SpecificationVersion::V_1_0},
+ {{i32T, bf16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i8T}, SpecificationVersion::V_1_0},
+ {{bf16T, i16T}, SpecificationVersion::V_1_0},
+ {{bf16T, i32T}, SpecificationVersion::V_1_0},
+ {{bf16T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp32T, bf16T}, SpecificationVersion::V_1_0}}},
{{Extension::bf16, Extension::fp8e4m3},
- {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}},
+ {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::bf16, Extension::fp8e5m2},
- {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}},
+ {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}},
allOf},
{{Extension::fp8e4m3},
- {{fp8e4m3T, fp16T},
- {fp8e4m3T, fp32T},
- {fp16T, fp8e4m3T},
- {fp32T, fp8e4m3T}}},
+ {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2},
- {{fp8e5m2T, fp16T},
- {fp8e5m2T, fp32T},
- {fp16T, fp8e5m2T},
- {fp32T, fp8e5m2T}}}}},
+ {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
+ {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
+ {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
+ {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
{"tosa.rescale",
{{{Extension::int16},
- {{i48T, i48T, i8T, i8T},
- {i48T, i48T, i16T, i16T},
- {i48T, i48T, i32T, i32T}}}}},
+ {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0},
+ {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.const",
- {{{Extension::int4}, {{i4T}}},
- {{Extension::int16}, {{i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T}}}}},
+ {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}},
{"tosa.identity",
- {{{Extension::int4}, {{i4T, i4T}}},
- {{Extension::int16}, {{i48T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}},
- {{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
+ {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.variable",
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_write",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.variable_read",
- {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}},
+ {{{Extension::variable},
+ {{{i8T}, SpecificationVersion::V_1_0},
+ {{fp16T}, SpecificationVersion::V_1_0},
+ {{fp32T}, SpecificationVersion::V_1_0}}}}},
};
+
// End of auto-generated metadata
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 38cb293..8376a4c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
}
//===----------------------------------------------------------------------===//
-// TOSA Spec Section 1.5.
+// TOSA Profiles and extensions
//
// Profile:
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
-def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-
-def Tosa_LevelAttr
- : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
-
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -405,17 +399,40 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
}
//===----------------------------------------------------------------------===//
+// TOSA Levels
+//===----------------------------------------------------------------------===//
+
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
+
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
+
+//===----------------------------------------------------------------------===//
+// TOSA Specification versions
+//===----------------------------------------------------------------------===//
+
+def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
+def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
+
+def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
+ "SpecificationVersion", "TOSA specification version", "specification_version",
+ [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
+
+//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
+ "SpecificationVersion": $specification_version,
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);
- let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
+ "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 8f5c72b..7b946ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -36,12 +36,15 @@ enum CheckCondition {
allOf
};
+using VersionedTypeInfo =
+ std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
+
template <typename T>
struct OpComplianceInfo {
// Certain operations require multiple modes enabled.
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
SmallVector<T> mode;
- SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
+ SmallVector<VersionedTypeInfo> operandTypeInfoSet;
CheckCondition condition = CheckCondition::anyOf;
};
@@ -130,9 +133,8 @@ public:
// Find the required profiles or extensions from the compliance info according
// to the operand type combination.
template <typename T>
- SmallVector<T> findMatchedProfile(Operation *op,
- SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition);
+ OpComplianceInfo<T>
+ findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
@@ -168,8 +170,7 @@ public:
private:
template <typename T>
- FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
- CheckCondition &condition);
+ FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 6ae19d8..14b00b0 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
let options = [
+ Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
+ /*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
+ "The specification version that TOSA operators should conform to.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
+ clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
+ )}]>,
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e79085..6e15b1e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [
}];
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
let assemblyFormat = "attr-dict `:` type($result)";
+ let hasCanonicalizer = 1;
}
def Vector_YieldOp : Vector_Op<"yield", [
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index b80ee2c..e9425e8 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -43,9 +43,41 @@ class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> :
let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target";
}
-def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {}
+def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<
+ "block",
+ "Create a nesting level with a label at its exit."> {
+ let description = [{
+ Defines a Wasm block, creating a new nested scope.
+ A block contains a body region and an optional list of input values.
+ Control can enter the block and later branch out to the block target.
+ Example:
+
+ ```mlir
+
+ wasmssa.block {
+
+ // instructions
+
+ } > ^successor
+ }];
+}
+
+def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<
+ "loop",
+ "Create a nesting level that define its entry as jump target."> {
+ let description = [{
+ Represents a Wasm loop construct. This defines a nesting level with
+ a label at the entry of the region.
-def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {}
+ Example:
+
+ ```mlir
+
+ wasmssa.loop {
+
+ } > ^successor
+ }];
+}
def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> {
@@ -55,9 +87,16 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
::mlir::Block* getTarget();
}];
let description = [{
- Marks a return from the current block.
+ Escape from the current nesting level and return the control flow to its successor.
+ Optionally, mark the arguments that should be transfered to the successor block.
- Example:
+ This shouldn't be confused with branch operations that targets the label defined
+ by the nesting level operation.
+
+ For instance, a `wasmssa.block_return` in a loop will give back control to the
+ successor of the loop, where a `branch` targeting the loop will flow back to the entry block of the loop.
+
+ Example:
```mlir
wasmssa.block_return
@@ -127,12 +166,18 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
- Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type
in the function type.
+ By default, `wasmssa.func` have nested visibility. Functions exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // A simple function with no arguments that returns a float32
+ // Internal function with no arguments that returns a float32
wasmssa.func @my_f32_func() -> f32
+ // Exported function with no arguments that returns a float32
+ wasmssa.func exported @my_f32_func() -> f32
+
// A function that takes a local ref argument
wasmssa.func @i64_wrap(%a: !wasmssa<local ref to i64>) -> i32
```
@@ -141,7 +186,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
WasmSSA_FuncTypeAttr: $functionType,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
- DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility);
+ UnitAttr: $exported);
let regions = (region AnyRegion: $body);
let extraClassDeclaration = [{
@@ -162,6 +207,12 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
@@ -207,8 +258,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
StrAttr: $importName,
WasmSSA_FuncTypeAttr: $type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
- OptionalAttr<DictArrayAttr>:$res_attrs,
- OptionalAttr<StrAttr>:$sym_visibility);
+ OptionalAttr<DictArrayAttr>:$res_attrs);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
@@ -221,6 +271,10 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
::llvm::ArrayRef<Type> getResultTypes() {
return getType().getResults();
}
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
OpBuilder<(ins "StringRef":$symbol,
@@ -238,30 +292,41 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_ValTypeAttr: $type,
UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $exported);
let description = [{
WebAssembly global variable.
Body contains the initialization instructions for the variable value.
The body must contain only instructions considered `const` in a webassembly context,
such as `wasmssa.const` or `global.get`.
+ By default, `wasmssa.global` have nested visibility. Global exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // Define a global_var, a mutable i32 global variable equal to 10.
- wasmssa.global @global_var i32 mutable nested : {
+ // Define module_global_var, an internal mutable i32 global variable equal to 10.
+ wasmssa.global @module_global_var i32 mutable : {
%[[VAL_0:.*]] = wasmssa.const 10 : i32
wasmssa.return %[[VAL_0]] : i32
}
+
+ // Define global_var, an exported constant i32 global variable equal to 42.
+ wasmssa.global @global_var i32 : {
+ %[[VAL_0:.*]] = wasmssa.const 42 : i32
+ wasmssa.return %[[VAL_0]] : i32
+ }
```
}];
let regions = (region AnyRegion: $initializer);
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "Type": $type,
- "bool": $isMutable)>
- ];
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
let hasCustomAssemblyFormat = 1;
}
@@ -283,18 +348,14 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
StrAttr: $moduleName,
StrAttr: $importName,
WasmSSA_ValTypeAttr: $type,
- UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $isMutable);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "StringRef":$moduleName,
- "StringRef":$importName,
- "Type": $type,
- "bool": $isMutable)>
- ];
let hasCustomAssemblyFormat = 1;
}
@@ -442,23 +503,33 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
Define a memory to be used by the program.
Multiple memories can be defined in the same module.
+ By default, `wasmssa.memory` have nested visibility. Memory exported by
+ the module are marked with the exported attribute. This gives them public
+ visibility.
+
Example:
```mlir
- // Define the `mem_0` memory with defined bounds of 0 -> 65536
+ // Define the `mem_0` (internal) memory with defined size bounds of [0:65536]
wasmssa.memory @mem_0 !wasmssa<limit[0:65536]>
+
+ // Define the `mem_1` exported memory with minimal size of 512
+ wasmssa.memory exported @mem_1 !wasmssa<limit[512:]>
```
}];
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [
- OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::LimitType":$limit)>
- ];
+ UnitAttr: $exported);
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict";
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $limits attr-dict";
}
def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> {
@@ -476,16 +547,13 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]>
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_LimitTypeAttr: $limits);
let extraClassDeclaration = [{
- bool isDeclaration() const { return true; }
+ bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::LimitType":$limits)>];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
}
@@ -493,11 +561,15 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
let summary= "WebAssembly table value";
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::TableType":$type)>];
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict";
+ UnitAttr: $exported);
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $type attr-dict";
}
def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> {
@@ -515,17 +587,14 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterfac
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_TableTypeAttr: $type);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::TableType":$type)>];
}
def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d..19a5231 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().contains(name);
}
- ArrayAttr getStrides() {
+ ArrayAttr getStrideAttr() {
return getAttrs().getAs<ArrayAttr>("stride");
}
+ ArrayAttr getBlockAttr() {
+ return getAttrs().getAs<ArrayAttr>("block");
+ }
+
}];
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061..426377f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
- AllElementTypesMatch<["mem_desc", "res"]>,
- AllRanksMatch<["mem_desc", "res"]>]> {
+ AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
- let results = (outs XeGPU_ValueType:$res);
+ let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
Arguments:
- `mem_desc`: the memory descriptor identifying the SLM region.
- `offsets`: the coordinates within the matrix to read from.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block load. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
@@ -1336,7 +1339,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
ArrayRef<int64_t> getDataShape() {
- return getRes().getType().getShape();
+ auto resTy = getRes().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1344,13 +1350,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- AllElementTypesMatch<["mem_desc", "data"]>,
- AllRanksMatch<["mem_desc", "data"]>]> {
+ AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
- XeGPU_ValueType:$data,
+ AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- `mem_desc`: the memory descriptor specifying the SLM region.
- `offsets`: the coordinates within the matrix where the data will be written.
- `data`: the values to be stored in the matrix.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block store. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
@@ -1378,7 +1387,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}
ArrayRef<int64_t> getDataShape() {
- return getData().getType().getShape();
+ auto DataTy = getData().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1386,41 +1398,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let hasVerifier = 1;
}
-def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
- [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
- let description = [{
- Creates a subview of a memory descriptor. The resulting memory descriptor can have
- a lower rank than the source; in this case, the result dimensions correspond to the
- higher-order dimensions of the source memory descriptor.
-
- Arguments:
- - `src` : a memory descriptor.
- - `offsets` : the coordinates within the matrix the subview will be created from.
-
- Results:
- - `res` : a memory descriptor with smaller size.
-
- }];
- let arguments = (ins XeGPU_MemDesc:$src,
- Variadic<Index>:$offsets,
- DenseI64ArrayAttr:$const_offsets);
- let results = (outs XeGPU_MemDesc:$res);
- let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
- attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
- let builders = [
- OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
- ];
-
- let extraClassDeclaration = [{
- mlir::Value getViewSource() { return getSrc(); }
-
- SmallVector<OpFoldResult> getMixedOffsets() {
- return getMixedValues(getConstOffsets(), getOffsets(), getContext());
- }
- }];
-
- let hasVerifier = 1;
-}
-
-
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84902b2..b1196fb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,12 +237,11 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}
- ArrayAttr getStrides() {
+ ArrayAttr getStrideAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
- return layout.getStrides();
+ return layout.getStrideAttr();
}
-
// derive and return default strides
SmallVector<int64_t> defaultStrides;
llvm::append_range(defaultStrides, getShape().drop_front());
@@ -250,6 +249,63 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}
+
+ ArrayAttr getBlockAttr() {
+ auto layout = getMemLayout();
+ if (layout && layout.hasAttr("block")) {
+ return layout.getBlockAttr();
+ }
+ Builder builder(getContext());
+ return builder.getI64ArrayAttr({});
+ }
+
+ /// Heuristic to determine if the MemDesc uses column-major layout,
+ /// based on the rank and the value of the first stride dimension.
+ bool isColMajor() {
+ auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
+ return getRank() == 2 && dim0.getInt() == 1;
+ }
+
+ // Get the Blocking shape for a MemDescType, Which is represented
+ // as an attribute in MemDescType. By default it is the shape
+ // of the mdescTy
+ SmallVector<int64_t> getBlockShape() {
+ SmallVector<int64_t> size(getShape());
+ ArrayAttr blockAttr = getBlockAttr();
+ if (!blockAttr.empty()) {
+ size.clear();
+ for (auto attr : blockAttr.getValue()) {
+ size.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ }
+ return size;
+ }
+
+ // Get strides as vector of integer.
+ // If it contains block attribute, the strides are blocked strides.
+ //
+ // The blocking is applied to the base matrix shape derived from the
+ // memory descriptor's stride information. If the matrix described by
+ // the memory descriptor is not contiguous, it is assumed that the base
+ // matrix is contiguous and follows the same memory layout.
+ //
+ // It first computes the original matrix shape using the stride info,
+ // then computes the number of blocks in each dimension of original shape,
+ // then compute the outer block shape and stride,
+ // then combines the inner and outer block shape and stride
+ // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
+ // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
+ // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
+ // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+ SmallVector<int64_t> getStrideShape();
+
+ /// Generates instructions to compute the linearize offset
+ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+ // the strides of memory descriptor is always considered regardless of blocked or not
+ Value getLinearOffsets(OpBuilder &builder,
+ Location loc, ArrayRef<OpFoldResult> offsets);
+
+
}];
let hasCustomAssemblyFormat = true;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 6b4e3dd..8427ba5 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -623,6 +623,14 @@ class VectorOfLengthAndType<list<int> allowedLengths,
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+class FixedVectorOfShapeAndType<list<int> shape, Type elType>: ShapedContainerType<
+ [elType],
+ And<[IsVectorOfShape<shape>, IsFixedVectorOfAnyRankTypePred]>,
+ "vector<" # !interleave(shape, "x") # "x" # elType # ">",
+ "::mlir::VectorType">,
+ BuildableType<"::mlir::VectorType::get({" # !interleave(shape, " ,") # "} , " # elType.builderCall # " );">;
+
+
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h
index 20e84ec..9877926 100644
--- a/mlir/include/mlir/IR/Remarks.h
+++ b/mlir/include/mlir/IR/Remarks.h
@@ -18,7 +18,6 @@
#include "llvm/Remarks/Remark.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Regex.h"
-#include <optional>
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
@@ -144,7 +143,7 @@ public:
llvm::StringRef getCategoryName() const { return categoryName; }
- llvm::StringRef getFullCategoryName() const {
+ llvm::StringRef getCombinedCategoryName() const {
if (categoryName.empty() && subCategoryName.empty())
return {};
if (subCategoryName.empty())
@@ -318,7 +317,7 @@ private:
};
//===----------------------------------------------------------------------===//
-// MLIR Remark Streamer
+// Pluggable Remark Utilities
//===----------------------------------------------------------------------===//
/// Base class for MLIR remark streamers that is used to stream
@@ -338,6 +337,26 @@ public:
virtual void finalize() {} // optional
};
+using ReportFn = llvm::unique_function<void(const Remark &)>;
+
+/// Base class for MLIR remark emitting policies that is used to emit
+/// optimization remarks to the underlying remark streamer. The derived classes
+/// should implement the `reportRemark` method to provide the actual emitting
+/// implementation.
+class RemarkEmittingPolicyBase {
+protected:
+ ReportFn reportImpl;
+
+public:
+ RemarkEmittingPolicyBase() = default;
+ virtual ~RemarkEmittingPolicyBase() = default;
+
+ void initialize(ReportFn fn) { reportImpl = std::move(fn); }
+
+ virtual void reportRemark(const Remark &remark) = 0;
+ virtual void finalize() = 0;
+};
+
//===----------------------------------------------------------------------===//
// Remark Engine (MLIR Context will own this class)
//===----------------------------------------------------------------------===//
@@ -355,6 +374,8 @@ private:
std::optional<llvm::Regex> failedFilter;
/// The MLIR remark streamer that will be used to emit the remarks.
std::unique_ptr<MLIRRemarkStreamerBase> remarkStreamer;
+ /// The MLIR remark policy that will be used to emit the remarks.
+ std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy;
/// When is enabled, engine also prints remarks as mlir::emitRemarks.
bool printAsEmitRemarks = false;
@@ -392,6 +413,8 @@ private:
InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts,
bool (RemarkEngine::*isEnabled)(StringRef)
const);
+ /// Report a remark.
+ void reportImpl(const Remark &remark);
public:
/// Default constructor is deleted, use the other constructor.
@@ -407,8 +430,15 @@ public:
~RemarkEngine();
/// Setup the remark engine with the given output path and format.
- LogicalResult initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
- std::string *errMsg);
+ LogicalResult
+ initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ std::string *errMsg);
+
+ /// Get the remark emitting policy.
+ RemarkEmittingPolicyBase *getRemarkEmittingPolicy() const {
+ return remarkEmittingPolicy.get();
+ }
/// Report a remark.
void report(const Remark &&remark);
@@ -446,6 +476,46 @@ inline InFlightRemark withEngine(Fn fn, Location loc, Args &&...args) {
namespace mlir::remark {
+//===----------------------------------------------------------------------===//
+// Remark Emitting Policies
+//===----------------------------------------------------------------------===//
+
+/// Policy that emits all remarks.
+class RemarkEmittingPolicyAll : public detail::RemarkEmittingPolicyBase {
+public:
+ RemarkEmittingPolicyAll();
+
+ void reportRemark(const detail::Remark &remark) override {
+ assert(reportImpl && "reportImpl is not set");
+ reportImpl(remark);
+ }
+ void finalize() override {}
+};
+
+/// Policy that emits final remarks.
+class RemarkEmittingPolicyFinal : public detail::RemarkEmittingPolicyBase {
+private:
+ /// user can intercept them for custom processing via a registered callback,
+ /// otherwise they will be reported on engine destruction.
+ llvm::DenseSet<detail::Remark> postponedRemarks;
+
+public:
+ RemarkEmittingPolicyFinal();
+
+ void reportRemark(const detail::Remark &remark) override {
+ postponedRemarks.erase(remark);
+ postponedRemarks.insert(remark);
+ }
+
+ void finalize() override {
+ assert(reportImpl && "reportImpl is not set");
+ for (auto &remark : postponedRemarks) {
+ if (reportImpl)
+ reportImpl(remark);
+ }
+ }
+};
+
/// Create a Reason with llvm::formatv formatting.
template <class... Ts>
inline detail::LazyTextBuild reason(const char *fmt, Ts &&...ts) {
@@ -505,16 +575,72 @@ inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) {
/// Setup remarks for the context. This function will enable the remark engine
/// and set the streamer to be used for optimization remarks. The remark
-/// categories are used to filter the remarks that will be emitted by the remark
-/// engine. If a category is not specified, it will not be emitted. If
+/// categories are used to filter the remarks that will be emitted by the
+/// remark engine. If a category is not specified, it will not be emitted. If
/// `printAsEmitRemarks` is true, the remarks will be printed as
/// mlir::emitRemarks. 'streamer' must inherit from MLIRRemarkStreamerBase and
/// will be used to stream the remarks.
LogicalResult enableOptimizationRemarks(
MLIRContext &ctx,
std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<remark::detail::RemarkEmittingPolicyBase>
+ remarkEmittingPolicy,
const remark::RemarkCategories &cats, bool printAsEmitRemarks = false);
} // namespace mlir::remark
+// DenseMapInfo specialization for Remark
+namespace llvm {
+template <>
+struct DenseMapInfo<mlir::remark::detail::Remark> {
+ static constexpr StringRef kEmptyKey = "<EMPTY_KEY>";
+ static constexpr StringRef kTombstoneKey = "<TOMBSTONE_KEY>";
+
+ /// Helper to provide a static dummy context for sentinel keys.
+ static mlir::MLIRContext *getStaticDummyContext() {
+ static mlir::MLIRContext dummyContext;
+ return &dummyContext;
+ }
+
+ /// Create an empty remark
+ static inline mlir::remark::detail::Remark getEmptyKey() {
+ return mlir::remark::detail::Remark(
+ mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
+ mlir::UnknownLoc::get(getStaticDummyContext()),
+ mlir::remark::RemarkOpts::name(kEmptyKey));
+ }
+
+ /// Create a dead remark
+ static inline mlir::remark::detail::Remark getTombstoneKey() {
+ return mlir::remark::detail::Remark(
+ mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note,
+ mlir::UnknownLoc::get(getStaticDummyContext()),
+ mlir::remark::RemarkOpts::name(kTombstoneKey));
+ }
+
+ /// Compute the hash value of the remark
+ static unsigned getHashValue(const mlir::remark::detail::Remark &remark) {
+ return llvm::hash_combine(
+ remark.getLocation().getAsOpaquePointer(),
+ llvm::hash_value(remark.getRemarkName()),
+ llvm::hash_value(remark.getCombinedCategoryName()));
+ }
+
+ static bool isEqual(const mlir::remark::detail::Remark &lhs,
+ const mlir::remark::detail::Remark &rhs) {
+ // Check for empty/tombstone keys first
+ if (lhs.getRemarkName() == kEmptyKey ||
+ lhs.getRemarkName() == kTombstoneKey ||
+ rhs.getRemarkName() == kEmptyKey ||
+ rhs.getRemarkName() == kTombstoneKey) {
+ return lhs.getRemarkName() == rhs.getRemarkName();
+ }
+
+ // For regular remarks, compare key identifying fields
+ return lhs.getLocation() == rhs.getLocation() &&
+ lhs.getRemarkName() == rhs.getRemarkName() &&
+ lhs.getCombinedCategoryName() == rhs.getCombinedCategoryName();
+ }
+};
+} // namespace llvm
#endif // MLIR_IR_REMARKS_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index a5feb59..72ed046 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
+add_mlir_interface(InferStridedMetadataInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(MemOpInterfaces)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e8..a6de3d1 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -117,7 +117,8 @@ public:
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
/// Create an integer value range lattice value.
- IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt)
+ explicit IntegerValueRange(
+ std::optional<ConstantIntRanges> value = std::nullopt)
: value(std::move(value)) {}
/// Whether the range is uninitialized. This happens when the state hasn't
@@ -167,6 +168,15 @@ using SetIntRangeFn =
using SetIntLatticeFn =
llvm::function_ref<void(Value, const IntegerValueRange &)>;
+/// Helper callback type to get the integer range of a value.
+using GetIntRangeFn = function_ref<IntegerValueRange(Value)>;
+
+/// Helper function to collect the integer range values of an array of op fold
+/// results.
+SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values,
+ GetIntRangeFn getIntRange,
+ int32_t indexBitwidth);
+
class InferIntRangeInterface;
namespace intrange::detail {
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
new file mode 100644
index 0000000..0c572e0
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h
@@ -0,0 +1,145 @@
+//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the strided metadata inference interface
+// defined in `InferStridedMetadataInterface.td`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
+
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+/// A class that represents the strided metadata range information, including
+/// offsets, sizes, and strides as integer ranges.
+class StridedMetadataRange {
+public:
+ /// Default constructor creates uninitialized ranges.
+ StridedMetadataRange() = default;
+
+ /// Returns a ranked strided metadata range.
+ static StridedMetadataRange
+ getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets,
+ SmallVectorImpl<ConstantIntRanges> &&sizes,
+ SmallVectorImpl<ConstantIntRanges> &&strides) {
+ return StridedMetadataRange(std::move(offsets), std::move(sizes),
+ std::move(strides));
+ }
+
+ /// Returns a strided metadata range with maximum ranges.
+ static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+ int32_t offsetsRank,
+ int32_t sizeRank,
+ int32_t stridedRank) {
+ return StridedMetadataRange(
+ SmallVector<ConstantIntRanges>(
+ offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)),
+ SmallVector<ConstantIntRanges>(
+ sizeRank, ConstantIntRanges::maxRange(indexBitwidth)),
+ SmallVector<ConstantIntRanges>(
+ stridedRank, ConstantIntRanges::maxRange(indexBitwidth)));
+ }
+
+ static StridedMetadataRange getMaxRanges(int32_t indexBitwidth,
+ int32_t rank) {
+ return getMaxRanges(indexBitwidth, 1, rank, rank);
+ }
+
+ /// Returns whether the metadata is uninitialized.
+ bool isUninitialized() const { return !offsets.has_value(); }
+
+ /// Get the offsets range.
+ ArrayRef<ConstantIntRanges> getOffsets() const {
+ return offsets ? *offsets : ArrayRef<ConstantIntRanges>();
+ }
+ MutableArrayRef<ConstantIntRanges> getOffsets() {
+ return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>();
+ }
+
+ /// Get the sizes ranges.
+ ArrayRef<ConstantIntRanges> getSizes() const { return sizes; }
+ MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; }
+
+ /// Get the strides ranges.
+ ArrayRef<ConstantIntRanges> getStrides() const { return strides; }
+ MutableArrayRef<ConstantIntRanges> getStrides() { return strides; }
+
+ /// Compare two strided metadata ranges.
+ bool operator==(const StridedMetadataRange &other) const {
+ return offsets == other.offsets && sizes == other.sizes &&
+ strides == other.strides;
+ }
+
+ /// Print the strided metadata range.
+ void print(raw_ostream &os) const;
+
+ /// Join two strided metadata ranges, by taking the element-wise union of the
+ /// metadata.
+ static StridedMetadataRange join(const StridedMetadataRange &lhs,
+ const StridedMetadataRange &rhs) {
+ if (lhs.isUninitialized())
+ return rhs;
+ if (rhs.isUninitialized())
+ return lhs;
+
+ // Helper fuction to compute the range union of constant ranges.
+ auto rangeUnion =
+ +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs)
+ -> ConstantIntRanges {
+ return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs));
+ };
+
+ // Get the elementwise range union. Note, that `zip_equal` will assert if
+ // sizes are not equal.
+ SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector(
+ llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion);
+ SmallVector<ConstantIntRanges> sizes =
+ llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion);
+ SmallVector<ConstantIntRanges> strides = llvm::map_to_vector(
+ llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion);
+
+ // Return the joined metadata.
+ return StridedMetadataRange(std::move(offsets), std::move(sizes),
+ std::move(strides));
+ }
+
+private:
+ /// Create a strided metadata range with the given offset, sizes, and strides.
+ StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets,
+ SmallVectorImpl<ConstantIntRanges> &&sizes,
+ SmallVectorImpl<ConstantIntRanges> &&strides)
+ : offsets(std::move(offsets)), sizes(std::move(sizes)),
+ strides(std::move(strides)) {}
+
+ /// The offsets range.
+ std::optional<SmallVector<ConstantIntRanges>> offsets;
+
+ /// The sizes ranges.
+ SmallVector<ConstantIntRanges> sizes;
+
+ /// The strides ranges.
+ SmallVector<ConstantIntRanges> strides;
+};
+
+/// Print the strided metadata to `os`.
+inline raw_ostream &operator<<(raw_ostream &os,
+ const StridedMetadataRange &range) {
+ range.print(os);
+ return os;
+}
+
+/// Callback function type for setting the strided metadata of a value.
+using SetStridedMetadataRangeFn =
+ function_ref<void(Value, const StridedMetadataRange &)>;
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
new file mode 100644
index 0000000..ee5b094
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td
@@ -0,0 +1,45 @@
+//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for strided metadata range analysis
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferStridedMetadataOpInterface :
+ OpInterface<"InferStridedMetadataOpInterface"> {
+ let description = [{
+ Allows operations to participate in strided metadata analysis by providing
+ methods that allow them to specify bounds on offsets, sizes, and strides
+ of their result(s) given bounds on their input(s) if known.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Infer the strided metadata bounds on the results of this op given
+ the bounds on its operands.
+ For each result value or block argument of interest, the method should
+ call `setMetadata` with that `Value` as an argument.
+ The `operands` parameter contains the strided metadata ranges for all the
+ operands of the operation in order.
+ The `getIntRange` callback is provided for obtaining the int-range
+ analysis result for a given value.
+ }],
+ "void", "inferStridedMetadataRanges",
+ (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands,
+ "::mlir::GetIntRangeFn":$getIntRange,
+ "::mlir::SetStridedMetadataRangeFn":$setMetadata,
+ "int32_t":$indexBitwidth)>
+ ];
+}
+#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index db9c37f..c1c2269 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -230,6 +230,22 @@ LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
ArrayRef<int64_t> attr,
ValueRange values);
+namespace OpTrait {
+/// This trai indicates that pointer-like objects (such as memrefs) returned
+/// from this operation will never alias with each other. This provides a
+/// guarantee to optimization passes that accesses through different results
+/// of this operation can be safely reordered, as they will never reference
+/// overlapping memory locations.
+///
+/// Operations with this trait take multiple pointer-like operands
+/// and return the same operands with additional non-aliasing guarantees.
+/// If the access to the results of this operation aliases at runtime, the
+/// behavior of such access is undefined.
+template <typename ConcreteType>
+class DistinctObjectsTrait
+ : public TraitBase<ConcreteType, DistinctObjectsTrait> {};
+} // namespace OpTrait
+
} // namespace mlir
#endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index ed213bf..131c1a0 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -414,4 +414,16 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
}];
}
+// This trai indicates that pointer-like objects (such as memrefs) returned
+// from this operation will never alias with each other. This provides a
+// guarantee to optimization passes that accesses through different results
+// of this operation can be safely reordered, as they will never reference
+// overlapping memory locations.
+//
+// Operations with this trait take multiple pointer-like operands
+// and return the same operands with additional non-aliasing guarantees.
+// If the access to the results of this operation aliases at runtime, the
+// behavior of such access is undefined.
+def DistinctObjectsTrait : NativeOpTrait<"DistinctObjectsTrait">;
+
#endif // MLIR_INTERFACES_VIEWLIKEINTERFACE
diff --git a/mlir/include/mlir/Remark/RemarkStreamer.h b/mlir/include/mlir/Remark/RemarkStreamer.h
index 170d6b4..19a70fa 100644
--- a/mlir/include/mlir/Remark/RemarkStreamer.h
+++ b/mlir/include/mlir/Remark/RemarkStreamer.h
@@ -45,6 +45,7 @@ namespace mlir::remark {
/// mlir::emitRemarks.
LogicalResult enableOptimizationRemarksWithLLVMStreamer(
MLIRContext &ctx, StringRef filePath, llvm::remarks::Format fmt,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
const RemarkCategories &cat, bool printAsEmitRemarks = false);
} // namespace mlir::remark
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index 252da21..997aef2 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -88,7 +88,7 @@ public:
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
- void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs);
+ void emitOpConstraints();
/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
index 21adde8..cd9ef5b 100644
--- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -19,6 +19,14 @@ namespace mlir {
struct WasmBinaryEncoding {
/// Byte encodings for Wasm instructions.
struct OpCode {
+ // Control instructions.
+ static constexpr std::byte block{0x02};
+ static constexpr std::byte loop{0x03};
+ static constexpr std::byte ifOpCode{0x04};
+ static constexpr std::byte elseOpCode{0x05};
+ static constexpr std::byte branchIf{0x0D};
+ static constexpr std::byte call{0x10};
+
// Locals, globals, constants.
static constexpr std::byte localGet{0x20};
static constexpr std::byte localSet{0x21};
@@ -29,6 +37,42 @@ struct WasmBinaryEncoding {
static constexpr std::byte constFP32{0x43};
static constexpr std::byte constFP64{0x44};
+ // Comparisons.
+ static constexpr std::byte eqzI32{0x45};
+ static constexpr std::byte eqI32{0x46};
+ static constexpr std::byte neI32{0x47};
+ static constexpr std::byte ltSI32{0x48};
+ static constexpr std::byte ltUI32{0x49};
+ static constexpr std::byte gtSI32{0x4A};
+ static constexpr std::byte gtUI32{0x4B};
+ static constexpr std::byte leSI32{0x4C};
+ static constexpr std::byte leUI32{0x4D};
+ static constexpr std::byte geSI32{0x4E};
+ static constexpr std::byte geUI32{0x4F};
+ static constexpr std::byte eqzI64{0x50};
+ static constexpr std::byte eqI64{0x51};
+ static constexpr std::byte neI64{0x52};
+ static constexpr std::byte ltSI64{0x53};
+ static constexpr std::byte ltUI64{0x54};
+ static constexpr std::byte gtSI64{0x55};
+ static constexpr std::byte gtUI64{0x56};
+ static constexpr std::byte leSI64{0x57};
+ static constexpr std::byte leUI64{0x58};
+ static constexpr std::byte geSI64{0x59};
+ static constexpr std::byte geUI64{0x5A};
+ static constexpr std::byte eqF32{0x5B};
+ static constexpr std::byte neF32{0x5C};
+ static constexpr std::byte ltF32{0x5D};
+ static constexpr std::byte gtF32{0x5E};
+ static constexpr std::byte leF32{0x5F};
+ static constexpr std::byte geF32{0x60};
+ static constexpr std::byte eqF64{0x61};
+ static constexpr std::byte neF64{0x62};
+ static constexpr std::byte ltF64{0x63};
+ static constexpr std::byte gtF64{0x64};
+ static constexpr std::byte leF64{0x65};
+ static constexpr std::byte geF64{0x66};
+
// Numeric operations.
static constexpr std::byte clzI32{0x67};
static constexpr std::byte ctzI32{0x68};
@@ -93,6 +137,33 @@ struct WasmBinaryEncoding {
static constexpr std::byte maxF64{0xA5};
static constexpr std::byte copysignF64{0xA6};
static constexpr std::byte wrap{0xA7};
+
+ // Conversion operations
+ static constexpr std::byte extendS{0xAC};
+ static constexpr std::byte extendU{0xAD};
+ static constexpr std::byte convertSI32F32{0xB2};
+ static constexpr std::byte convertUI32F32{0xB3};
+ static constexpr std::byte convertSI64F32{0xB4};
+ static constexpr std::byte convertUI64F32{0xB5};
+
+ static constexpr std::byte demoteF64ToF32{0xB6};
+
+ static constexpr std::byte convertSI32F64{0xB7};
+ static constexpr std::byte convertUI32F64{0xB8};
+ static constexpr std::byte convertSI64F64{0xB9};
+ static constexpr std::byte convertUI64F64{0xBA};
+
+ static constexpr std::byte promoteF32ToF64{0xBB};
+ static constexpr std::byte reinterpretF32AsI32{0xBC};
+ static constexpr std::byte reinterpretF64AsI64{0xBD};
+ static constexpr std::byte reinterpretI32AsF32{0xBE};
+ static constexpr std::byte reinterpretI64AsF64{0xBF};
+
+ static constexpr std::byte extendI328S{0xC0};
+ static constexpr std::byte extendI3216S{0xC1};
+ static constexpr std::byte extendI648S{0xC2};
+ static constexpr std::byte extendI6416S{0xC3};
+ static constexpr std::byte extendI6432S{0xC4};
};
/// Byte encodings of types in Wasm binaries
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 0fbe15f..b739438 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -44,6 +44,11 @@ enum class RemarkFormat {
REMARK_FORMAT_BITSTREAM,
};
+enum class RemarkPolicy {
+ REMARK_POLICY_ALL,
+ REMARK_POLICY_FINAL,
+};
+
/// Configuration options for the mlir-opt tool.
/// This is intended to help building tools like mlir-opt by collecting the
/// supported options.
@@ -242,6 +247,8 @@ public:
/// Set the reproducer output filename
RemarkFormat getRemarkFormat() const { return remarkFormatFlag; }
+ /// Set the remark policy to use.
+ RemarkPolicy getRemarkPolicy() const { return remarkPolicyFlag; }
/// Set the remark format to use.
std::string getRemarksAllFilter() const { return remarksAllFilterFlag; }
/// Set the remark output file.
@@ -265,6 +272,8 @@ protected:
/// Remark format
RemarkFormat remarkFormatFlag = RemarkFormat::REMARK_FORMAT_STDOUT;
+ /// Remark policy
+ RemarkPolicy remarkPolicyFlag = RemarkPolicy::REMARK_POLICY_ALL;
/// Remark file to output to
std::string remarksOutputFileFlag = "";
/// Remark filters