diff options
Diffstat (limited to 'mlir/include')
45 files changed, 2415 insertions, 558 deletions
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 2db1d84..fe42a20 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -352,7 +352,7 @@ typedef struct { /// Create a rewrite pattern that matches the operation /// with the given rootName, corresponding to mlir::OpRewritePattern. -MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( +MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames); diff --git a/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h new file mode 100644 index 0000000..72ac247 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h @@ -0,0 +1,54 @@ +//===- StridedMetadataRange.h - Strided metadata range analysis -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H +#define MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Interfaces/InferStridedMetadataInterface.h" + +namespace mlir { +namespace dataflow { + +/// This lattice element represents the strided metadata of an SSA value. +class StridedMetadataRangeLattice : public Lattice<StridedMetadataRange> { +public: + using Lattice::Lattice; +}; + +/// Strided metadata range analysis determines the strided metadata ranges of +/// SSA values using operations that define `InferStridedMetadataInterface`. +/// +/// This analysis depends on DeadCodeAnalysis, SparseConstantPropagation, and +/// IntegerRangeAnalysis, and will be a silent no-op if the analyses are not +/// loaded in the same solver context. +class StridedMetadataRangeAnalysis + : public SparseForwardDataFlowAnalysis<StridedMetadataRangeLattice> { +public: + StridedMetadataRangeAnalysis(DataFlowSolver &solver, + int32_t indexBitwidth = 64); + + /// At an entry point, we cannot reason about strided metadata ranges unless + /// the type also encodes the data. For example, a memref with static layout. + void setToEntryState(StridedMetadataRangeLattice *lattice) override; + + /// Visit an operation. Invoke the transfer function on each operation that + /// implements `InferStridedMetadataInterface`. + LogicalResult + visitOperation(Operation *op, + ArrayRef<const StridedMetadataRangeLattice *> operands, + ArrayRef<StridedMetadataRangeLattice *> results) override; + +private: + /// Index bitwidth to use when operating with the int-ranges. + int32_t indexBitwidth = 64; +}; +} // namespace dataflow +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DATAFLOW_STRIDEDMETADATARANGE_H diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h index 46573e79..60f1888 100644 --- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/IR/PatternMatch.h" #include <memory> @@ -19,8 +20,11 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" /// Populate the given list with patterns that convert from Math to ROCDL calls. -void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns); +// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`, +// none of the chipset dependent patterns are added. +void populateMathToROCDLConversionPatterns( + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + std::optional<amdgpu::Chipset> chipset); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h new file mode 100644 index 0000000..91d3c92 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h @@ -0,0 +1,27 @@ +//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ +#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include <memory> + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Math to XeVM calls. +void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index da061b2..40d866e 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -49,6 +49,7 @@ #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 3c18ecc..70e3e45 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { let summary = "Convert Math dialect to ROCDL library calls"; let description = [{ This pass converts supported Math ops to ROCDL library calls. + + The chipset option specifies the target AMDGPU architecture. If the chipset + is empty, none of the chipset-dependent patterns are added, and the pass + will not attempt to parse the chipset. }]; let dependentDialects = [ "arith::ArithDialect", @@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; + let options = [Option<"chipset", "chipset", "std::string", + /*default=*/"\"\"", + "Chipset that these operations will run on">]; } //===----------------------------------------------------------------------===// @@ -797,6 +804,31 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { } //===----------------------------------------------------------------------===// +// MathToXeVM +//===----------------------------------------------------------------------===// + +def ConvertMathToXeVM : Pass<"convert-math-to-xevm"> { + let summary = + "Convert (fast) math operations to native XeVM/SPIRV equivalents"; + let description = [{ + This pass converts supported math ops marked with the `afn` fastmath flag + to function calls for OpenCL `native_` math intrinsics: These intrinsics + are typically mapped directly to native device instructions, often resulting + in better performance. However, the precision/error of these intrinsics + are implementation-defined, and thus math ops are only converted when they + have the `afn` fastmath flag enabled. + }]; + let options = [Option< + "convertArith", "convert-arith", "bool", /*default=*/"true", + "Convert supported Arith ops (e.g. arith.divf) as well.">]; + let dependentDialects = [ + "arith::ArithDialect", + "xevm::XeVMDialect", + "LLVM::LLVMDialect", + ]; +} + +//===----------------------------------------------------------------------===// // MathToEmitC //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 8370d35..7184de9 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -112,6 +112,97 @@ def AMDGPU_ExtPackedFp8Op : }]; } +def IsValidBlockSize: AttrConstraint< + CPred<"::llvm::is_contained({16, 32}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">, + "whose value is 16 or 32">; + +def AMDGPU_ScaledExtPacked816Op + : AMDGPU_Op<"scaled_ext_packed816", [Pure, AllShapesMatch<["source", "res"]>]>, + Arguments<( + ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>, + FixedVectorOfShapeAndType<[8], F8E4M3FN>, + FixedVectorOfShapeAndType<[8], F8E5M2>, + FixedVectorOfShapeAndType<[16], F6E2M3FN>, + FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source, + FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale, + ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize, + ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane, + ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>, + Results<( + outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>, + FixedVectorOfShapeAndType<[8], F16>, + FixedVectorOfShapeAndType<[8], BF16>, + FixedVectorOfShapeAndType<[16], F32>, + FixedVectorOfShapeAndType<[16], F16>, + FixedVectorOfShapeAndType<[16], BF16>]>:$res)> { + + let summary = "Extend a vector of packed floating point values"; + + let description = [{ + The scales applied to the input microfloats are stored in two bytes which + come from the `scales` input provided in a *half* of the wave identified + by `firstScaleLane`. The pair of bytes used is selected by + `firstScaleByte`. The 16 vectors in consecutive lanes starting from + `firstScaleLane` (which we'll call the scale vectors) will be used by both + halves of the wave (with lane L reading from L % 16'th scale vector), but + each half will use a different byte. + + When the block size is 32, `firstScaleByte` can be either 0 or 2, + selecting halves of the scale vectors. Lanes 0-15 will read from + `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1. + For example: + ```mlir + // Input: 8-element vector of F8E4M3FN, converting to F32 + // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1 + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(32) firstScaleLane(0) firstScaleByte(0) + : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + + // Input: 16-element vector of F6E2M3FN, converting to F16 + // Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3 + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(32) firstScaleLane(1) firstScaleByte(2) + : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + ``` + + However, when the block size is 16, `firstScaleByte` can be 0 or 1. + Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors, + while lanes 16-31 read from `firstScaleByte` + 2. + For example: + ```mlir + // Input: 8-element vector of F8E5M2, converting to BF16 + // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2) + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(16) firstScaleLane(0) firstScaleByte(0) + : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // Input: 16-element vector of F6E3M2FN, converting to F32 + // Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2) + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(16) firstScaleLane(1) firstScaleByte(1) + : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + ``` + + Note: the layout for the scales generally mirrors how the WMMA + instructions use for matix scales. These selection operands allows + one to choose portions of the matrix to convert. + + Available on gfx1250+. + }]; + + let assemblyFormat = [{ + attr-dict $source + `scale` `(` $scale `)` + `blockSize` `(` $blockSize `)` + `firstScaleLane` `(` $firstScaleLane`)` + `firstScaleByte` `(` $firstScaleByte `)` + `:` type($source) `,` type($scale) `->` type($res) + }]; + + let hasVerifier = 1; + +} + def AMDGPU_ScaledExtPackedOp : AMDGPU_Op<"scaled_ext_packed", [Pure]>, Arguments<( @@ -860,7 +951,7 @@ def AMDGPU_MFMAOp : based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the types of the source and destination arguments. - For information on the layouts of the input and output matrces (which are stored + For information on the layouts of the input and output matrices (which are stored in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation. The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index e52b7d2..12a7935 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -330,7 +330,6 @@ def AffineForOp : Affine_Op<"for", Speculation::Speculatability getSpeculatability(); }]; - let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasRegionVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h index 9b59af7..830c394 100644 --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -61,7 +61,7 @@ LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor); /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops); +[[maybe_unused]] bool isPerfectlyNested(ArrayRef<AffineForOp> loops); /// Get perfectly nested sequence of loops starting at root of loop nest /// (the first op being another AffineFor, and the second op - a terminator). diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 20c9097..a38cf41 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1229,37 +1229,50 @@ def Arith_ScalingExtFOp let summary = "Upcasts input floats using provided scales values following " "OCP MXFP Spec"; let description = [{ - This operation upcasts input floating-point values using provided scale - values. It expects both scales and the input operand to be of the same shape, - making the operation elementwise. Scales are usually calculated per block - following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. - - If scales are calculated per block where blockSize != 1, then scales may - require broadcasting to make this operation elementwise. For example, let's - say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and - assuming quantization happens on the last axis, the input can be reshaped to - `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated - per block on the last axis. Therefore, scales will be of shape - `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other - shape as long as it is broadcast compatible with the input, e.g., - `<1 x 1 x ... (dimN/blockSize) x 1>`. - - In this example, before calling into `arith.scaling_extf`, scales must be - broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note - that there could be multiple quantization axes. Internally, - `arith.scaling_extf` would perform the following: + This operation upcasts input floating-point values using provided scale + values. It expects both scales and the input operand to be of the same shape, + making the operation elementwise. Scales are usually calculated per block + following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. - ``` - resultTy = get_type(result) - scaleTy = get_type(scale) - inputTy = get_type(input) - scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 - scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy - input.extf = arith.extf(input) : inputTy to resultTy - result = arith.mulf(scale.extf, input.extf) + If scales are calculated per block where blockSize != 1, then scales may + require broadcasting to make this operation elementwise. For example, let's + say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and + assuming quantization happens on the last axis, the input can be reshaped to + `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated + per block on the last axis. Therefore, scales will be of shape + `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other + shape as long as it is broadcast compatible with the input, e.g., + `<1 x 1 x ... (dimN/blockSize) x 1>`. + + In this example, before calling into `arith.scaling_extf`, scales must be + broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note + that there could be multiple quantization axes. Internally, + `arith.scaling_extf` would perform the following: + + ```mlir + // Cast scale to result type. + %0 = arith.truncf %1 : f32 to f8E8M0FNU + %1 = arith.extf %0 : f8E8M0FNU to f16 + + // Cast input to result type. + %2 = arith.extf %3 : f4E2M1FN to f16 + + // Perform scaling + %3 = arith.mulf %2, %1 : f16 ``` It propagates NaN values. Therefore, if either scale or the input element contains NaN, then the output element value will also be a NaN. + + Example: + + ```mlir + // Upcast from f4E2M1FN to f32. + %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32 + + // Element-wise upcast with broadcast (blockSize = 32). + %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> + %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16> + ``` }]; let hasVerifier = 1; let assemblyFormat = @@ -1397,14 +1410,27 @@ def Arith_ScalingTruncFOp that there could be multiple quantization axes. Internally, `arith.scaling_truncf` would perform the following: + ```mlir + // Cast scale to input type. + %0 = arith.truncf %1 : f32 to f8E8M0FNU + %1 = arith.extf %0 : f8E8M0FNU to f16 + + // Perform scaling. + %3 = arith.divf %2, %1 : f16 + + // Cast to result type. + %4 = arith.truncf %3 : f16 to f4E2M1FN ``` - scaleTy = get_type(scale) - inputTy = get_type(input) - resultTy = get_type(result) - scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 - scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy - result = arith.divf(input, scale.extf) - result.cast = arith.truncf(result, resultTy) + + Example: + + ```mlir + // Downcast from f32 to f4E2M1FN. + %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN + + // Element-wise downcast with broadcast (blockSize = 32). + %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> + %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN> ``` }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h index 035235f..fccb49d 100644 --- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h @@ -1,4 +1,4 @@ -//===- Passes.h - GPU NVVM pipeline entry points --------------------------===// +//===- Passes.h - GPU pipeline entry points--------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -60,6 +60,52 @@ struct GPUToNVVMPipelineOptions llvm::cl::init(false)}; }; +// Options for the gpu to xevm pipeline. +struct GPUToXeVMPipelineOptions + : public PassPipelineOptions<GPUToXeVMPipelineOptions> { + PassOptions::Option<std::string> xegpuOpLevel{ + *this, "xegpu-op-level", + llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | " + "subgroup | lane"), + llvm::cl::init("workgroup")}; + // General lowering controls. + PassOptions::Option<bool> use64bitIndex{ + *this, "use-64bit-index", + llvm::cl::desc("Bitwidth of the index type (host & device)"), + llvm::cl::init(true)}; + PassOptions::Option<bool> kernelBarePtrCallConv{ + *this, "kernel-bare-ptr-calling-convention", + llvm::cl::desc("Use bare pointer calling convention for device kernels"), + llvm::cl::init(false)}; + PassOptions::Option<bool> hostBarePtrCallConv{ + *this, "host-bare-ptr-calling-convention", + llvm::cl::desc("Use bare pointer calling convention for host launches"), + llvm::cl::init(false)}; + PassOptions::Option<std::string> binaryFormat{ + *this, "binary-format", + llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"), + llvm::cl::init("fatbin")}; + // Options mirroring xevm-attach-target (GpuXeVMAttachTarget). + PassOptions::Option<std::string> xevmModuleMatcher{ + *this, "xevm-module-matcher", + llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"), + llvm::cl::init("")}; + PassOptions::Option<std::string> zebinTriple{ + *this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"), + llvm::cl::init("spirv64-unknown-unknown")}; + PassOptions::Option<std::string> zebinChip{ + *this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"), + llvm::cl::init("bmg")}; + PassOptions::Option<unsigned> optLevel{ + *this, "opt-level", + llvm::cl::desc("Optimization level for attached target/codegen"), + llvm::cl::init(2)}; + PassOptions::Option<std::string> cmdOptions{ + *this, "igc-cmd-options", + llvm::cl::desc("Additional downstream compiler command line options"), + llvm::cl::init("")}; +}; + //===----------------------------------------------------------------------===// // Building and Registering. //===----------------------------------------------------------------------===// @@ -70,8 +116,15 @@ struct GPUToNVVMPipelineOptions void buildLowerToNVVMPassPipeline(OpPassManager &pm, const GPUToNVVMPipelineOptions &options); -/// Register all pipeleines for the `gpu` dialect. +/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main +/// dialects into XeVM targets. Begins with GPU code regions, then handles host +/// code. +void buildLowerToXeVMPassPipeline(OpPassManager &pm, + const GPUToXeVMPipelineOptions &options); + +/// Register all pipelines for the `gpu` dialect. void registerGPUToNVVMPipeline(); +void registerGPUToXeVMPipeline(); } // namespace gpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 8d9474b..c301e0b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -48,6 +48,10 @@ mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conve mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics) add_mlir_dialect_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen) +set(LLVM_TARGET_DEFINITIONS LLVMDialectBytecode.td) +mlir_tablegen(LLVMDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="LLVM") +add_public_tablegen_target(MLIRLLVMDialectBytecodeIncGen) + set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td) mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls) mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td new file mode 100644 index 0000000..e7b202c --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialectBytecode.td @@ -0,0 +1,353 @@ +//===-- LLVMDialectBytecode.td - LLVM bytecode defs --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the LLVM bytecode reader/writer definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_DIALECT_BYTECODE +#define LLVM_DIALECT_BYTECODE + +include "mlir/IR/BytecodeBase.td" + +//===----------------------------------------------------------------------===// +// Bytecode classes for attributes and types. +//===----------------------------------------------------------------------===// + +def String : + WithParser <"succeeded($_reader.readString($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedString($_getter)", + WithType <"StringRef">>>>; + +class Attr<string type> : WithType<type, Attribute>; + +class OptionalAttribute<string type> : + WithParser <"succeeded($_reader.readOptionalAttribute($_var))", + WithPrinter<"$_writer.writeOptionalAttribute($_getter)", + WithType<type, Attribute>>>; + +class OptionalInt<string type> : + WithParser <"succeeded(readOptionalInt($_reader, $_var))", + WithPrinter<"writeOptionalInt($_writer, $_getter)", + WithType<"std::optional<" # type # ">", VarInt>>>; + +class OptionalArrayRef<string eltType> : + WithParser <"succeeded(readOptionalArrayRef<" + # eltType # ">($_reader, $_var))", + WithPrinter<"writeOptionalArrayRef<" + # eltType # ">($_writer, $_getter)", + WithType<"SmallVector<" + # eltType # ">", Attribute>>>; + +class EnumClassFlag<string flag, string getter> : + WithParser<"succeeded($_reader.readVarInt($_var))", + WithBuilder<"(" # flag # ")$_args", + WithPrinter<"$_writer.writeVarInt((uint64_t)$_name." # getter # ")", + WithType<"uint64_t", VarInt>>>>; + +//===----------------------------------------------------------------------===// +// General notes +// - For each attribute or type entry, the argument names should match +// LLVMAttrDefs.td +// - The mnemonics are either LLVM or builtin MLIR attributes and types, but +// regular C++ types are also allowed to match builders and parsers. +// - DIScopeAttr and DINodeAttr are empty base classes, custom encoding not +// needed. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// DIBasicTypeAttr +//===----------------------------------------------------------------------===// + +def DIBasicTypeAttr : DialectAttribute<(attr + VarInt:$tag, + String:$name, + VarInt:$sizeInBits, + VarInt:$encoding +)>; + +//===----------------------------------------------------------------------===// +// DIExpressionAttr, DIExpressionElemAttr +//===----------------------------------------------------------------------===// + +def DIExpressionElemAttr : DialectAttribute<(attr + VarInt:$opcode, + OptionalArrayRef<"uint64_t">:$arguments +)>; + +def DIExpressionAttr : DialectAttribute<(attr + OptionalArrayRef<"DIExpressionElemAttr">:$operations +)>; + +//===----------------------------------------------------------------------===// +// DIFileAttr +//===----------------------------------------------------------------------===// + +def DIFileAttr : DialectAttribute<(attr + String:$name, + String:$directory +)>; + +//===----------------------------------------------------------------------===// +// DILocalVariableAttr +//===----------------------------------------------------------------------===// + +def DILocalVariableAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + VarInt:$arg, + VarInt:$alignInBits, + OptionalAttribute<"DITypeAttr">:$type, + EnumClassFlag<"DIFlags", "getFlags()">:$_rawflags, + LocalVar<"DIFlags", "(DIFlags)_rawflags">:$flags +)> { + // DILocalVariableAttr direct getter uses a `StringRef` for `name`. Since the + // more direct getter is prefered during bytecode reading, force the base one + // and prevent crashes for empty `StringAttr`. + let cBuilder = "$_resultType::get(context, $_args)"; +} + +//===----------------------------------------------------------------------===// +// DISubroutineTypeAttr +//===----------------------------------------------------------------------===// + +def DISubroutineTypeAttr : DialectAttribute<(attr + VarInt:$callingConvention, + OptionalArrayRef<"DITypeAttr">:$types +)>; + +//===----------------------------------------------------------------------===// +// DICompileUnitAttr +//===----------------------------------------------------------------------===// + +def DICompileUnitAttr : DialectAttribute<(attr + Attr<"DistinctAttr">:$id, + VarInt:$sourceLanguage, + Attr<"DIFileAttr">:$file, + OptionalAttribute<"StringAttr">:$producer, + Bool:$isOptimized, + EnumClassFlag<"DIEmissionKind", "getEmissionKind()">:$_rawEmissionKind, + LocalVar<"DIEmissionKind", "(DIEmissionKind)_rawEmissionKind">:$emissionKind, + EnumClassFlag<"DINameTableKind", "getNameTableKind()">:$_rawNameTableKind, + LocalVar<"DINameTableKind", + "(DINameTableKind)_rawNameTableKind">:$nameTableKind +)>; + +//===----------------------------------------------------------------------===// +// DISubprogramAttr +//===----------------------------------------------------------------------===// + +def DISubprogramAttr : DialectAttribute<(attr + OptionalAttribute<"DistinctAttr">:$recId, + Bool:$isRecSelf, + OptionalAttribute<"DistinctAttr">:$id, + OptionalAttribute<"DICompileUnitAttr">:$compileUnit, + OptionalAttribute<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"StringAttr">:$linkageName, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + VarInt:$scopeLine, + EnumClassFlag<"DISubprogramFlags", "getSubprogramFlags()">:$_rawflags, + LocalVar<"DISubprogramFlags", "(DISubprogramFlags)_rawflags">:$subprogramFlags, + OptionalAttribute<"DISubroutineTypeAttr">:$type, + OptionalArrayRef<"DINodeAttr">:$retainedNodes, + OptionalArrayRef<"DINodeAttr">:$annotations +)>; + +//===----------------------------------------------------------------------===// +// DICompositeTypeAttr +//===----------------------------------------------------------------------===// + +def DICompositeTypeAttr : DialectAttribute<(attr + OptionalAttribute<"DistinctAttr">:$recId, + Bool:$isRecSelf, + VarInt:$tag, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + OptionalAttribute<"DIScopeAttr">:$scope, + OptionalAttribute<"DITypeAttr">:$baseType, + EnumClassFlag<"DIFlags", "getFlags()">:$_rawflags, + LocalVar<"DIFlags", "(DIFlags)_rawflags">:$flags, + VarInt:$sizeInBits, + VarInt:$alignInBits, + OptionalAttribute<"DIExpressionAttr">:$dataLocation, + OptionalAttribute<"DIExpressionAttr">:$rank, + OptionalAttribute<"DIExpressionAttr">:$allocated, + OptionalAttribute<"DIExpressionAttr">:$associated, + OptionalArrayRef<"DINodeAttr">:$elements +)>; + +//===----------------------------------------------------------------------===// +// DIDerivedTypeAttr +//===----------------------------------------------------------------------===// + +def DIDerivedTypeAttr : DialectAttribute<(attr + VarInt:$tag, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DITypeAttr">:$baseType, + VarInt:$sizeInBits, + VarInt:$alignInBits, + VarInt:$offsetInBits, + OptionalInt<"unsigned">:$dwarfAddressSpace, + OptionalAttribute<"DINodeAttr">:$extraData +)>; + +//===----------------------------------------------------------------------===// +// DIImportedEntityAttr +//===----------------------------------------------------------------------===// + +def DIImportedEntityAttr : DialectAttribute<(attr + VarInt:$tag, + Attr<"DIScopeAttr">:$scope, + Attr<"DINodeAttr">:$entity, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + OptionalAttribute<"StringAttr">:$name, + OptionalArrayRef<"DINodeAttr">:$elements +)>; + +//===----------------------------------------------------------------------===// +// DIGlobalVariableAttr, DIGlobalVariableExpressionAttr +//===----------------------------------------------------------------------===// + +def DIGlobalVariableAttr : DialectAttribute<(attr + OptionalAttribute<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"StringAttr">:$linkageName, + Attr<"DIFileAttr">:$file, + VarInt:$line, + Attr<"DITypeAttr">:$type, + Bool:$isLocalToUnit, + Bool:$isDefined, + VarInt:$alignInBits +)>; + +def DIGlobalVariableExpressionAttr : DialectAttribute<(attr + Attr<"DIGlobalVariableAttr">:$var, + OptionalAttribute<"DIExpressionAttr">:$expr +)>; + +//===----------------------------------------------------------------------===// +// DILabelAttr +//===----------------------------------------------------------------------===// + +def DILabelAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line +)> { + // DILabelAttr direct getter uses a `StringRef` for `name`. Since the + // more direct getter is prefered during bytecode reading, force the base one + // and prevent crashes for empty `StringAttr`. + let cBuilder = "$_resultType::get(context, $_args)"; +} + +//===----------------------------------------------------------------------===// +// DILexicalBlockAttr, DILexicalBlockFileAttr +//===----------------------------------------------------------------------===// + +def DILexicalBlockAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$line, + VarInt:$column +)>; + +def DILexicalBlockFileAttr : DialectAttribute<(attr + Attr<"DIScopeAttr">:$scope, + OptionalAttribute<"DIFileAttr">:$file, + VarInt:$discriminator +)>; + +//===----------------------------------------------------------------------===// +// DINamespaceAttr +//===----------------------------------------------------------------------===// + +def DINamespaceAttr : DialectAttribute<(attr + OptionalAttribute<"StringAttr">:$name, + OptionalAttribute<"DIScopeAttr">:$scope, + Bool:$exportSymbols +)>; + +//===----------------------------------------------------------------------===// +// DISubrangeAttr +//===----------------------------------------------------------------------===// + +def DISubrangeAttr : DialectAttribute<(attr + OptionalAttribute<"Attribute">:$count, + OptionalAttribute<"Attribute">:$lowerBound, + OptionalAttribute<"Attribute">:$upperBound, + OptionalAttribute<"Attribute">:$stride +)>; + +//===----------------------------------------------------------------------===// +// LoopAnnotationAttr +//===----------------------------------------------------------------------===// + +def LoopAnnotationAttr : DialectAttribute<(attr + OptionalAttribute<"BoolAttr">:$disableNonforced, + OptionalAttribute<"LoopVectorizeAttr">:$vectorize, + OptionalAttribute<"LoopInterleaveAttr">:$interleave, + OptionalAttribute<"LoopUnrollAttr">:$unroll, + OptionalAttribute<"LoopUnrollAndJamAttr">:$unrollAndJam, + OptionalAttribute<"LoopLICMAttr">:$licm, + OptionalAttribute<"LoopDistributeAttr">:$distribute, + OptionalAttribute<"LoopPipelineAttr">:$pipeline, + OptionalAttribute<"LoopPeeledAttr">:$peeled, + OptionalAttribute<"LoopUnswitchAttr">:$unswitch, + OptionalAttribute<"BoolAttr">:$mustProgress, + OptionalAttribute<"BoolAttr">:$isVectorized, + OptionalAttribute<"FusedLoc">:$startLoc, + OptionalAttribute<"FusedLoc">:$endLoc, + OptionalArrayRef<"AccessGroupAttr">:$parallelAccesses +)>; + +//===----------------------------------------------------------------------===// +// Attributes & Types with custom bytecode handling. +//===----------------------------------------------------------------------===// + +// All the attributes with custom bytecode handling. +def LLVMDialectAttributes : DialectAttributes<"LLVM"> { + let elems = [ + DIBasicTypeAttr, + DICompileUnitAttr, + DICompositeTypeAttr, + DIDerivedTypeAttr, + DIExpressionElemAttr, + DIExpressionAttr, + DIFileAttr, + DIGlobalVariableAttr, + DIGlobalVariableExpressionAttr, + DIImportedEntityAttr, + DILabelAttr, + DILexicalBlockAttr, + DILexicalBlockFileAttr, + DILocalVariableAttr, + DINamespaceAttr, + DISubprogramAttr, + DISubrangeAttr, + DISubroutineTypeAttr, + LoopAnnotationAttr + // Referenced attributes currently missing support: + // AccessGroupAttr, LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr, + // LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr, LoopPipelineAttr, + // LoopPeeledAttr, LoopUnswitchAttr + ]; +} + +def LLVMDialectTypes : DialectTypes<"LLVM"> { + let elems = []; +} + +#endif // LLVM_DIALECT_BYTECODE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9753dca..d0811a2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", custom<ShuffleType>(ref(type($v1)), type($res), ref($mask)) }]; + let hasFolder = 1; let hasVerifier = 1; string llvmInstName = "ShuffleVector"; @@ -1985,6 +1986,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ OptionalAttr<StrAttr>:$instrument_function_exit, OptionalAttr<UnitAttr>:$no_inline, OptionalAttr<UnitAttr>:$always_inline, + OptionalAttr<UnitAttr>:$inline_hint, OptionalAttr<UnitAttr>:$no_unwind, OptionalAttr<UnitAttr>:$will_return, OptionalAttr<UnitAttr>:$optimize_none, @@ -2037,6 +2039,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ /// Returns true if the `always_inline` attribute is set, false otherwise. bool isAlwaysInline() { return bool(getAlwaysInlineAttr()); } + /// Returns true if the `inline_hint` attribute is set, false otherwise. + bool isInlineHint() { return bool(getInlineHintAttr()); } + /// Returns true if the `optimize_none` attribute is set, false otherwise. bool isOptimizeNone() { return bool(getOptimizeNoneAttr()); } }]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 89fbeb7..d959464 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -263,6 +263,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda; + let hasVerifier = 1; // Backwards-compatibility builder for an unspecified range. let builders = [ @@ -279,6 +280,11 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = SetIntRangeFn setResultRanges) { nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges); } + + // Verify the range attribute satisfies LLVM ConstantRange constructor requirements. + ::llvm::LogicalResult $cppClass::verify() { + return verifyConstantRangeAttr(getOperation(), getRange()); + } }]; } @@ -1655,6 +1661,40 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> { }]; } +def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> { + let summary = "Convert a pair of float inputs to f4x2"; + let description = [{ + This Op converts each of the given float inputs to the specified fp4 type. + The result `dst` is returned as an i8 type where the converted values are + packed such that the value converted from `a` is stored in the upper 4 bits + of `dst` and the value converted from `b` is stored in the lower 4 bits of + `dst`. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let results = (outs I8:$dst); + let arguments = (ins F32:$a, F32:$b, + DefaultValuedAttr<BoolAttr, "false">:$relu, + TypeAttr:$dstTy); + let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args); + $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext())); + }]; +} + def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let summary = "Convert a pair of float inputs to f6x2"; let description = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 6925cec..d2df244 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -412,6 +412,32 @@ def ROCDL_WaitExpcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.expcnt", [], 0, [0], let assemblyFormat = "$count attr-dict"; } +def ROCDL_WaitAsynccntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.asynccnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until ASYNCCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + +def ROCDL_WaitTensorcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.tensorcnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until TENSORCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>, Arguments<(ins I16Attr:$priority)> { let assemblyFormat = "$priority attr-dict"; @@ -548,6 +574,30 @@ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_b def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>; def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>; +// Available from gfx1250 +def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>; +def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>; +def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>; +def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>; +def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>; +def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>; +def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>; +def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>; +def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>; +def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>; +def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>; //===---------------------------------------------------------------------===// // LDS transpose intrinsics (available in GFX950) @@ -1117,6 +1167,7 @@ foreach smallT = [ ScaleArgInfo<ROCDL_V16BF16Type, "Bf16">, ScaleArgInfo<ROCDL_V16F32Type, "F32">, ] in { + // Up-scaling def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op : ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name, [Pure], 1, [2], ["scaleSel"]>, @@ -1132,6 +1183,42 @@ foreach smallT = [ }]; } + + // Down-scaling + def ROCDL_CvtScaleF32Pk16 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk16." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name ; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $scale `:` type($res) + }]; + } + + def ROCDL_CvtScaleF32SrPk16 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk16." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name # " with stochastic rounding"; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so and apply stochastic rounding. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $seed `,` $scale `:` type($res) + }]; + } + } // foreach largeT } // foreach smallTOp diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index ae7a085..c89fc59 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -25,7 +25,6 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallSet.h" namespace mlir { namespace bufferization { @@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> -computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v, +computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v, AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes, const PadTilingInterfaceOptions &options); using PadSizeComputationFunction = std::function<FailureOr<SmallVector<OpFoldResult>>( - RewriterBase &, OpOperand &, ArrayRef<Range>, + OpBuilder &, OpOperand &, ArrayRef<Range>, const PadTilingInterfaceOptions &)>; /// Specific helper for Linalg ops. -FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape( - RewriterBase &rewriter, OpOperand &operandToPad, - ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options); +FailureOr<SmallVector<OpFoldResult>> +computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad, + ArrayRef<Range> iterationDomain, + const PadTilingInterfaceOptions &); + +/// Operations and values created in the process of padding a TilingInterface +/// operation. +struct PadTilingInterfaceResult { + /// The operands of the padded op. + SmallVector<tensor::PadOp> padOps; + /// The padded op, a clone of `toPad` with padded operands. + TilingInterface paddedOp; + /// Slices of the padded op's results, same types as `toPad`. + SmallVector<Value> replacements; +}; -/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`. -/// +/// Pad the iterator dimensions of `toPad`. /// * "options.paddingSizes" indicates that each padding dimension should be /// padded to the specified padding size. /// * "options.padToMultipleOf" indicates that the paddingSizes should be // interpreted as the bounding box (dynamic) value to pad to. /// * Use "options.paddingValues" to set the padding value of the created // tensor::PadOp. -/// * The tensor::PadOp is returned on success. - -FailureOr<TilingInterface> -rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, - const PadTilingInterfaceOptions &constOptions, - SmallVector<tensor::PadOp> &padOps, - const PadSizeComputationFunction &computePaddingSizeFun = +FailureOr<PadTilingInterfaceResult> +rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad, + PadTilingInterfaceOptions options, + const PadSizeComputationFunction & = &computeIndexingMapOpInterfacePaddedShape); namespace detail { diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index 30f33ed..69447f7 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -17,6 +17,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/InferStridedMetadataInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/MemOpInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 40b7d7e..b39207f 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" +include "mlir/Interfaces/InferStridedMetadataInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/MemOpInterfaces.td" include "mlir/Interfaces/MemorySlotInterfaces.td" @@ -184,6 +185,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ def DistinctObjectsOp : MemRef_Op<"distinct_objects", [ Pure, + DistinctObjectsTrait, DeclareOpInterfaceMethods<InferTypeOpInterface> // ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument ]> { @@ -2084,6 +2086,7 @@ def MemRef_StoreOp : MemRef_Op<"store", def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, + DeclareOpInterfaceMethods<InferStridedMetadataOpInterface>, DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>, DeclareOpInterfaceMethods<ViewLikeOpInterface>, AttrSizedOperandSegments, diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index 8f87235..b8aa497 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -183,6 +183,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() { return StringLiteral("acc.routine_info"); } +static constexpr StringLiteral getVarNameAttrName() { + return VarNameAttr::name; +} + static constexpr StringLiteral getCombinedConstructsAttrName() { return CombinedConstructsTypeAttr::name; } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 77e833f..1eaa21b4 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -415,6 +415,13 @@ def OpenACC_ConstructResource : Resource<"::mlir::acc::ConstructResource">; // Define a resource for the OpenACC current device setting. def OpenACC_CurrentDeviceIdResource : Resource<"::mlir::acc::CurrentDeviceIdResource">; +// Attribute for saving variable names - this can be attached to non-acc-dialect +// operations in order to ensure the name is preserved. +def OpenACC_VarNameAttr : OpenACC_Attr<"VarName", "var_name"> { + let parameters = (ins StringRefParameter<"">:$name); + let assemblyFormat = "`<` $name `>`"; +} + // Used for data specification in data clauses (2.7.1). // Either (or both) extent and upperbound must be specified. def OpenACC_DataBoundsOp : OpenACC_Op<"bounds", @@ -1316,6 +1323,24 @@ def OpenACC_PrivateRecipeOp }]; let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + /// Creates a PrivateRecipeOp and populates its regions based on the + /// variable type as long as the type implements MappableType or + /// PointerLikeType interface. If a type implements both, the MappableType + /// API will be preferred. Returns std::nullopt if the recipe cannot be + /// created or populated. The builder's current insertion point will be used + /// and it must be a valid place for this operation to be inserted. The + /// `recipeName` must be a unique name to prevent "redefinition of symbol" + /// IR errors. + static std::optional<PrivateRecipeOp> createAndPopulate( + ::mlir::OpBuilder &builder, + ::mlir::Location loc, + ::llvm::StringRef recipeName, + ::mlir::Type varType, + ::llvm::StringRef varName = "", + ::mlir::ValueRange bounds = {}); + }]; } //===----------------------------------------------------------------------===// @@ -1410,6 +1435,24 @@ def OpenACC_FirstprivateRecipeOp }]; let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + /// Creates a FirstprivateRecipeOp and populates its regions based on the + /// variable type as long as the type implements MappableType or + /// PointerLikeType interface. If a type implements both, the MappableType + /// API will be preferred. Returns std::nullopt if the recipe cannot be + /// created or populated. The builder's current insertion point will be used + /// and it must be a valid place for this operation to be inserted. The + /// `recipeName` must be a unique name to prevent "redefinition of symbol" + /// IR errors. + static std::optional<FirstprivateRecipeOp> createAndPopulate( + ::mlir::OpBuilder &builder, + ::mlir::Location loc, + ::llvm::StringRef recipeName, + ::mlir::Type varType, + ::llvm::StringRef varName = "", + ::mlir::ValueRange bounds = {}); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td index 0d16255..93e9e3d 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td @@ -73,17 +73,31 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { InterfaceMethod< /*description=*/[{ Generates allocation operations for the pointer-like type. It will create - an allocate that produces memory space for an instance of the current type. + an allocate operation that produces memory space for an instance of the + current type. The `varName` parameter is optional and can be used to provide a name - for the allocated variable. If the current type is represented - in a way that it does not capture the pointee type, `varType` must be - passed in to provide the necessary type information. + for the allocated variable. When provided, it must be used by the + implementation; and if the implementing dialect does not have its own + way to save it, the discardable `acc.var_name` attribute from the acc + dialect will be used. + + If the current type is represented in a way that it does not capture + the pointee type, `varType` must be passed in to provide the necessary + type information. The `originalVar` parameter is optional but enables support for dynamic types (e.g., dynamic memrefs). When provided, implementations can extract runtime dimension information from the original variable to create - allocations with matching dynamic sizes. + allocations with matching dynamic sizes. When generating recipe bodies, + `originalVar` should be the block argument representing the original + variable in the recipe region. + + The `needsFree` output parameter indicates whether the allocated memory + requires explicit deallocation. Implementations should set this to true + for heap allocations that need a matching deallocation operation (e.g., + alloc) and false for stack-based allocations (e.g., alloca). During + recipe generation, this determines whether a destroy region is created. Returns a Value representing the result of the allocation. If no value is returned, it means the allocation was not successfully generated. @@ -94,7 +108,8 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { "::mlir::Location":$loc, "::llvm::StringRef":$varName, "::mlir::Type":$varType, - "::mlir::Value":$originalVar), + "::mlir::Value":$originalVar, + "bool &":$needsFree), /*methodBody=*/"", /*defaultImplementation=*/[{ return {}; @@ -102,23 +117,34 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { >, InterfaceMethod< /*description=*/[{ - Generates deallocation operations for the pointer-like type. It deallocates - the instance provided. + Generates deallocation operations for the pointer-like type. - The `varPtr` parameter is required and must represent an instance that was - previously allocated. If the current type is represented in a way that it - does not capture the pointee type, `varType` must be passed in to provide - the necessary type information. Nothing is generated in case the allocate - is `alloca`-like. + The `varToFree` parameter is required and must represent an instance + that was previously allocated. When generating recipe bodies, this + should be the block argument representing the private variable in the + destroy region. + + The `allocRes` parameter is optional and provides the result of the + corresponding allocation from the init region. This allows implementations + to inspect the allocation operation to determine the appropriate + deallocation strategy. This is necessary because in recipe generation, + the allocation and deallocation occur in separate regions. Dialects that + use only one allocation type or can determine deallocation from type + information alone may ignore this parameter. + + The `varType` parameter must be provided if the current type does not + capture the pointee type information. No deallocation is generated for + stack-based allocations (e.g., alloca). - Returns true if deallocation was successfully generated or successfully - deemed as not needed to be generated, false otherwise. + Returns true if deallocation was successfully generated or determined to + be unnecessary, false otherwise. }], /*retTy=*/"bool", /*methodName=*/"genFree", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, - "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr, + "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varToFree, + "::mlir::Value":$allocRes, "::mlir::Type":$varType), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -274,6 +300,14 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> { The `initVal` can be empty - it is primarily needed for reductions to ensure the variable is also initialized with appropriate value. + The `needsDestroy` out-parameter is set by implementations to indicate + that destruction code must be generated after the returned private + variable usages, typically in the destroy region of recipe operations + (for example, when heap allocations or temporaries requiring cleanup + are created during initialization). When `needsDestroy` is set, callers + should invoke `generatePrivateDestroy` in the recipe's destroy region + with the privatized value returned by this method. + If the return value is empty, it means that recipe body was not successfully generated. }], @@ -284,12 +318,38 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> { "::mlir::TypedValue<::mlir::acc::MappableType>":$var, "::llvm::StringRef":$varName, "::mlir::ValueRange":$extents, - "::mlir::Value":$initVal), + "::mlir::Value":$initVal, + "bool &":$needsDestroy), /*methodBody=*/"", /*defaultImplementation=*/[{ return {}; }] >, + InterfaceMethod< + /*description=*/[{ + Generates destruction operations for a privatized value previously + produced by `generatePrivateInit`. This is typically inserted in a + recipe's destroy region, after all uses of the privatized value. + + The `privatized` value is the SSA value yielded by the init region + (and passed as the privatized argument to the destroy region). + Implementations should free heap-allocated storage or perform any + cleanup required for the given type. If no destruction is required, + this function should be a no-op and return `true`. + + Returns true if destruction was successfully generated or deemed not + necessary, false otherwise. + }], + /*retTy=*/"bool", + /*methodName=*/"generatePrivateDestroy", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::Value":$privatized), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return true; + }] + >, ]; } diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index 29b384f..5e68f75e 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -174,7 +174,7 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [ ``` The above returns two indices, `633` and `693`, which correspond to the index of the previous process `(1, 1, 3)`, and the next process - `(1, 3, 3) along the split axis `1`. + `(1, 3, 3)` along the split axis `1`. A negative value is returned if there is no neighbor in the respective direction along the given `split_axes`. @@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [ ]> { let summary = "All-gather over a device grid."; let description = [{ - Gathers along the `gather_axis` tensor axis. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and replicates the result across all devices + in the group. Example: ```mlir @@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [ SameOperandsAndResultShape]> { let summary = "All-reduce over a device grid."; let description = [{ - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + Each device in a group receives a replicated copy of the reduction result. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ SameOperandsAndResultElementType, SameOperandsAndResultRank ]> { - let summary = "All-slice over a device grid. This is the inverse of all-gather."; + let summary = "All-slice over a device grid."; let description = [{ - Slice along the `slice_axis` tensor axis. - This operation can be thought of as the inverse of all-gather. - Technically, it is not required that all processes have the same input tensor. - Each process will slice a piece of its local tensor based on its in-group device index. - The operation does not communicate data between devices. + Within each device group defined by `grid_axes`, slices the input tensor along + the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if + the input data is replicated along the `slice_axis`. + Each process simply crops its local data to the slice corresponding to its + in-group device index. + Notice: `AllSliceOp` does not involve any communication between devices and + devices within a group may not have replicated input data. Example: ```mlir @@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ ``` Result: ``` - gather tensor + slice tensor axis 1 ------------> +-------+-------+ @@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [ SameOperandsAndResultRank]> { let summary = "All-to-all over a device grid."; let description = [{ - Performs an all-to-all on tensor pieces split along `split_axis`. - The resulting pieces are concatenated along `concat_axis` on ech device. + Each participant logically splits its input along split_axis, + then scatters the resulting pieces across the group defined by `grid_axes`. + After receiving data pieces from other participants' scatters, + it concatenates them along concat_axis to produce the final result. Example: ``` @@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ ]> { let summary = "Broadcast over a device grid."; let description = [{ - Broadcast the tensor on `root` to all devices in each respective group. - The operation broadcasts along grid axes `grid_axes`. - The `root` device specifies the in-group multi-index that is broadcast to - all other devices in the group. + Copies the input tensor on `root` to all devices in each group defined by + `grid_axes`. The `root` device is defined by its in-group multi-index. + The contents of input tensors on non-root devices are ignored. Example: ``` @@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ +-------+-------+ | broadcast device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 +-------+-------+ ↓ - device (1, 0) -> | | | <- device (1, 1) + device (1, 0) -> | * * | * * | <- device (1, 1) +-------+-------+ ``` @@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [ ]> { let summary = "Gather over a device grid."; let description = [{ - Gathers on device `root` along the `gather_axis` tensor axis. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device for each device group. - The result tensor on non-root devices is undefined. - Using it will result in undefined behavior. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and returns the resulting tensor on each + `root` device. The result on all other (non-root) devices is undefined. + The `root` device is defined by its in-group multi-index. Example: ```mlir @@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [ ]> { let summary = "Send over a device grid."; let description = [{ - Receive from a device within a device group. + Receive tensor from device `source`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. + The content of input tensor is ignored. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [ ]> { let summary = "Reduce over a device grid."; let description = [{ - Reduces on device `root` within each device group. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device within its device group. - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + The reduction result will be returned on the `root` device of each group. + It is undefined on all other (non-root) devices. + The `root` device is defined by its in-group multi-index. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" SameOperandsAndResultRank]> { let summary = "Reduce-scatter over a device grid."; let description = [{ - After the reduction, the result is scattered within each device group. - The tensor is split along `scatter_axis` and the pieces distributed - across the device group. + Reduces the input tensor across all devices within the groups defined by + `grid_axes` using the specified reduction method. The reduction is performed + element-wise across the tensor pieces from all devices in the group. + After reduction, the reduction result is scattered (split and distributed) + across the device group along `scatter_axis`. Example: ``` shard.grid @grid0(shape = 2x2) ... %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] reduction = <max> scatter_axis = 0 - : tensor<3x4xf32> -> tensor<1x4xf64> + : tensor<2x2xf32> -> tensor<1x2xf64> ``` Input: ``` @@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" Result: ``` +-------+ - | 6 8 | <- devices (0, 0) + | 5 6 | <- devices (0, 0) +-------+ - | 10 12 | <- devices (0, 1) + | 7 8 | <- devices (0, 1) +-------+ - | 22 24 | <- devices (1, 0) + | 13 14 | <- devices (1, 0) +-------+ - | 26 28 | <- devices (1, 1) + | 15 16 | <- devices (1, 1) +-------+ ``` }]; @@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ ]> { let summary = "Scatter over a device grid."; let description = [{ - For each device group split the input tensor on the `root` device along - axis `scatter_axis` and scatter the parts across the group devices. + For each device group defined by `grid_axes`, the input tensor on the `root` + device is split along axis `scatter_axis` and distributed across the group. + The content of the input on all other (non-root) devices is ignored. + The `root` device is defined by its in-group multi-index. Example: ``` @@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ (0, 1) ↓ +-------+-------+ | scatter tensor - device (0, 0) -> | | | | axis 0 - | | | ↓ + device (0, 0) -> | * * | * * | | axis 0 + | * * | * * | ↓ +-------+-------+ device (1, 0) -> | 1 2 | 5 6 | | 3 4 | 7 8 | @@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [ ]> { let summary = "Send over a device grid."; let description = [{ - Send from one device to another within a device group. + Send input tensor to device `destination`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [ ]> { let summary = "Shift over a device grid."; let description = [{ - Within each device group shift along grid axis `shift_axis` by an offset - `offset`. - The result on devices that do not have a corresponding source is undefined. - `shift_axis` must be one of `grid_axes`. - If the `rotate` attribute is present, - instead of a shift a rotation is done. + Within each device group defined by `grid_axes`, shifts input tensors along the + device grid's axis `shift_axis` by the specified offset. The `shift_axis` must + be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular. + That is, the offset wraps around according to the group size along `shift_axis`. + Otherwise, the results on devices without a corresponding source are undefined. Example: ``` diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 10491f6..4ecf03c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); /// returned by getDefaultTargetEnv() if not provided. TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); +/// A thin wrapper around the SpecificationVersion enum to represent +/// and provide utilities around the TOSA specification version. +class TosaSpecificationVersion { +public: + TosaSpecificationVersion(uint32_t major, uint32_t minor) + : majorVersion(major), minorVersion(minor) {} + TosaSpecificationVersion(SpecificationVersion version) + : TosaSpecificationVersion(fromVersionEnum(version)) {} + + bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const { + return this->majorVersion == baseVersion.majorVersion && + this->minorVersion >= baseVersion.minorVersion; + } + + uint32_t getMajor() const { return majorVersion; } + uint32_t getMinor() const { return minorVersion; } + +private: + uint32_t majorVersion = 0; + uint32_t minorVersion = 0; + + static TosaSpecificationVersion + fromVersionEnum(SpecificationVersion version) { + switch (version) { + case SpecificationVersion::V_1_0: + return TosaSpecificationVersion(1, 0); + case SpecificationVersion::V_1_1_DRAFT: + return TosaSpecificationVersion(1, 1); + } + llvm_unreachable("Unknown TOSA version"); + } +}; + +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version); + /// This class represents the capability enabled in the target implementation /// such as profile, extension, and level. It's a wrapper class around /// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles, + explicit TargetEnv(SpecificationVersion specificationVersion, Level level, + const ArrayRef<Profile> &profiles, const ArrayRef<Extension> &extensions) - : level(level) { + : specificationVersion(specificationVersion), level(level) { enabledProfiles.insert_range(profiles); enabledExtensions.insert_range(extensions); } explicit TargetEnv(TargetEnvAttr targetAttr) - : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), - targetAttr.getExtensions()) {} + : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()) {} void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } - // TODO implement the following utilities. - // Version getSpecVersion() const; + SpecificationVersion getSpecVersion() const { return specificationVersion; } TosaLevel getLevel() const { if (level == Level::eightK) @@ -105,6 +140,7 @@ public: } private: + SpecificationVersion specificationVersion; Level level; llvm::SmallSet<Profile, 3> enabledProfiles; llvm::SmallSet<Extension, 13> enabledExtensions; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index 1f718ac..c1b5e78 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -2,441 +2,779 @@ // `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git profileComplianceMap = { {"tosa.argmax", - {{{Profile::pro_int}, {{i8T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp32T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.max_pool2d", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.add", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.arithmetic_right_shift", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_and", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_or", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_xor", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.intdiv", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_and", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_left_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_right_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_or", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_xor", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.maximum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.minimum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.mul", - {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}}, - {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pow", - {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.sub", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, - {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}}, {"tosa.abs", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_not", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}}, - {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}}, - {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.clz", + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.logical_not", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.negate", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reciprocal", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.select", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater_equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_all", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_any", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_max", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_min", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_product", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_sum", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", {{{Profile::pro_int}, - {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}}, + {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", {{{Profile::pro_int}, - {{i8T, i32T, i8T, i8T}, - {i16T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}}, + {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Profile::pro_int}, - {{boolT, i8T}, - {boolT, i16T}, - {boolT, i32T}, - {i8T, boolT}, - {i8T, i16T}, - {i8T, i32T}, - {i16T, boolT}, - {i16T, i8T}, - {i16T, i32T}, - {i32T, boolT}, - {i32T, i8T}, - {i32T, i16T}}}, - {{Profile::pro_fp}, - {{i8T, fp16T}, - {i8T, fp32T}, - {i16T, fp16T}, - {i16T, fp32T}, - {i32T, fp16T}, - {i32T, fp32T}, - {fp16T, i8T}, - {fp16T, i16T}, - {fp16T, i32T}, - {fp16T, fp32T}, - {fp32T, i8T}, - {fp32T, i16T}, - {fp32T, i32T}, - {fp32T, fp16T}}}}}, + {{{boolT, i8T}, SpecificationVersion::V_1_0}, + {{boolT, i16T}, SpecificationVersion::V_1_0}, + {{boolT, i32T}, SpecificationVersion::V_1_0}, + {{i8T, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, boolT}, SpecificationVersion::V_1_0}, + {{i16T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, boolT}, SpecificationVersion::V_1_0}, + {{i32T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{i8T, fp16T}, SpecificationVersion::V_1_0}, + {{i8T, fp32T}, SpecificationVersion::V_1_0}, + {{i16T, fp16T}, SpecificationVersion::V_1_0}, + {{i16T, fp32T}, SpecificationVersion::V_1_0}, + {{i32T, fp16T}, SpecificationVersion::V_1_0}, + {{i32T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, i8T}, SpecificationVersion::V_1_0}, + {{fp16T, i16T}, SpecificationVersion::V_1_0}, + {{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, i8T}, SpecificationVersion::V_1_0}, + {{fp32T, i16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.rescale", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i8T, i8T, i16T, i16T}, - {i8T, i8T, i32T, i32T}, - {i16T, i16T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i16T, i16T, i32T, i32T}, - {i32T, i32T, i8T, i8T}, - {i32T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT}, {i8T}, {i16T}, {i32T}}, + {{{boolT}, SpecificationVersion::V_1_0}, + {{i8T}, SpecificationVersion::V_1_0}, + {{i16T}, SpecificationVersion::V_1_0}, + {{i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; extensionComplianceMap = { {"tosa.argmax", - {{{Extension::int16}, {{i16T, i32T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}}, - {{Extension::bf16}, {{bf16T, i32T}}}}}, + {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}, + SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}, + SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, - {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, + {"tosa.fft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, - {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, - {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e4m3, Extension::fp8e5m2}, - {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, - {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}}, + {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}, allOf}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.max_pool2d", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rfft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}}, - {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}}, - {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.add", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.maximum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.minimum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.mul", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.pow", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sub", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Extension::int16}, + {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.abs", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.negate", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reciprocal", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.select", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater_equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_max", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_min", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_product", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_sum", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i48T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Extension::bf16}, - {{i8T, bf16T}, - {i16T, bf16T}, - {i32T, bf16T}, - {bf16T, i8T}, - {bf16T, i16T}, - {bf16T, i32T}, - {bf16T, fp32T}, - {fp32T, bf16T}}}, + {{{i8T, bf16T}, SpecificationVersion::V_1_0}, + {{i16T, bf16T}, SpecificationVersion::V_1_0}, + {{i32T, bf16T}, SpecificationVersion::V_1_0}, + {{bf16T, i8T}, SpecificationVersion::V_1_0}, + {{bf16T, i16T}, SpecificationVersion::V_1_0}, + {{bf16T, i32T}, SpecificationVersion::V_1_0}, + {{bf16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, bf16T}, SpecificationVersion::V_1_0}}}, {{Extension::bf16, Extension::fp8e4m3}, - {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}}, + {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::bf16, Extension::fp8e5m2}, - {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}}, + {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp16T}, - {fp8e4m3T, fp32T}, - {fp16T, fp8e4m3T}, - {fp32T, fp8e4m3T}}}, + {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp16T}, - {fp8e5m2T, fp32T}, - {fp16T, fp8e5m2T}, - {fp32T, fp8e5m2T}}}}}, + {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}}, {"tosa.rescale", {{{Extension::int16}, - {{i48T, i48T, i8T, i8T}, - {i48T, i48T, i16T, i16T}, - {i48T, i48T, i32T, i32T}}}}}, + {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", - {{{Extension::int4}, {{i4T}}}, - {{Extension::int16}, {{i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T}}}}}, + {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", - {{{Extension::int4}, {{i4T, i4T}}}, - {{Extension::int16}, {{i48T, i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.variable", + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; + // End of auto-generated metadata diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 38cb293..8376a4c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic, } //===----------------------------------------------------------------------===// -// TOSA Spec Section 1.5. +// TOSA Profiles and extensions // // Profile: // INT : Integer Inference. Integer operations, primarily 8 and 32-bit values. @@ -293,12 +293,6 @@ def Tosa_ExtensionAttr def Tosa_ExtensionArrayAttr : TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">; -def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; -def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; - -def Tosa_LevelAttr - : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; - // The base class for defining op availability dimensions. class Availability { // The following are fields for controlling the generated C++ OpInterface. @@ -405,17 +399,40 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability { } //===----------------------------------------------------------------------===// +// TOSA Levels +//===----------------------------------------------------------------------===// + +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; + +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; + +//===----------------------------------------------------------------------===// +// TOSA Specification versions +//===----------------------------------------------------------------------===// + +def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">; +def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">; + +def Tosa_SpecificationVersion : Tosa_I32EnumAttr< + "SpecificationVersion", "TOSA specification version", "specification_version", + [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>; + +//===----------------------------------------------------------------------===// // TOSA target environment. //===----------------------------------------------------------------------===// def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { let summary = "Target environment information."; let parameters = ( ins + "SpecificationVersion": $specification_version, "Level": $level, ArrayRefParameter<"Profile">: $profiles, ArrayRefParameter<"Extension">: $extensions ); - let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` " + "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " "`extensions` `=` `[` $extensions `]` `>`"; } diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 8f5c72b..7b946ad 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -36,12 +36,15 @@ enum CheckCondition { allOf }; +using VersionedTypeInfo = + std::pair<SmallVector<TypeInfo>, SpecificationVersion>; + template <typename T> struct OpComplianceInfo { // Certain operations require multiple modes enabled. // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3. SmallVector<T> mode; - SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet; + SmallVector<VersionedTypeInfo> operandTypeInfoSet; CheckCondition condition = CheckCondition::anyOf; }; @@ -130,9 +133,8 @@ public: // Find the required profiles or extensions from the compliance info according // to the operand type combination. template <typename T> - SmallVector<T> findMatchedProfile(Operation *op, - SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition); + OpComplianceInfo<T> + findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo); SmallVector<Profile> getCooperativeProfiles(Extension ext) { switch (ext) { @@ -168,8 +170,7 @@ public: private: template <typename T> - FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op, - CheckCondition &condition); + FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op); OperationProfileComplianceMap profileComplianceMap; OperationExtensionComplianceMap extensionComplianceMap; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 6ae19d8..14b00b0 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { ]; let options = [ + Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion", + /*default=*/"mlir::tosa::SpecificationVersion::V_1_0", + "The specification version that TOSA operators should conform to.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"), + clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft") + )}]>, Option<"level", "level", "mlir::tosa::Level", /*default=*/"mlir::tosa::Level::eightK", "The TOSA level that operators should conform to. A TOSA level defines " diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6e79085..6e15b1e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [ }]; let results = (outs VectorOfRankAndType<[1], [Index]>:$result); let assemblyFormat = "attr-dict `:` type($result)"; + let hasCanonicalizer = 1; } def Vector_YieldOp : Vector_Op<"yield", [ diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td index b80ee2c..e9425e8 100644 --- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td +++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td @@ -43,9 +43,41 @@ class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> : let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target"; } -def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {} +def WasmSSA_BlockOp : WasmSSA_BlockLikeOp< + "block", + "Create a nesting level with a label at its exit."> { + let description = [{ + Defines a Wasm block, creating a new nested scope. + A block contains a body region and an optional list of input values. + Control can enter the block and later branch out to the block target. + Example: + + ```mlir + + wasmssa.block { + + // instructions + + } > ^successor + }]; +} + +def WasmSSA_LoopOp : WasmSSA_BlockLikeOp< + "loop", + "Create a nesting level that define its entry as jump target."> { + let description = [{ + Represents a Wasm loop construct. This defines a nesting level with + a label at the entry of the region. -def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {} + Example: + + ```mlir + + wasmssa.loop { + + } > ^successor + }]; +} def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator, DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> { @@ -55,9 +87,16 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator, ::mlir::Block* getTarget(); }]; let description = [{ - Marks a return from the current block. + Escape from the current nesting level and return the control flow to its successor. + Optionally, mark the arguments that should be transfered to the successor block. - Example: + This shouldn't be confused with branch operations that targets the label defined + by the nesting level operation. + + For instance, a `wasmssa.block_return` in a loop will give back control to the + successor of the loop, where a `branch` targeting the loop will flow back to the entry block of the loop. + + Example: ```mlir wasmssa.block_return @@ -127,12 +166,18 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ - Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type in the function type. + By default, `wasmssa.func` have nested visibility. Functions exported by the module + are marked with the exported attribute. This gives them public visibility. + Example: ```mlir - // A simple function with no arguments that returns a float32 + // Internal function with no arguments that returns a float32 wasmssa.func @my_f32_func() -> f32 + // Exported function with no arguments that returns a float32 + wasmssa.func exported @my_f32_func() -> f32 + // A function that takes a local ref argument wasmssa.func @i64_wrap(%a: !wasmssa<local ref to i64>) -> i32 ``` @@ -141,7 +186,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ WasmSSA_FuncTypeAttr: $functionType, OptionalAttr<DictArrayAttr>:$arg_attrs, OptionalAttr<DictArrayAttr>:$res_attrs, - DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility); + UnitAttr: $exported); let regions = (region AnyRegion: $body); let extraClassDeclaration = [{ @@ -162,6 +207,12 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ /// Returns the result types of this function. ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); } + + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; }]; let builders = [ @@ -207,8 +258,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [ StrAttr: $importName, WasmSSA_FuncTypeAttr: $type, OptionalAttr<DictArrayAttr>:$arg_attrs, - OptionalAttr<DictArrayAttr>:$res_attrs, - OptionalAttr<StrAttr>:$sym_visibility); + OptionalAttr<DictArrayAttr>:$res_attrs); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } @@ -221,6 +271,10 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [ ::llvm::ArrayRef<Type> getResultTypes() { return getType().getResults(); } + + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; let builders = [ OpBuilder<(ins "StringRef":$symbol, @@ -238,30 +292,41 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [ let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_ValTypeAttr: $type, UnitAttr: $isMutable, - OptionalAttr<StrAttr>:$sym_visibility); + UnitAttr: $exported); let description = [{ WebAssembly global variable. Body contains the initialization instructions for the variable value. The body must contain only instructions considered `const` in a webassembly context, such as `wasmssa.const` or `global.get`. + By default, `wasmssa.global` have nested visibility. Global exported by the module + are marked with the exported attribute. This gives them public visibility. + Example: ```mlir - // Define a global_var, a mutable i32 global variable equal to 10. - wasmssa.global @global_var i32 mutable nested : { + // Define module_global_var, an internal mutable i32 global variable equal to 10. + wasmssa.global @module_global_var i32 mutable : { %[[VAL_0:.*]] = wasmssa.const 10 : i32 wasmssa.return %[[VAL_0]] : i32 } + + // Define global_var, an exported constant i32 global variable equal to 42. + wasmssa.global @global_var i32 : { + %[[VAL_0:.*]] = wasmssa.const 42 : i32 + wasmssa.return %[[VAL_0]] : i32 + } ``` }]; let regions = (region AnyRegion: $initializer); - let builders = [ - OpBuilder<(ins "StringRef":$symbol, - "Type": $type, - "bool": $isMutable)> - ]; + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; let hasCustomAssemblyFormat = 1; } @@ -283,18 +348,14 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [ StrAttr: $moduleName, StrAttr: $importName, WasmSSA_ValTypeAttr: $type, - UnitAttr: $isMutable, - OptionalAttr<StrAttr>:$sym_visibility); + UnitAttr: $isMutable); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } + + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; - let builders = [ - OpBuilder<(ins "StringRef":$symbol, - "StringRef":$moduleName, - "StringRef":$importName, - "Type": $type, - "bool": $isMutable)> - ]; let hasCustomAssemblyFormat = 1; } @@ -442,23 +503,33 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> { Define a memory to be used by the program. Multiple memories can be defined in the same module. + By default, `wasmssa.memory` have nested visibility. Memory exported by + the module are marked with the exported attribute. This gives them public + visibility. + Example: ```mlir - // Define the `mem_0` memory with defined bounds of 0 -> 65536 + // Define the `mem_0` (internal) memory with defined size bounds of [0:65536] wasmssa.memory @mem_0 !wasmssa<limit[0:65536]> + + // Define the `mem_1` exported memory with minimal size of 512 + wasmssa.memory exported @mem_1 !wasmssa<limit[512:]> ``` }]; let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_LimitTypeAttr: $limits, - OptionalAttr<StrAttr>:$sym_visibility); - let builders = [ - OpBuilder<(ins - "::llvm::StringRef":$symbol, - "wasmssa::LimitType":$limit)> - ]; + UnitAttr: $exported); - let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict"; + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; + + let assemblyFormat = "(`exported` $exported^)? $sym_name $limits attr-dict"; } def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> { @@ -476,16 +547,13 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> let arguments = (ins SymbolNameAttr: $sym_name, StrAttr: $moduleName, StrAttr: $importName, - WasmSSA_LimitTypeAttr: $limits, - OptionalAttr<StrAttr>:$sym_visibility); + WasmSSA_LimitTypeAttr: $limits); let extraClassDeclaration = [{ - bool isDeclaration() const { return true; } + bool isDeclaration() const { return true; } + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "::llvm::StringRef":$moduleName, - "::llvm::StringRef":$importName, - "wasmssa::LimitType":$limits)>]; let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict"; } @@ -493,11 +561,15 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> { let summary= "WebAssembly table value"; let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_TableTypeAttr: $type, - OptionalAttr<StrAttr>:$sym_visibility); - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "wasmssa::TableType":$type)>]; - let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict"; + UnitAttr: $exported); + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; + let assemblyFormat = "(`exported` $exported^)? $sym_name $type attr-dict"; } def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> { @@ -515,17 +587,14 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterfac let arguments = (ins SymbolNameAttr: $sym_name, StrAttr: $moduleName, StrAttr: $importName, - WasmSSA_TableTypeAttr: $type, - OptionalAttr<StrAttr>:$sym_visibility); + WasmSSA_TableTypeAttr: $type); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict"; - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "::llvm::StringRef":$moduleName, - "::llvm::StringRef":$importName, - "wasmssa::TableType":$type)>]; } def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> { diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 5695d5d..19a5231 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().contains(name); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { return getAttrs().getAs<ArrayAttr>("stride"); } + ArrayAttr getBlockAttr() { + return getAttrs().getAs<ArrayAttr>("block"); + } + }]; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73f9061..426377f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, } def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, - AllElementTypesMatch<["mem_desc", "res"]>, - AllRanksMatch<["mem_desc", "res"]>]> { + AllElementTypesMatch<["mem_desc", "res"]>]> { let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic<Index>: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr<UnitAttr>:$subgroup_block_io, OptionalAttr<DistributeLayoutAttr>:$layout ); - let results = (outs XeGPU_ValueType:$res); + let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); let assemblyFormat = [{ $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands) `->` type(results) @@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, Arguments: - `mem_desc`: the memory descriptor identifying the SLM region. - `offsets`: the coordinates within the matrix to read from. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block load. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1336,7 +1339,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } ArrayRef<int64_t> getDataShape() { - return getRes().getType().getShape(); + auto resTy = getRes().getType(); + if (auto vecTy = llvm::dyn_cast<VectorType>(resTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1344,13 +1350,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - AllElementTypesMatch<["mem_desc", "data"]>, - AllRanksMatch<["mem_desc", "data"]>]> { + AllElementTypesMatch<["mem_desc", "data"]>]> { let arguments = (ins - XeGPU_ValueType:$data, + AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, XeGPU_MemDesc:$mem_desc, Variadic<Index>: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr<UnitAttr>:$subgroup_block_io, OptionalAttr<DistributeLayoutAttr>:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) @@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - `mem_desc`: the memory descriptor specifying the SLM region. - `offsets`: the coordinates within the matrix where the data will be written. - `data`: the values to be stored in the matrix. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block store. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1378,7 +1387,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, } ArrayRef<int64_t> getDataShape() { - return getData().getType().getShape(); + auto DataTy = getData().getType(); + if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1386,41 +1398,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, let hasVerifier = 1; } -def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview", - [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> { - let description = [{ - Creates a subview of a memory descriptor. The resulting memory descriptor can have - a lower rank than the source; in this case, the result dimensions correspond to the - higher-order dimensions of the source memory descriptor. - - Arguments: - - `src` : a memory descriptor. - - `offsets` : the coordinates within the matrix the subview will be created from. - - Results: - - `res` : a memory descriptor with smaller size. - - }]; - let arguments = (ins XeGPU_MemDesc:$src, - Variadic<Index>:$offsets, - DenseI64ArrayAttr:$const_offsets); - let results = (outs XeGPU_MemDesc:$res); - let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict - attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}]; - let builders = [ - OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)> - ]; - - let extraClassDeclaration = [{ - mlir::Value getViewSource() { return getSrc(); } - - SmallVector<OpFoldResult> getMixedOffsets() { - return getMixedValues(getConstOffsets(), getOffsets(), getContext()); - } - }]; - - let hasVerifier = 1; -} - - #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84902b2..b1196fb 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,12 +237,11 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { - return layout.getStrides(); + return layout.getStrideAttr(); } - // derive and return default strides SmallVector<int64_t> defaultStrides; llvm::append_range(defaultStrides, getShape().drop_front()); @@ -250,6 +249,63 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m Builder builder(getContext()); return builder.getI64ArrayAttr(defaultStrides); } + + ArrayAttr getBlockAttr() { + auto layout = getMemLayout(); + if (layout && layout.hasAttr("block")) { + return layout.getBlockAttr(); + } + Builder builder(getContext()); + return builder.getI64ArrayAttr({}); + } + + /// Heuristic to determine if the MemDesc uses column-major layout, + /// based on the rank and the value of the first stride dimension. + bool isColMajor() { + auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]); + return getRank() == 2 && dim0.getInt() == 1; + } + + // Get the Blocking shape for a MemDescType, Which is represented + // as an attribute in MemDescType. By default it is the shape + // of the mdescTy + SmallVector<int64_t> getBlockShape() { + SmallVector<int64_t> size(getShape()); + ArrayAttr blockAttr = getBlockAttr(); + if (!blockAttr.empty()) { + size.clear(); + for (auto attr : blockAttr.getValue()) { + size.push_back(cast<IntegerAttr>(attr).getInt()); + } + } + return size; + } + + // Get strides as vector of integer. + // If it contains block attribute, the strides are blocked strides. + // + // The blocking is applied to the base matrix shape derived from the + // memory descriptor's stride information. If the matrix described by + // the memory descriptor is not contiguous, it is assumed that the base + // matrix is contiguous and follows the same memory layout. + // + // It first computes the original matrix shape using the stride info, + // then computes the number of blocks in each dimension of original shape, + // then compute the outer block shape and stride, + // then combines the inner and outer block shape and stride + // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>` + // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) + // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1] + // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) + SmallVector<int64_t> getStrideShape(); + + /// Generates instructions to compute the linearize offset + // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout + // the strides of memory descriptor is always considered regardless of blocked or not + Value getLinearOffsets(OpBuilder &builder, + Location loc, ArrayRef<OpFoldResult> offsets); + + }]; let hasCustomAssemblyFormat = true; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 6b4e3dd..8427ba5 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -623,6 +623,14 @@ class VectorOfLengthAndType<list<int> allowedLengths, VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary, "::mlir::VectorType">; +class FixedVectorOfShapeAndType<list<int> shape, Type elType>: ShapedContainerType< + [elType], + And<[IsVectorOfShape<shape>, IsFixedVectorOfAnyRankTypePred]>, + "vector<" # !interleave(shape, "x") # "x" # elType # ">", + "::mlir::VectorType">, + BuildableType<"::mlir::VectorType::get({" # !interleave(shape, " ,") # "} , " # elType.builderCall # " );">; + + // Any fixed-length vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class FixedVectorOfLengthAndType<list<int> allowedLengths, diff --git a/mlir/include/mlir/IR/Remarks.h b/mlir/include/mlir/IR/Remarks.h index 20e84ec..9877926 100644 --- a/mlir/include/mlir/IR/Remarks.h +++ b/mlir/include/mlir/IR/Remarks.h @@ -18,7 +18,6 @@ #include "llvm/Remarks/Remark.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Regex.h" -#include <optional> #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" @@ -144,7 +143,7 @@ public: llvm::StringRef getCategoryName() const { return categoryName; } - llvm::StringRef getFullCategoryName() const { + llvm::StringRef getCombinedCategoryName() const { if (categoryName.empty() && subCategoryName.empty()) return {}; if (subCategoryName.empty()) @@ -318,7 +317,7 @@ private: }; //===----------------------------------------------------------------------===// -// MLIR Remark Streamer +// Pluggable Remark Utilities //===----------------------------------------------------------------------===// /// Base class for MLIR remark streamers that is used to stream @@ -338,6 +337,26 @@ public: virtual void finalize() {} // optional }; +using ReportFn = llvm::unique_function<void(const Remark &)>; + +/// Base class for MLIR remark emitting policies that is used to emit +/// optimization remarks to the underlying remark streamer. The derived classes +/// should implement the `reportRemark` method to provide the actual emitting +/// implementation. +class RemarkEmittingPolicyBase { +protected: + ReportFn reportImpl; + +public: + RemarkEmittingPolicyBase() = default; + virtual ~RemarkEmittingPolicyBase() = default; + + void initialize(ReportFn fn) { reportImpl = std::move(fn); } + + virtual void reportRemark(const Remark &remark) = 0; + virtual void finalize() = 0; +}; + //===----------------------------------------------------------------------===// // Remark Engine (MLIR Context will own this class) //===----------------------------------------------------------------------===// @@ -355,6 +374,8 @@ private: std::optional<llvm::Regex> failedFilter; /// The MLIR remark streamer that will be used to emit the remarks. std::unique_ptr<MLIRRemarkStreamerBase> remarkStreamer; + /// The MLIR remark policy that will be used to emit the remarks. + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy; /// When is enabled, engine also prints remarks as mlir::emitRemarks. bool printAsEmitRemarks = false; @@ -392,6 +413,8 @@ private: InFlightRemark emitIfEnabled(Location loc, RemarkOpts opts, bool (RemarkEngine::*isEnabled)(StringRef) const); + /// Report a remark. + void reportImpl(const Remark &remark); public: /// Default constructor is deleted, use the other constructor. @@ -407,8 +430,15 @@ public: ~RemarkEngine(); /// Setup the remark engine with the given output path and format. - LogicalResult initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, - std::string *errMsg); + LogicalResult + initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy, + std::string *errMsg); + + /// Get the remark emitting policy. + RemarkEmittingPolicyBase *getRemarkEmittingPolicy() const { + return remarkEmittingPolicy.get(); + } /// Report a remark. void report(const Remark &&remark); @@ -446,6 +476,46 @@ inline InFlightRemark withEngine(Fn fn, Location loc, Args &&...args) { namespace mlir::remark { +//===----------------------------------------------------------------------===// +// Remark Emitting Policies +//===----------------------------------------------------------------------===// + +/// Policy that emits all remarks. +class RemarkEmittingPolicyAll : public detail::RemarkEmittingPolicyBase { +public: + RemarkEmittingPolicyAll(); + + void reportRemark(const detail::Remark &remark) override { + assert(reportImpl && "reportImpl is not set"); + reportImpl(remark); + } + void finalize() override {} +}; + +/// Policy that emits final remarks. +class RemarkEmittingPolicyFinal : public detail::RemarkEmittingPolicyBase { +private: + /// user can intercept them for custom processing via a registered callback, + /// otherwise they will be reported on engine destruction. + llvm::DenseSet<detail::Remark> postponedRemarks; + +public: + RemarkEmittingPolicyFinal(); + + void reportRemark(const detail::Remark &remark) override { + postponedRemarks.erase(remark); + postponedRemarks.insert(remark); + } + + void finalize() override { + assert(reportImpl && "reportImpl is not set"); + for (auto &remark : postponedRemarks) { + if (reportImpl) + reportImpl(remark); + } + } +}; + /// Create a Reason with llvm::formatv formatting. template <class... Ts> inline detail::LazyTextBuild reason(const char *fmt, Ts &&...ts) { @@ -505,16 +575,72 @@ inline detail::InFlightRemark analysis(Location loc, RemarkOpts opts) { /// Setup remarks for the context. This function will enable the remark engine /// and set the streamer to be used for optimization remarks. The remark -/// categories are used to filter the remarks that will be emitted by the remark -/// engine. If a category is not specified, it will not be emitted. If +/// categories are used to filter the remarks that will be emitted by the +/// remark engine. If a category is not specified, it will not be emitted. If /// `printAsEmitRemarks` is true, the remarks will be printed as /// mlir::emitRemarks. 'streamer' must inherit from MLIRRemarkStreamerBase and /// will be used to stream the remarks. LogicalResult enableOptimizationRemarks( MLIRContext &ctx, std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer, + std::unique_ptr<remark::detail::RemarkEmittingPolicyBase> + remarkEmittingPolicy, const remark::RemarkCategories &cats, bool printAsEmitRemarks = false); } // namespace mlir::remark +// DenseMapInfo specialization for Remark +namespace llvm { +template <> +struct DenseMapInfo<mlir::remark::detail::Remark> { + static constexpr StringRef kEmptyKey = "<EMPTY_KEY>"; + static constexpr StringRef kTombstoneKey = "<TOMBSTONE_KEY>"; + + /// Helper to provide a static dummy context for sentinel keys. + static mlir::MLIRContext *getStaticDummyContext() { + static mlir::MLIRContext dummyContext; + return &dummyContext; + } + + /// Create an empty remark + static inline mlir::remark::detail::Remark getEmptyKey() { + return mlir::remark::detail::Remark( + mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note, + mlir::UnknownLoc::get(getStaticDummyContext()), + mlir::remark::RemarkOpts::name(kEmptyKey)); + } + + /// Create a dead remark + static inline mlir::remark::detail::Remark getTombstoneKey() { + return mlir::remark::detail::Remark( + mlir::remark::RemarkKind::RemarkUnknown, mlir::DiagnosticSeverity::Note, + mlir::UnknownLoc::get(getStaticDummyContext()), + mlir::remark::RemarkOpts::name(kTombstoneKey)); + } + + /// Compute the hash value of the remark + static unsigned getHashValue(const mlir::remark::detail::Remark &remark) { + return llvm::hash_combine( + remark.getLocation().getAsOpaquePointer(), + llvm::hash_value(remark.getRemarkName()), + llvm::hash_value(remark.getCombinedCategoryName())); + } + + static bool isEqual(const mlir::remark::detail::Remark &lhs, + const mlir::remark::detail::Remark &rhs) { + // Check for empty/tombstone keys first + if (lhs.getRemarkName() == kEmptyKey || + lhs.getRemarkName() == kTombstoneKey || + rhs.getRemarkName() == kEmptyKey || + rhs.getRemarkName() == kTombstoneKey) { + return lhs.getRemarkName() == rhs.getRemarkName(); + } + + // For regular remarks, compare key identifying fields + return lhs.getLocation() == rhs.getLocation() && + lhs.getRemarkName() == rhs.getRemarkName() && + lhs.getCombinedCategoryName() == rhs.getCombinedCategoryName(); + } +}; +} // namespace llvm #endif // MLIR_IR_REMARKS_H diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index a5feb59..72ed046 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(FunctionInterfaces) add_mlir_interface(IndexingMapOpInterface) add_mlir_interface(InferIntRangeInterface) +add_mlir_interface(InferStridedMetadataInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(MemOpInterfaces) diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h index 0e107e8..a6de3d1 100644 --- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -117,7 +117,8 @@ public: IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} /// Create an integer value range lattice value. - IntegerValueRange(std::optional<ConstantIntRanges> value = std::nullopt) + explicit IntegerValueRange( + std::optional<ConstantIntRanges> value = std::nullopt) : value(std::move(value)) {} /// Whether the range is uninitialized. This happens when the state hasn't @@ -167,6 +168,15 @@ using SetIntRangeFn = using SetIntLatticeFn = llvm::function_ref<void(Value, const IntegerValueRange &)>; +/// Helper callback type to get the integer range of a value. +using GetIntRangeFn = function_ref<IntegerValueRange(Value)>; + +/// Helper function to collect the integer range values of an array of op fold +/// results. +SmallVector<IntegerValueRange> getIntValueRanges(ArrayRef<OpFoldResult> values, + GetIntRangeFn getIntRange, + int32_t indexBitwidth); + class InferIntRangeInterface; namespace intrange::detail { diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h new file mode 100644 index 0000000..0c572e0 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.h @@ -0,0 +1,145 @@ +//===- InferStridedMetadataInterface.h - Strided Metadata Inference -C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the strided metadata inference interface +// defined in `InferStridedMetadataInterface.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H +#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H + +#include "mlir/Interfaces/InferIntRangeInterface.h" + +namespace mlir { +/// A class that represents the strided metadata range information, including +/// offsets, sizes, and strides as integer ranges. +class StridedMetadataRange { +public: + /// Default constructor creates uninitialized ranges. + StridedMetadataRange() = default; + + /// Returns a ranked strided metadata range. + static StridedMetadataRange + getRanked(SmallVectorImpl<ConstantIntRanges> &&offsets, + SmallVectorImpl<ConstantIntRanges> &&sizes, + SmallVectorImpl<ConstantIntRanges> &&strides) { + return StridedMetadataRange(std::move(offsets), std::move(sizes), + std::move(strides)); + } + + /// Returns a strided metadata range with maximum ranges. + static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, + int32_t offsetsRank, + int32_t sizeRank, + int32_t stridedRank) { + return StridedMetadataRange( + SmallVector<ConstantIntRanges>( + offsetsRank, ConstantIntRanges::maxRange(indexBitwidth)), + SmallVector<ConstantIntRanges>( + sizeRank, ConstantIntRanges::maxRange(indexBitwidth)), + SmallVector<ConstantIntRanges>( + stridedRank, ConstantIntRanges::maxRange(indexBitwidth))); + } + + static StridedMetadataRange getMaxRanges(int32_t indexBitwidth, + int32_t rank) { + return getMaxRanges(indexBitwidth, 1, rank, rank); + } + + /// Returns whether the metadata is uninitialized. + bool isUninitialized() const { return !offsets.has_value(); } + + /// Get the offsets range. + ArrayRef<ConstantIntRanges> getOffsets() const { + return offsets ? *offsets : ArrayRef<ConstantIntRanges>(); + } + MutableArrayRef<ConstantIntRanges> getOffsets() { + return offsets ? *offsets : MutableArrayRef<ConstantIntRanges>(); + } + + /// Get the sizes ranges. + ArrayRef<ConstantIntRanges> getSizes() const { return sizes; } + MutableArrayRef<ConstantIntRanges> getSizes() { return sizes; } + + /// Get the strides ranges. + ArrayRef<ConstantIntRanges> getStrides() const { return strides; } + MutableArrayRef<ConstantIntRanges> getStrides() { return strides; } + + /// Compare two strided metadata ranges. + bool operator==(const StridedMetadataRange &other) const { + return offsets == other.offsets && sizes == other.sizes && + strides == other.strides; + } + + /// Print the strided metadata range. + void print(raw_ostream &os) const; + + /// Join two strided metadata ranges, by taking the element-wise union of the + /// metadata. + static StridedMetadataRange join(const StridedMetadataRange &lhs, + const StridedMetadataRange &rhs) { + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + + // Helper fuction to compute the range union of constant ranges. + auto rangeUnion = + +[](const std::tuple<ConstantIntRanges, ConstantIntRanges> &lhsRhs) + -> ConstantIntRanges { + return std::get<0>(lhsRhs).rangeUnion(std::get<1>(lhsRhs)); + }; + + // Get the elementwise range union. Note, that `zip_equal` will assert if + // sizes are not equal. + SmallVector<ConstantIntRanges> offsets = llvm::map_to_vector( + llvm::zip_equal(*lhs.offsets, *rhs.offsets), rangeUnion); + SmallVector<ConstantIntRanges> sizes = + llvm::map_to_vector(llvm::zip_equal(lhs.sizes, rhs.sizes), rangeUnion); + SmallVector<ConstantIntRanges> strides = llvm::map_to_vector( + llvm::zip_equal(lhs.strides, rhs.strides), rangeUnion); + + // Return the joined metadata. + return StridedMetadataRange(std::move(offsets), std::move(sizes), + std::move(strides)); + } + +private: + /// Create a strided metadata range with the given offset, sizes, and strides. + StridedMetadataRange(SmallVectorImpl<ConstantIntRanges> &&offsets, + SmallVectorImpl<ConstantIntRanges> &&sizes, + SmallVectorImpl<ConstantIntRanges> &&strides) + : offsets(std::move(offsets)), sizes(std::move(sizes)), + strides(std::move(strides)) {} + + /// The offsets range. + std::optional<SmallVector<ConstantIntRanges>> offsets; + + /// The sizes ranges. + SmallVector<ConstantIntRanges> sizes; + + /// The strides ranges. + SmallVector<ConstantIntRanges> strides; +}; + +/// Print the strided metadata to `os`. +inline raw_ostream &operator<<(raw_ostream &os, + const StridedMetadataRange &range) { + range.print(os); + return os; +} + +/// Callback function type for setting the strided metadata of a value. +using SetStridedMetadataRangeFn = + function_ref<void(Value, const StridedMetadataRange &)>; +} // end namespace mlir + +#include "mlir/Interfaces/InferStridedMetadataInterface.h.inc" + +#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td new file mode 100644 index 0000000..ee5b094 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferStridedMetadataInterface.td @@ -0,0 +1,45 @@ +//===- InferStridedMetadataInterface.td - Strided MD Inference ----------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for strided metadata range analysis +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE +#define MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE + +include "mlir/IR/OpBase.td" + +def InferStridedMetadataOpInterface : + OpInterface<"InferStridedMetadataOpInterface"> { + let description = [{ + Allows operations to participate in strided metadata analysis by providing + methods that allow them to specify bounds on offsets, sizes, and strides + of their result(s) given bounds on their input(s) if known. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Infer the strided metadata bounds on the results of this op given + the bounds on its operands. + For each result value or block argument of interest, the method should + call `setMetadata` with that `Value` as an argument. + The `operands` parameter contains the strided metadata ranges for all the + operands of the operation in order. + The `getIntRange` callback is provided for obtaining the int-range + analysis result for a given value. + }], + "void", "inferStridedMetadataRanges", + (ins "::llvm::ArrayRef<::mlir::StridedMetadataRange>":$operands, + "::mlir::GetIntRangeFn":$getIntRange, + "::mlir::SetStridedMetadataRangeFn":$setMetadata, + "int32_t":$indexBitwidth)> + ]; +} +#endif // MLIR_INTERFACES_INFERSTRIDEDMETADATAINTERFACE diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index db9c37f..c1c2269 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -230,6 +230,22 @@ LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, ArrayRef<int64_t> attr, ValueRange values); +namespace OpTrait { +/// This trai indicates that pointer-like objects (such as memrefs) returned +/// from this operation will never alias with each other. This provides a +/// guarantee to optimization passes that accesses through different results +/// of this operation can be safely reordered, as they will never reference +/// overlapping memory locations. +/// +/// Operations with this trait take multiple pointer-like operands +/// and return the same operands with additional non-aliasing guarantees. +/// If the access to the results of this operation aliases at runtime, the +/// behavior of such access is undefined. +template <typename ConcreteType> +class DistinctObjectsTrait + : public TraitBase<ConcreteType, DistinctObjectsTrait> {}; +} // namespace OpTrait + } // namespace mlir #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td index ed213bf..131c1a0 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -414,4 +414,16 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface }]; } +// This trai indicates that pointer-like objects (such as memrefs) returned +// from this operation will never alias with each other. This provides a +// guarantee to optimization passes that accesses through different results +// of this operation can be safely reordered, as they will never reference +// overlapping memory locations. +// +// Operations with this trait take multiple pointer-like operands +// and return the same operands with additional non-aliasing guarantees. +// If the access to the results of this operation aliases at runtime, the +// behavior of such access is undefined. +def DistinctObjectsTrait : NativeOpTrait<"DistinctObjectsTrait">; + #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE diff --git a/mlir/include/mlir/Remark/RemarkStreamer.h b/mlir/include/mlir/Remark/RemarkStreamer.h index 170d6b4..19a70fa 100644 --- a/mlir/include/mlir/Remark/RemarkStreamer.h +++ b/mlir/include/mlir/Remark/RemarkStreamer.h @@ -45,6 +45,7 @@ namespace mlir::remark { /// mlir::emitRemarks. LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef filePath, llvm::remarks::Format fmt, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks = false); } // namespace mlir::remark diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h index 252da21..997aef2 100644 --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -88,7 +88,7 @@ public: /// /// Constraints that do not meet the restriction that they can only reference /// `$_self` and `$_op` are not uniqued. - void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs); + void emitOpConstraints(); /// Unique all compatible type and attribute constraints from a pattern file /// and emit them at the top of the generated file. diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h index 21adde8..cd9ef5b 100644 --- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -19,6 +19,14 @@ namespace mlir { struct WasmBinaryEncoding { /// Byte encodings for Wasm instructions. struct OpCode { + // Control instructions. + static constexpr std::byte block{0x02}; + static constexpr std::byte loop{0x03}; + static constexpr std::byte ifOpCode{0x04}; + static constexpr std::byte elseOpCode{0x05}; + static constexpr std::byte branchIf{0x0D}; + static constexpr std::byte call{0x10}; + // Locals, globals, constants. static constexpr std::byte localGet{0x20}; static constexpr std::byte localSet{0x21}; @@ -29,6 +37,42 @@ struct WasmBinaryEncoding { static constexpr std::byte constFP32{0x43}; static constexpr std::byte constFP64{0x44}; + // Comparisons. + static constexpr std::byte eqzI32{0x45}; + static constexpr std::byte eqI32{0x46}; + static constexpr std::byte neI32{0x47}; + static constexpr std::byte ltSI32{0x48}; + static constexpr std::byte ltUI32{0x49}; + static constexpr std::byte gtSI32{0x4A}; + static constexpr std::byte gtUI32{0x4B}; + static constexpr std::byte leSI32{0x4C}; + static constexpr std::byte leUI32{0x4D}; + static constexpr std::byte geSI32{0x4E}; + static constexpr std::byte geUI32{0x4F}; + static constexpr std::byte eqzI64{0x50}; + static constexpr std::byte eqI64{0x51}; + static constexpr std::byte neI64{0x52}; + static constexpr std::byte ltSI64{0x53}; + static constexpr std::byte ltUI64{0x54}; + static constexpr std::byte gtSI64{0x55}; + static constexpr std::byte gtUI64{0x56}; + static constexpr std::byte leSI64{0x57}; + static constexpr std::byte leUI64{0x58}; + static constexpr std::byte geSI64{0x59}; + static constexpr std::byte geUI64{0x5A}; + static constexpr std::byte eqF32{0x5B}; + static constexpr std::byte neF32{0x5C}; + static constexpr std::byte ltF32{0x5D}; + static constexpr std::byte gtF32{0x5E}; + static constexpr std::byte leF32{0x5F}; + static constexpr std::byte geF32{0x60}; + static constexpr std::byte eqF64{0x61}; + static constexpr std::byte neF64{0x62}; + static constexpr std::byte ltF64{0x63}; + static constexpr std::byte gtF64{0x64}; + static constexpr std::byte leF64{0x65}; + static constexpr std::byte geF64{0x66}; + // Numeric operations. static constexpr std::byte clzI32{0x67}; static constexpr std::byte ctzI32{0x68}; @@ -93,6 +137,33 @@ struct WasmBinaryEncoding { static constexpr std::byte maxF64{0xA5}; static constexpr std::byte copysignF64{0xA6}; static constexpr std::byte wrap{0xA7}; + + // Conversion operations + static constexpr std::byte extendS{0xAC}; + static constexpr std::byte extendU{0xAD}; + static constexpr std::byte convertSI32F32{0xB2}; + static constexpr std::byte convertUI32F32{0xB3}; + static constexpr std::byte convertSI64F32{0xB4}; + static constexpr std::byte convertUI64F32{0xB5}; + + static constexpr std::byte demoteF64ToF32{0xB6}; + + static constexpr std::byte convertSI32F64{0xB7}; + static constexpr std::byte convertUI32F64{0xB8}; + static constexpr std::byte convertSI64F64{0xB9}; + static constexpr std::byte convertUI64F64{0xBA}; + + static constexpr std::byte promoteF32ToF64{0xBB}; + static constexpr std::byte reinterpretF32AsI32{0xBC}; + static constexpr std::byte reinterpretF64AsI64{0xBD}; + static constexpr std::byte reinterpretI32AsF32{0xBE}; + static constexpr std::byte reinterpretI64AsF64{0xBF}; + + static constexpr std::byte extendI328S{0xC0}; + static constexpr std::byte extendI3216S{0xC1}; + static constexpr std::byte extendI648S{0xC2}; + static constexpr std::byte extendI6416S{0xC3}; + static constexpr std::byte extendI6432S{0xC4}; }; /// Byte encodings of types in Wasm binaries diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h index 0fbe15f..b739438 100644 --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -44,6 +44,11 @@ enum class RemarkFormat { REMARK_FORMAT_BITSTREAM, }; +enum class RemarkPolicy { + REMARK_POLICY_ALL, + REMARK_POLICY_FINAL, +}; + /// Configuration options for the mlir-opt tool. /// This is intended to help building tools like mlir-opt by collecting the /// supported options. @@ -242,6 +247,8 @@ public: /// Set the reproducer output filename RemarkFormat getRemarkFormat() const { return remarkFormatFlag; } + /// Set the remark policy to use. + RemarkPolicy getRemarkPolicy() const { return remarkPolicyFlag; } /// Set the remark format to use. std::string getRemarksAllFilter() const { return remarksAllFilterFlag; } /// Set the remark output file. @@ -265,6 +272,8 @@ protected: /// Remark format RemarkFormat remarkFormatFlag = RemarkFormat::REMARK_FORMAT_STDOUT; + /// Remark policy + RemarkPolicy remarkPolicyFlag = RemarkPolicy::REMARK_POLICY_ALL; /// Remark file to output to std::string remarksOutputFileFlag = ""; /// Remark filters |