diff options
Diffstat (limited to 'mlir/include')
-rw-r--r-- | mlir/include/mlir-c/Rewrite.h | 2 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h | 8 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/Passes.td | 9 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 93 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Affine/IR/AffineOps.td | 1 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 94 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h | 57 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 61 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 37 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/SMT/IR/SMTOps.td | 2 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Shard/IR/ShardOps.td | 117 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h | 1 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td | 17 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td | 175 | ||||
-rw-r--r-- | mlir/include/mlir/IR/CommonTypeConstraints.td | 8 | ||||
-rw-r--r-- | mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h | 71 |
16 files changed, 585 insertions, 168 deletions
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 2db1d84..fe42a20 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -352,7 +352,7 @@ typedef struct { /// Create a rewrite pattern that matches the operation /// with the given rootName, corresponding to mlir::OpRewritePattern. -MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( +MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames); diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h index 46573e79..60f1888 100644 --- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/IR/PatternMatch.h" #include <memory> @@ -19,8 +20,11 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" /// Populate the given list with patterns that convert from Math to ROCDL calls. -void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns); +// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`, +// none of the chipset dependent patterns are added. +void populateMathToROCDLConversionPatterns( + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + std::optional<amdgpu::Chipset> chipset); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 25e9d34..70e3e45 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { let summary = "Convert Math dialect to ROCDL library calls"; let description = [{ This pass converts supported Math ops to ROCDL library calls. + + The chipset option specifies the target AMDGPU architecture. If the chipset + is empty, none of the chipset-dependent patterns are added, and the pass + will not attempt to parse the chipset. }]; let dependentDialects = [ "arith::ArithDialect", @@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; + let options = [Option<"chipset", "chipset", "std::string", + /*default=*/"\"\"", + "Chipset that these operations will run on">]; } //===----------------------------------------------------------------------===// @@ -800,7 +807,7 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> { // MathToXeVM //===----------------------------------------------------------------------===// -def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> { +def ConvertMathToXeVM : Pass<"convert-math-to-xevm"> { let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents"; let description = [{ diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 8370d35..7184de9 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -112,6 +112,97 @@ def AMDGPU_ExtPackedFp8Op : }]; } +def IsValidBlockSize: AttrConstraint< + CPred<"::llvm::is_contained({16, 32}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">, + "whose value is 16 or 32">; + +def AMDGPU_ScaledExtPacked816Op + : AMDGPU_Op<"scaled_ext_packed816", [Pure, AllShapesMatch<["source", "res"]>]>, + Arguments<( + ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>, + FixedVectorOfShapeAndType<[8], F8E4M3FN>, + FixedVectorOfShapeAndType<[8], F8E5M2>, + FixedVectorOfShapeAndType<[16], F6E2M3FN>, + FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source, + FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale, + ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize, + ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane, + ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>, + Results<( + outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>, + FixedVectorOfShapeAndType<[8], F16>, + FixedVectorOfShapeAndType<[8], BF16>, + FixedVectorOfShapeAndType<[16], F32>, + FixedVectorOfShapeAndType<[16], F16>, + FixedVectorOfShapeAndType<[16], BF16>]>:$res)> { + + let summary = "Extend a vector of packed floating point values"; + + let description = [{ + The scales applied to the input microfloats are stored in two bytes which + come from the `scales` input provided in a *half* of the wave identified + by `firstScaleLane`. The pair of bytes used is selected by + `firstScaleByte`. The 16 vectors in consecutive lanes starting from + `firstScaleLane` (which we'll call the scale vectors) will be used by both + halves of the wave (with lane L reading from L % 16'th scale vector), but + each half will use a different byte. + + When the block size is 32, `firstScaleByte` can be either 0 or 2, + selecting halves of the scale vectors. Lanes 0-15 will read from + `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1. + For example: + ```mlir + // Input: 8-element vector of F8E4M3FN, converting to F32 + // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1 + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(32) firstScaleLane(0) firstScaleByte(0) + : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + + // Input: 16-element vector of F6E2M3FN, converting to F16 + // Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3 + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(32) firstScaleLane(1) firstScaleByte(2) + : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + ``` + + However, when the block size is 16, `firstScaleByte` can be 0 or 1. + Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors, + while lanes 16-31 read from `firstScaleByte` + 2. + For example: + ```mlir + // Input: 8-element vector of F8E5M2, converting to BF16 + // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2) + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(16) firstScaleLane(0) firstScaleByte(0) + : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // Input: 16-element vector of F6E3M2FN, converting to F32 + // Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2) + %result = amdgpu.scaled_ext_packed816 %source scale(%scales) + blockSize(16) firstScaleLane(1) firstScaleByte(1) + : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + ``` + + Note: the layout for the scales generally mirrors how the WMMA + instructions use for matix scales. These selection operands allows + one to choose portions of the matrix to convert. + + Available on gfx1250+. + }]; + + let assemblyFormat = [{ + attr-dict $source + `scale` `(` $scale `)` + `blockSize` `(` $blockSize `)` + `firstScaleLane` `(` $firstScaleLane`)` + `firstScaleByte` `(` $firstScaleByte `)` + `:` type($source) `,` type($scale) `->` type($res) + }]; + + let hasVerifier = 1; + +} + def AMDGPU_ScaledExtPackedOp : AMDGPU_Op<"scaled_ext_packed", [Pure]>, Arguments<( @@ -860,7 +951,7 @@ def AMDGPU_MFMAOp : based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the types of the source and destination arguments. - For information on the layouts of the input and output matrces (which are stored + For information on the layouts of the input and output matrices (which are stored in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation. The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index e52b7d2..12a7935 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -330,7 +330,6 @@ def AffineForOp : Affine_Op<"for", Speculation::Speculatability getSpeculatability(); }]; - let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasFolder = 1; let hasRegionVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 20c9097..a38cf41 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1229,37 +1229,50 @@ def Arith_ScalingExtFOp let summary = "Upcasts input floats using provided scales values following " "OCP MXFP Spec"; let description = [{ - This operation upcasts input floating-point values using provided scale - values. It expects both scales and the input operand to be of the same shape, - making the operation elementwise. Scales are usually calculated per block - following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. - - If scales are calculated per block where blockSize != 1, then scales may - require broadcasting to make this operation elementwise. For example, let's - say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and - assuming quantization happens on the last axis, the input can be reshaped to - `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated - per block on the last axis. Therefore, scales will be of shape - `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other - shape as long as it is broadcast compatible with the input, e.g., - `<1 x 1 x ... (dimN/blockSize) x 1>`. - - In this example, before calling into `arith.scaling_extf`, scales must be - broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note - that there could be multiple quantization axes. Internally, - `arith.scaling_extf` would perform the following: + This operation upcasts input floating-point values using provided scale + values. It expects both scales and the input operand to be of the same shape, + making the operation elementwise. Scales are usually calculated per block + following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. - ``` - resultTy = get_type(result) - scaleTy = get_type(scale) - inputTy = get_type(input) - scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 - scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy - input.extf = arith.extf(input) : inputTy to resultTy - result = arith.mulf(scale.extf, input.extf) + If scales are calculated per block where blockSize != 1, then scales may + require broadcasting to make this operation elementwise. For example, let's + say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and + assuming quantization happens on the last axis, the input can be reshaped to + `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated + per block on the last axis. Therefore, scales will be of shape + `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other + shape as long as it is broadcast compatible with the input, e.g., + `<1 x 1 x ... (dimN/blockSize) x 1>`. + + In this example, before calling into `arith.scaling_extf`, scales must be + broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note + that there could be multiple quantization axes. Internally, + `arith.scaling_extf` would perform the following: + + ```mlir + // Cast scale to result type. + %0 = arith.truncf %1 : f32 to f8E8M0FNU + %1 = arith.extf %0 : f8E8M0FNU to f16 + + // Cast input to result type. + %2 = arith.extf %3 : f4E2M1FN to f16 + + // Perform scaling + %3 = arith.mulf %2, %1 : f16 ``` It propagates NaN values. Therefore, if either scale or the input element contains NaN, then the output element value will also be a NaN. + + Example: + + ```mlir + // Upcast from f4E2M1FN to f32. + %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32 + + // Element-wise upcast with broadcast (blockSize = 32). + %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> + %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16> + ``` }]; let hasVerifier = 1; let assemblyFormat = @@ -1397,14 +1410,27 @@ def Arith_ScalingTruncFOp that there could be multiple quantization axes. Internally, `arith.scaling_truncf` would perform the following: + ```mlir + // Cast scale to input type. + %0 = arith.truncf %1 : f32 to f8E8M0FNU + %1 = arith.extf %0 : f8E8M0FNU to f16 + + // Perform scaling. + %3 = arith.divf %2, %1 : f16 + + // Cast to result type. + %4 = arith.truncf %3 : f16 to f4E2M1FN ``` - scaleTy = get_type(scale) - inputTy = get_type(input) - resultTy = get_type(result) - scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 - scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy - result = arith.divf(input, scale.extf) - result.cast = arith.truncf(result, resultTy) + + Example: + + ```mlir + // Downcast from f32 to f4E2M1FN. + %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN + + // Element-wise downcast with broadcast (blockSize = 32). + %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> + %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN> ``` }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h index 035235f..fccb49d 100644 --- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h @@ -1,4 +1,4 @@ -//===- Passes.h - GPU NVVM pipeline entry points --------------------------===// +//===- Passes.h - GPU pipeline entry points--------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -60,6 +60,52 @@ struct GPUToNVVMPipelineOptions llvm::cl::init(false)}; }; +// Options for the gpu to xevm pipeline. +struct GPUToXeVMPipelineOptions + : public PassPipelineOptions<GPUToXeVMPipelineOptions> { + PassOptions::Option<std::string> xegpuOpLevel{ + *this, "xegpu-op-level", + llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | " + "subgroup | lane"), + llvm::cl::init("workgroup")}; + // General lowering controls. + PassOptions::Option<bool> use64bitIndex{ + *this, "use-64bit-index", + llvm::cl::desc("Bitwidth of the index type (host & device)"), + llvm::cl::init(true)}; + PassOptions::Option<bool> kernelBarePtrCallConv{ + *this, "kernel-bare-ptr-calling-convention", + llvm::cl::desc("Use bare pointer calling convention for device kernels"), + llvm::cl::init(false)}; + PassOptions::Option<bool> hostBarePtrCallConv{ + *this, "host-bare-ptr-calling-convention", + llvm::cl::desc("Use bare pointer calling convention for host launches"), + llvm::cl::init(false)}; + PassOptions::Option<std::string> binaryFormat{ + *this, "binary-format", + llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"), + llvm::cl::init("fatbin")}; + // Options mirroring xevm-attach-target (GpuXeVMAttachTarget). + PassOptions::Option<std::string> xevmModuleMatcher{ + *this, "xevm-module-matcher", + llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"), + llvm::cl::init("")}; + PassOptions::Option<std::string> zebinTriple{ + *this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"), + llvm::cl::init("spirv64-unknown-unknown")}; + PassOptions::Option<std::string> zebinChip{ + *this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"), + llvm::cl::init("bmg")}; + PassOptions::Option<unsigned> optLevel{ + *this, "opt-level", + llvm::cl::desc("Optimization level for attached target/codegen"), + llvm::cl::init(2)}; + PassOptions::Option<std::string> cmdOptions{ + *this, "igc-cmd-options", + llvm::cl::desc("Additional downstream compiler command line options"), + llvm::cl::init("")}; +}; + //===----------------------------------------------------------------------===// // Building and Registering. //===----------------------------------------------------------------------===// @@ -70,8 +116,15 @@ struct GPUToNVVMPipelineOptions void buildLowerToNVVMPassPipeline(OpPassManager &pm, const GPUToNVVMPipelineOptions &options); -/// Register all pipeleines for the `gpu` dialect. +/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main +/// dialects into XeVM targets. Begins with GPU code regions, then handles host +/// code. +void buildLowerToXeVMPassPipeline(OpPassManager &pm, + const GPUToXeVMPipelineOptions &options); + +/// Register all pipelines for the `gpu` dialect. void registerGPUToNVVMPipeline(); +void registerGPUToXeVMPipeline(); } // namespace gpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 68f31e6..d2df244 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -574,6 +574,30 @@ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_b def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>; def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>; +// Available from gfx1250 +def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>; +def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>; +def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>; +def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>; +def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>; +def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>; +def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>; +def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>; +def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>; +def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>; +def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>; +def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>; +def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>; +def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>; //===---------------------------------------------------------------------===// // LDS transpose intrinsics (available in GFX950) @@ -1143,6 +1167,7 @@ foreach smallT = [ ScaleArgInfo<ROCDL_V16BF16Type, "Bf16">, ScaleArgInfo<ROCDL_V16F32Type, "F32">, ] in { + // Up-scaling def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op : ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name, [Pure], 1, [2], ["scaleSel"]>, @@ -1158,6 +1183,42 @@ foreach smallT = [ }]; } + + // Down-scaling + def ROCDL_CvtScaleF32Pk16 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk16." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name ; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $scale `:` type($res) + }]; + } + + def ROCDL_CvtScaleF32SrPk16 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk16." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name # " with stochastic rounding"; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so and apply stochastic rounding. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $seed `,` $scale `:` type($res) + }]; + } + } // foreach largeT } // foreach smallTOp diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index ae7a085..c89fc59 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -25,7 +25,6 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallSet.h" namespace mlir { namespace bufferization { @@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> -computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v, +computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v, AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes, const PadTilingInterfaceOptions &options); using PadSizeComputationFunction = std::function<FailureOr<SmallVector<OpFoldResult>>( - RewriterBase &, OpOperand &, ArrayRef<Range>, + OpBuilder &, OpOperand &, ArrayRef<Range>, const PadTilingInterfaceOptions &)>; /// Specific helper for Linalg ops. -FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape( - RewriterBase &rewriter, OpOperand &operandToPad, - ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options); +FailureOr<SmallVector<OpFoldResult>> +computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad, + ArrayRef<Range> iterationDomain, + const PadTilingInterfaceOptions &); + +/// Operations and values created in the process of padding a TilingInterface +/// operation. +struct PadTilingInterfaceResult { + /// The operands of the padded op. + SmallVector<tensor::PadOp> padOps; + /// The padded op, a clone of `toPad` with padded operands. + TilingInterface paddedOp; + /// Slices of the padded op's results, same types as `toPad`. + SmallVector<Value> replacements; +}; -/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`. -/// +/// Pad the iterator dimensions of `toPad`. /// * "options.paddingSizes" indicates that each padding dimension should be /// padded to the specified padding size. /// * "options.padToMultipleOf" indicates that the paddingSizes should be // interpreted as the bounding box (dynamic) value to pad to. /// * Use "options.paddingValues" to set the padding value of the created // tensor::PadOp. -/// * The tensor::PadOp is returned on success. - -FailureOr<TilingInterface> -rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, - const PadTilingInterfaceOptions &constOptions, - SmallVector<tensor::PadOp> &padOps, - const PadSizeComputationFunction &computePaddingSizeFun = +FailureOr<PadTilingInterfaceResult> +rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad, + PadTilingInterfaceOptions options, + const PadSizeComputationFunction & = &computeIndexingMapOpInterfacePaddedShape); namespace detail { diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td index 3143ab7..99b22e5 100644 --- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td @@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [ Pure, Terminator, ReturnLike, - ParentOneOf<["smt::SolverOp", "smt::CheckOp", - "smt::ForallOp", "smt::ExistsOp"]>, ]> { let summary = "terminator operation for various regions of SMT operations"; let arguments = (ins Variadic<AnyType>:$values); diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index b9d7163..5e68f75e 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [ ]> { let summary = "All-gather over a device grid."; let description = [{ - Gathers along the `gather_axis` tensor axis. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and replicates the result across all devices + in the group. Example: ```mlir @@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [ SameOperandsAndResultShape]> { let summary = "All-reduce over a device grid."; let description = [{ - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + Each device in a group receives a replicated copy of the reduction result. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ SameOperandsAndResultElementType, SameOperandsAndResultRank ]> { - let summary = "All-slice over a device grid. This is the inverse of all-gather."; + let summary = "All-slice over a device grid."; let description = [{ - Slice along the `slice_axis` tensor axis. - This operation can be thought of as the inverse of all-gather. - Technically, it is not required that all processes have the same input tensor. - Each process will slice a piece of its local tensor based on its in-group device index. - The operation does not communicate data between devices. + Within each device group defined by `grid_axes`, slices the input tensor along + the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if + the input data is replicated along the `slice_axis`. + Each process simply crops its local data to the slice corresponding to its + in-group device index. + Notice: `AllSliceOp` does not involve any communication between devices and + devices within a group may not have replicated input data. Example: ```mlir @@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ ``` Result: ``` - gather tensor + slice tensor axis 1 ------------> +-------+-------+ @@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [ SameOperandsAndResultRank]> { let summary = "All-to-all over a device grid."; let description = [{ - Performs an all-to-all on tensor pieces split along `split_axis`. - The resulting pieces are concatenated along `concat_axis` on ech device. + Each participant logically splits its input along split_axis, + then scatters the resulting pieces across the group defined by `grid_axes`. + After receiving data pieces from other participants' scatters, + it concatenates them along concat_axis to produce the final result. Example: ``` @@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ ]> { let summary = "Broadcast over a device grid."; let description = [{ - Broadcast the tensor on `root` to all devices in each respective group. - The operation broadcasts along grid axes `grid_axes`. - The `root` device specifies the in-group multi-index that is broadcast to - all other devices in the group. + Copies the input tensor on `root` to all devices in each group defined by + `grid_axes`. The `root` device is defined by its in-group multi-index. + The contents of input tensors on non-root devices are ignored. Example: ``` @@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ +-------+-------+ | broadcast device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 +-------+-------+ ↓ - device (1, 0) -> | | | <- device (1, 1) + device (1, 0) -> | * * | * * | <- device (1, 1) +-------+-------+ ``` @@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [ ]> { let summary = "Gather over a device grid."; let description = [{ - Gathers on device `root` along the `gather_axis` tensor axis. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device for each device group. - The result tensor on non-root devices is undefined. - Using it will result in undefined behavior. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and returns the resulting tensor on each + `root` device. The result on all other (non-root) devices is undefined. + The `root` device is defined by its in-group multi-index. Example: ```mlir @@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [ ]> { let summary = "Send over a device grid."; let description = [{ - Receive from a device within a device group. + Receive tensor from device `source`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. + The content of input tensor is ignored. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [ ]> { let summary = "Reduce over a device grid."; let description = [{ - Reduces on device `root` within each device group. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device within its device group. - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + The reduction result will be returned on the `root` device of each group. + It is undefined on all other (non-root) devices. + The `root` device is defined by its in-group multi-index. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" SameOperandsAndResultRank]> { let summary = "Reduce-scatter over a device grid."; let description = [{ - After the reduction, the result is scattered within each device group. - The tensor is split along `scatter_axis` and the pieces distributed - across the device group. + Reduces the input tensor across all devices within the groups defined by + `grid_axes` using the specified reduction method. The reduction is performed + element-wise across the tensor pieces from all devices in the group. + After reduction, the reduction result is scattered (split and distributed) + across the device group along `scatter_axis`. Example: ``` shard.grid @grid0(shape = 2x2) ... %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] reduction = <max> scatter_axis = 0 - : tensor<3x4xf32> -> tensor<1x4xf64> + : tensor<2x2xf32> -> tensor<1x2xf64> ``` Input: ``` @@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" Result: ``` +-------+ - | 6 8 | <- devices (0, 0) + | 5 6 | <- devices (0, 0) +-------+ - | 10 12 | <- devices (0, 1) + | 7 8 | <- devices (0, 1) +-------+ - | 22 24 | <- devices (1, 0) + | 13 14 | <- devices (1, 0) +-------+ - | 26 28 | <- devices (1, 1) + | 15 16 | <- devices (1, 1) +-------+ ``` }]; @@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ ]> { let summary = "Scatter over a device grid."; let description = [{ - For each device group split the input tensor on the `root` device along - axis `scatter_axis` and scatter the parts across the group devices. + For each device group defined by `grid_axes`, the input tensor on the `root` + device is split along axis `scatter_axis` and distributed across the group. + The content of the input on all other (non-root) devices is ignored. + The `root` device is defined by its in-group multi-index. Example: ``` @@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ (0, 1) ↓ +-------+-------+ | scatter tensor - device (0, 0) -> | | | | axis 0 - | | | ↓ + device (0, 0) -> | * * | * * | | axis 0 + | * * | * * | ↓ +-------+-------+ device (1, 0) -> | 1 2 | 5 6 | | 3 4 | 7 8 | @@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [ ]> { let summary = "Send over a device grid."; let description = [{ - Send from one device to another within a device group. + Send input tensor to device `destination`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [ ]> { let summary = "Shift over a device grid."; let description = [{ - Within each device group shift along grid axis `shift_axis` by an offset - `offset`. - The result on devices that do not have a corresponding source is undefined. - `shift_axis` must be one of `grid_axes`. - If the `rotate` attribute is present, - instead of a shift a rotation is done. + Within each device group defined by `grid_axes`, shifts input tensors along the + device grid's axis `shift_axis` by the specified offset. The `shift_axis` must + be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular. + That is, the offset wraps around according to the group size along `shift_axis`. + Otherwise, the results on devices without a corresponding source are undefined. Example: ``` diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h index fc69b03..f6353a9 100644 --- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/SMT/IR/SMTOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td index b987cb3..9d9783a 100644 --- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td @@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [ DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, - NoTerminator + SingleBlockImplicitTerminator<"::mlir::smt::YieldOp"> ]> { let cppNamespace = [{ mlir::transform::smt }]; @@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [ let description = [{ Allows expressing constraints on params using the SMT dialect. - Each Transform dialect param provided as an operand has a corresponding + Each Transform-dialect param provided as an operand has a corresponding argument of SMT-type in the region. The SMT-Dialect ops in the region use - these arguments as operands. + these params-as-SMT-vars as operands, thereby expressing relevant + constraints on their allowed values. + + Computations w.r.t. passed-in params can also be expressed through the + region's SMT-ops. Namely, the constraints express relationships to other + SMT-variables which can then be yielded from the region (with `smt.yield`). The semantics of this op is that all the ops in the region together express a constraint on the params-interpreted-as-smt-vars. The op fails in case the expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the - op succeeds. + op succeeds and any one satisfying assignment is used to map the + SMT-variables yielded in the region to `transform.param`s. --- @@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [ }]; let arguments = (ins Variadic<TransformParamTypeInterface>:$params); + let results = (outs Variadic<TransformParamTypeInterface>:$results); let regions = (region SizedRegion<1>:$body); let assemblyFormat = - "`(` $params `)` attr-dict `:` type(operands) $body"; + "`(` $params `)` attr-dict `:` functional-type(operands, results) $body"; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td index b80ee2c..e9425e8 100644 --- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td +++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td @@ -43,9 +43,41 @@ class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> : let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target"; } -def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {} +def WasmSSA_BlockOp : WasmSSA_BlockLikeOp< + "block", + "Create a nesting level with a label at its exit."> { + let description = [{ + Defines a Wasm block, creating a new nested scope. + A block contains a body region and an optional list of input values. + Control can enter the block and later branch out to the block target. + Example: + + ```mlir + + wasmssa.block { + + // instructions + + } > ^successor + }]; +} + +def WasmSSA_LoopOp : WasmSSA_BlockLikeOp< + "loop", + "Create a nesting level that define its entry as jump target."> { + let description = [{ + Represents a Wasm loop construct. This defines a nesting level with + a label at the entry of the region. -def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {} + Example: + + ```mlir + + wasmssa.loop { + + } > ^successor + }]; +} def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator, DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> { @@ -55,9 +87,16 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator, ::mlir::Block* getTarget(); }]; let description = [{ - Marks a return from the current block. + Escape from the current nesting level and return the control flow to its successor. + Optionally, mark the arguments that should be transfered to the successor block. - Example: + This shouldn't be confused with branch operations that targets the label defined + by the nesting level operation. + + For instance, a `wasmssa.block_return` in a loop will give back control to the + successor of the loop, where a `branch` targeting the loop will flow back to the entry block of the loop. + + Example: ```mlir wasmssa.block_return @@ -127,12 +166,18 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ - Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type in the function type. + By default, `wasmssa.func` have nested visibility. Functions exported by the module + are marked with the exported attribute. This gives them public visibility. + Example: ```mlir - // A simple function with no arguments that returns a float32 + // Internal function with no arguments that returns a float32 wasmssa.func @my_f32_func() -> f32 + // Exported function with no arguments that returns a float32 + wasmssa.func exported @my_f32_func() -> f32 + // A function that takes a local ref argument wasmssa.func @i64_wrap(%a: !wasmssa<local ref to i64>) -> i32 ``` @@ -141,7 +186,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ WasmSSA_FuncTypeAttr: $functionType, OptionalAttr<DictArrayAttr>:$arg_attrs, OptionalAttr<DictArrayAttr>:$res_attrs, - DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility); + UnitAttr: $exported); let regions = (region AnyRegion: $body); let extraClassDeclaration = [{ @@ -162,6 +207,12 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ /// Returns the result types of this function. ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); } + + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; }]; let builders = [ @@ -207,8 +258,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [ StrAttr: $importName, WasmSSA_FuncTypeAttr: $type, OptionalAttr<DictArrayAttr>:$arg_attrs, - OptionalAttr<DictArrayAttr>:$res_attrs, - OptionalAttr<StrAttr>:$sym_visibility); + OptionalAttr<DictArrayAttr>:$res_attrs); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } @@ -221,6 +271,10 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [ ::llvm::ArrayRef<Type> getResultTypes() { return getType().getResults(); } + + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; let builders = [ OpBuilder<(ins "StringRef":$symbol, @@ -238,30 +292,41 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [ let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_ValTypeAttr: $type, UnitAttr: $isMutable, - OptionalAttr<StrAttr>:$sym_visibility); + UnitAttr: $exported); let description = [{ WebAssembly global variable. Body contains the initialization instructions for the variable value. The body must contain only instructions considered `const` in a webassembly context, such as `wasmssa.const` or `global.get`. + By default, `wasmssa.global` have nested visibility. Global exported by the module + are marked with the exported attribute. This gives them public visibility. + Example: ```mlir - // Define a global_var, a mutable i32 global variable equal to 10. - wasmssa.global @global_var i32 mutable nested : { + // Define module_global_var, an internal mutable i32 global variable equal to 10. + wasmssa.global @module_global_var i32 mutable : { %[[VAL_0:.*]] = wasmssa.const 10 : i32 wasmssa.return %[[VAL_0]] : i32 } + + // Define global_var, an exported constant i32 global variable equal to 42. + wasmssa.global @global_var i32 : { + %[[VAL_0:.*]] = wasmssa.const 42 : i32 + wasmssa.return %[[VAL_0]] : i32 + } ``` }]; let regions = (region AnyRegion: $initializer); - let builders = [ - OpBuilder<(ins "StringRef":$symbol, - "Type": $type, - "bool": $isMutable)> - ]; + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; let hasCustomAssemblyFormat = 1; } @@ -283,18 +348,14 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [ StrAttr: $moduleName, StrAttr: $importName, WasmSSA_ValTypeAttr: $type, - UnitAttr: $isMutable, - OptionalAttr<StrAttr>:$sym_visibility); + UnitAttr: $isMutable); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } + + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; - let builders = [ - OpBuilder<(ins "StringRef":$symbol, - "StringRef":$moduleName, - "StringRef":$importName, - "Type": $type, - "bool": $isMutable)> - ]; let hasCustomAssemblyFormat = 1; } @@ -442,23 +503,33 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> { Define a memory to be used by the program. Multiple memories can be defined in the same module. + By default, `wasmssa.memory` have nested visibility. Memory exported by + the module are marked with the exported attribute. This gives them public + visibility. + Example: ```mlir - // Define the `mem_0` memory with defined bounds of 0 -> 65536 + // Define the `mem_0` (internal) memory with defined size bounds of [0:65536] wasmssa.memory @mem_0 !wasmssa<limit[0:65536]> + + // Define the `mem_1` exported memory with minimal size of 512 + wasmssa.memory exported @mem_1 !wasmssa<limit[512:]> ``` }]; let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_LimitTypeAttr: $limits, - OptionalAttr<StrAttr>:$sym_visibility); - let builders = [ - OpBuilder<(ins - "::llvm::StringRef":$symbol, - "wasmssa::LimitType":$limit)> - ]; + UnitAttr: $exported); - let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict"; + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; + + let assemblyFormat = "(`exported` $exported^)? $sym_name $limits attr-dict"; } def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> { @@ -476,16 +547,13 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> let arguments = (ins SymbolNameAttr: $sym_name, StrAttr: $moduleName, StrAttr: $importName, - WasmSSA_LimitTypeAttr: $limits, - OptionalAttr<StrAttr>:$sym_visibility); + WasmSSA_LimitTypeAttr: $limits); let extraClassDeclaration = [{ - bool isDeclaration() const { return true; } + bool isDeclaration() const { return true; } + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "::llvm::StringRef":$moduleName, - "::llvm::StringRef":$importName, - "wasmssa::LimitType":$limits)>]; let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict"; } @@ -493,11 +561,15 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> { let summary= "WebAssembly table value"; let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_TableTypeAttr: $type, - OptionalAttr<StrAttr>:$sym_visibility); - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "wasmssa::TableType":$type)>]; - let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict"; + UnitAttr: $exported); + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; + let assemblyFormat = "(`exported` $exported^)? $sym_name $type attr-dict"; } def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> { @@ -515,17 +587,14 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterfac let arguments = (ins SymbolNameAttr: $sym_name, StrAttr: $moduleName, StrAttr: $importName, - WasmSSA_TableTypeAttr: $type, - OptionalAttr<StrAttr>:$sym_visibility); + WasmSSA_TableTypeAttr: $type); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict"; - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "::llvm::StringRef":$moduleName, - "::llvm::StringRef":$importName, - "wasmssa::TableType":$type)>]; } def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> { diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 6b4e3dd..8427ba5 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -623,6 +623,14 @@ class VectorOfLengthAndType<list<int> allowedLengths, VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary, "::mlir::VectorType">; +class FixedVectorOfShapeAndType<list<int> shape, Type elType>: ShapedContainerType< + [elType], + And<[IsVectorOfShape<shape>, IsFixedVectorOfAnyRankTypePred]>, + "vector<" # !interleave(shape, "x") # "x" # elType # ">", + "::mlir::VectorType">, + BuildableType<"::mlir::VectorType::get({" # !interleave(shape, " ,") # "} , " # elType.builderCall # " );">; + + // Any fixed-length vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class FixedVectorOfLengthAndType<list<int> allowedLengths, diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h index 21adde8..cd9ef5b 100644 --- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -19,6 +19,14 @@ namespace mlir { struct WasmBinaryEncoding { /// Byte encodings for Wasm instructions. struct OpCode { + // Control instructions. + static constexpr std::byte block{0x02}; + static constexpr std::byte loop{0x03}; + static constexpr std::byte ifOpCode{0x04}; + static constexpr std::byte elseOpCode{0x05}; + static constexpr std::byte branchIf{0x0D}; + static constexpr std::byte call{0x10}; + // Locals, globals, constants. static constexpr std::byte localGet{0x20}; static constexpr std::byte localSet{0x21}; @@ -29,6 +37,42 @@ struct WasmBinaryEncoding { static constexpr std::byte constFP32{0x43}; static constexpr std::byte constFP64{0x44}; + // Comparisons. + static constexpr std::byte eqzI32{0x45}; + static constexpr std::byte eqI32{0x46}; + static constexpr std::byte neI32{0x47}; + static constexpr std::byte ltSI32{0x48}; + static constexpr std::byte ltUI32{0x49}; + static constexpr std::byte gtSI32{0x4A}; + static constexpr std::byte gtUI32{0x4B}; + static constexpr std::byte leSI32{0x4C}; + static constexpr std::byte leUI32{0x4D}; + static constexpr std::byte geSI32{0x4E}; + static constexpr std::byte geUI32{0x4F}; + static constexpr std::byte eqzI64{0x50}; + static constexpr std::byte eqI64{0x51}; + static constexpr std::byte neI64{0x52}; + static constexpr std::byte ltSI64{0x53}; + static constexpr std::byte ltUI64{0x54}; + static constexpr std::byte gtSI64{0x55}; + static constexpr std::byte gtUI64{0x56}; + static constexpr std::byte leSI64{0x57}; + static constexpr std::byte leUI64{0x58}; + static constexpr std::byte geSI64{0x59}; + static constexpr std::byte geUI64{0x5A}; + static constexpr std::byte eqF32{0x5B}; + static constexpr std::byte neF32{0x5C}; + static constexpr std::byte ltF32{0x5D}; + static constexpr std::byte gtF32{0x5E}; + static constexpr std::byte leF32{0x5F}; + static constexpr std::byte geF32{0x60}; + static constexpr std::byte eqF64{0x61}; + static constexpr std::byte neF64{0x62}; + static constexpr std::byte ltF64{0x63}; + static constexpr std::byte gtF64{0x64}; + static constexpr std::byte leF64{0x65}; + static constexpr std::byte geF64{0x66}; + // Numeric operations. static constexpr std::byte clzI32{0x67}; static constexpr std::byte ctzI32{0x68}; @@ -93,6 +137,33 @@ struct WasmBinaryEncoding { static constexpr std::byte maxF64{0xA5}; static constexpr std::byte copysignF64{0xA6}; static constexpr std::byte wrap{0xA7}; + + // Conversion operations + static constexpr std::byte extendS{0xAC}; + static constexpr std::byte extendU{0xAD}; + static constexpr std::byte convertSI32F32{0xB2}; + static constexpr std::byte convertUI32F32{0xB3}; + static constexpr std::byte convertSI64F32{0xB4}; + static constexpr std::byte convertUI64F32{0xB5}; + + static constexpr std::byte demoteF64ToF32{0xB6}; + + static constexpr std::byte convertSI32F64{0xB7}; + static constexpr std::byte convertUI32F64{0xB8}; + static constexpr std::byte convertSI64F64{0xB9}; + static constexpr std::byte convertUI64F64{0xBA}; + + static constexpr std::byte promoteF32ToF64{0xBB}; + static constexpr std::byte reinterpretF32AsI32{0xBC}; + static constexpr std::byte reinterpretF64AsI64{0xBD}; + static constexpr std::byte reinterpretI32AsF32{0xBE}; + static constexpr std::byte reinterpretI64AsF64{0xBF}; + + static constexpr std::byte extendI328S{0xC0}; + static constexpr std::byte extendI3216S{0xC1}; + static constexpr std::byte extendI648S{0xC2}; + static constexpr std::byte extendI6416S{0xC3}; + static constexpr std::byte extendI6432S{0xC4}; }; /// Byte encodings of types in Wasm binaries |