aboutsummaryrefslogtreecommitdiff
path: root/mlir/include
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir-c/Rewrite.h2
-rw-r--r--mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h8
-rw-r--r--mlir/include/mlir/Conversion/Passes.td9
-rw-r--r--mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td93
-rw-r--r--mlir/include/mlir/Dialect/Affine/IR/AffineOps.td1
-rw-r--r--mlir/include/mlir/Dialect/Arith/IR/ArithOps.td94
-rw-r--r--mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h57
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td61
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h37
-rw-r--r--mlir/include/mlir/Dialect/SMT/IR/SMTOps.td2
-rw-r--r--mlir/include/mlir/Dialect/Shard/IR/ShardOps.td117
-rw-r--r--mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h1
-rw-r--r--mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td17
-rw-r--r--mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td175
-rw-r--r--mlir/include/mlir/IR/CommonTypeConstraints.td8
-rw-r--r--mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h71
16 files changed, 585 insertions, 168 deletions
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 2db1d84..fe42a20 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -352,7 +352,7 @@ typedef struct {
/// Create a rewrite pattern that matches the operation
/// with the given rootName, corresponding to mlir::OpRewritePattern.
-MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate(
+MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames);
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e79..60f1888 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>
@@ -19,8 +20,11 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to ROCDL calls.
-void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`,
+// none of the chipset dependent patterns are added.
+void populateMathToROCDLConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 25e9d34..70e3e45 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.
+
+ The chipset option specifies the target AMDGPU architecture. If the chipset
+ is empty, none of the chipset-dependent patterns are added, and the pass
+ will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
@@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
+ let options = [Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"\"",
+ "Chipset that these operations will run on">];
}
//===----------------------------------------------------------------------===//
@@ -800,7 +807,7 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
// MathToXeVM
//===----------------------------------------------------------------------===//
-def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm"> {
let summary =
"Convert (fast) math operations to native XeVM/SPIRV equivalents";
let description = [{
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 8370d35..7184de9 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -112,6 +112,97 @@ def AMDGPU_ExtPackedFp8Op :
}];
}
+def IsValidBlockSize: AttrConstraint<
+ CPred<"::llvm::is_contained({16, 32}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
+ "whose value is 16 or 32">;
+
+def AMDGPU_ScaledExtPacked816Op
+ : AMDGPU_Op<"scaled_ext_packed816", [Pure, AllShapesMatch<["source", "res"]>]>,
+ Arguments<(
+ ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>,
+ FixedVectorOfShapeAndType<[8], F8E4M3FN>,
+ FixedVectorOfShapeAndType<[8], F8E5M2>,
+ FixedVectorOfShapeAndType<[16], F6E2M3FN>,
+ FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source,
+ FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale,
+ ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
+ ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
+ ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>,
+ Results<(
+ outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>,
+ FixedVectorOfShapeAndType<[8], F16>,
+ FixedVectorOfShapeAndType<[8], BF16>,
+ FixedVectorOfShapeAndType<[16], F32>,
+ FixedVectorOfShapeAndType<[16], F16>,
+ FixedVectorOfShapeAndType<[16], BF16>]>:$res)> {
+
+ let summary = "Extend a vector of packed floating point values";
+
+ let description = [{
+ The scales applied to the input microfloats are stored in two bytes which
+ come from the `scales` input provided in a *half* of the wave identified
+ by `firstScaleLane`. The pair of bytes used is selected by
+ `firstScaleByte`. The 16 vectors in consecutive lanes starting from
+ `firstScaleLane` (which we'll call the scale vectors) will be used by both
+ halves of the wave (with lane L reading from L % 16'th scale vector), but
+ each half will use a different byte.
+
+ When the block size is 32, `firstScaleByte` can be either 0 or 2,
+ selecting halves of the scale vectors. Lanes 0-15 will read from
+ `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1.
+ For example:
+ ```mlir
+ // Input: 8-element vector of F8E4M3FN, converting to F32
+ // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(32) firstScaleLane(0) firstScaleByte(0)
+ : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
+
+ // Input: 16-element vector of F6E2M3FN, converting to F16
+ // Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(32) firstScaleLane(1) firstScaleByte(2)
+ : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
+ ```
+
+ However, when the block size is 16, `firstScaleByte` can be 0 or 1.
+ Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors,
+ while lanes 16-31 read from `firstScaleByte` + 2.
+ For example:
+ ```mlir
+ // Input: 8-element vector of F8E5M2, converting to BF16
+ // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2)
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(16) firstScaleLane(0) firstScaleByte(0)
+ : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
+
+ // Input: 16-element vector of F6E3M2FN, converting to F32
+ // Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2)
+ %result = amdgpu.scaled_ext_packed816 %source scale(%scales)
+ blockSize(16) firstScaleLane(1) firstScaleByte(1)
+ : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
+ ```
+
+ Note: the layout for the scales generally mirrors how the WMMA
+ instructions use for matix scales. These selection operands allows
+ one to choose portions of the matrix to convert.
+
+ Available on gfx1250+.
+ }];
+
+ let assemblyFormat = [{
+ attr-dict $source
+ `scale` `(` $scale `)`
+ `blockSize` `(` $blockSize `)`
+ `firstScaleLane` `(` $firstScaleLane`)`
+ `firstScaleByte` `(` $firstScaleByte `)`
+ `:` type($source) `,` type($scale) `->` type($res)
+ }];
+
+ let hasVerifier = 1;
+
+}
+
def AMDGPU_ScaledExtPackedOp
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
Arguments<(
@@ -860,7 +951,7 @@ def AMDGPU_MFMAOp :
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
types of the source and destination arguments.
- For information on the layouts of the input and output matrces (which are stored
+ For information on the layouts of the input and output matrices (which are stored
in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation.
The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index e52b7d2..12a7935 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -330,7 +330,6 @@ def AffineForOp : Affine_Op<"for",
Speculation::Speculatability getSpeculatability();
}];
- let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasRegionVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 20c9097..a38cf41 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1229,37 +1229,50 @@ def Arith_ScalingExtFOp
let summary = "Upcasts input floats using provided scales values following "
"OCP MXFP Spec";
let description = [{
- This operation upcasts input floating-point values using provided scale
- values. It expects both scales and the input operand to be of the same shape,
- making the operation elementwise. Scales are usually calculated per block
- following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
-
- If scales are calculated per block where blockSize != 1, then scales may
- require broadcasting to make this operation elementwise. For example, let's
- say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
- assuming quantization happens on the last axis, the input can be reshaped to
- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
- per block on the last axis. Therefore, scales will be of shape
- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
- shape as long as it is broadcast compatible with the input, e.g.,
- `<1 x 1 x ... (dimN/blockSize) x 1>`.
-
- In this example, before calling into `arith.scaling_extf`, scales must be
- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
- that there could be multiple quantization axes. Internally,
- `arith.scaling_extf` would perform the following:
+ This operation upcasts input floating-point values using provided scale
+ values. It expects both scales and the input operand to be of the same shape,
+ making the operation elementwise. Scales are usually calculated per block
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
- ```
- resultTy = get_type(result)
- scaleTy = get_type(scale)
- inputTy = get_type(input)
- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
- scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
- input.extf = arith.extf(input) : inputTy to resultTy
- result = arith.mulf(scale.extf, input.extf)
+ If scales are calculated per block where blockSize != 1, then scales may
+ require broadcasting to make this operation elementwise. For example, let's
+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+ In this example, before calling into `arith.scaling_extf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
+ `arith.scaling_extf` would perform the following:
+
+ ```mlir
+ // Cast scale to result type.
+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
+ %1 = arith.extf %0 : f8E8M0FNU to f16
+
+ // Cast input to result type.
+ %2 = arith.extf %3 : f4E2M1FN to f16
+
+ // Perform scaling
+ %3 = arith.mulf %2, %1 : f16
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
+
+ Example:
+
+ ```mlir
+ // Upcast from f4E2M1FN to f32.
+ %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
+
+ // Element-wise upcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
+ ```
}];
let hasVerifier = 1;
let assemblyFormat =
@@ -1397,14 +1410,27 @@ def Arith_ScalingTruncFOp
that there could be multiple quantization axes. Internally,
`arith.scaling_truncf` would perform the following:
+ ```mlir
+ // Cast scale to input type.
+ %0 = arith.truncf %1 : f32 to f8E8M0FNU
+ %1 = arith.extf %0 : f8E8M0FNU to f16
+
+ // Perform scaling.
+ %3 = arith.divf %2, %1 : f16
+
+ // Cast to result type.
+ %4 = arith.truncf %3 : f16 to f4E2M1FN
```
- scaleTy = get_type(scale)
- inputTy = get_type(input)
- resultTy = get_type(result)
- scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
- scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
- result = arith.divf(input, scale.extf)
- result.cast = arith.truncf(result, resultTy)
+
+ Example:
+
+ ```mlir
+ // Downcast from f32 to f4E2M1FN.
+ %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN
+
+ // Element-wise downcast with broadcast (blockSize = 32).
+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
+ %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
```
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index 035235f..fccb49d 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - GPU NVVM pipeline entry points --------------------------===//
+//===- Passes.h - GPU pipeline entry points--------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -60,6 +60,52 @@ struct GPUToNVVMPipelineOptions
llvm::cl::init(false)};
};
+// Options for the gpu to xevm pipeline.
+struct GPUToXeVMPipelineOptions
+ : public PassPipelineOptions<GPUToXeVMPipelineOptions> {
+ PassOptions::Option<std::string> xegpuOpLevel{
+ *this, "xegpu-op-level",
+ llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
+ "subgroup | lane"),
+ llvm::cl::init("workgroup")};
+ // General lowering controls.
+ PassOptions::Option<bool> use64bitIndex{
+ *this, "use-64bit-index",
+ llvm::cl::desc("Bitwidth of the index type (host & device)"),
+ llvm::cl::init(true)};
+ PassOptions::Option<bool> kernelBarePtrCallConv{
+ *this, "kernel-bare-ptr-calling-convention",
+ llvm::cl::desc("Use bare pointer calling convention for device kernels"),
+ llvm::cl::init(false)};
+ PassOptions::Option<bool> hostBarePtrCallConv{
+ *this, "host-bare-ptr-calling-convention",
+ llvm::cl::desc("Use bare pointer calling convention for host launches"),
+ llvm::cl::init(false)};
+ PassOptions::Option<std::string> binaryFormat{
+ *this, "binary-format",
+ llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"),
+ llvm::cl::init("fatbin")};
+ // Options mirroring xevm-attach-target (GpuXeVMAttachTarget).
+ PassOptions::Option<std::string> xevmModuleMatcher{
+ *this, "xevm-module-matcher",
+ llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"),
+ llvm::cl::init("")};
+ PassOptions::Option<std::string> zebinTriple{
+ *this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"),
+ llvm::cl::init("spirv64-unknown-unknown")};
+ PassOptions::Option<std::string> zebinChip{
+ *this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"),
+ llvm::cl::init("bmg")};
+ PassOptions::Option<unsigned> optLevel{
+ *this, "opt-level",
+ llvm::cl::desc("Optimization level for attached target/codegen"),
+ llvm::cl::init(2)};
+ PassOptions::Option<std::string> cmdOptions{
+ *this, "igc-cmd-options",
+ llvm::cl::desc("Additional downstream compiler command line options"),
+ llvm::cl::init("")};
+};
+
//===----------------------------------------------------------------------===//
// Building and Registering.
//===----------------------------------------------------------------------===//
@@ -70,8 +116,15 @@ struct GPUToNVVMPipelineOptions
void buildLowerToNVVMPassPipeline(OpPassManager &pm,
const GPUToNVVMPipelineOptions &options);
-/// Register all pipeleines for the `gpu` dialect.
+/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main
+/// dialects into XeVM targets. Begins with GPU code regions, then handles host
+/// code.
+void buildLowerToXeVMPassPipeline(OpPassManager &pm,
+ const GPUToXeVMPipelineOptions &options);
+
+/// Register all pipelines for the `gpu` dialect.
void registerGPUToNVVMPipeline();
+void registerGPUToXeVMPipeline();
} // namespace gpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 68f31e6..d2df244 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -574,6 +574,30 @@ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_b
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+// Available from gfx1250
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
@@ -1143,6 +1167,7 @@ foreach smallT = [
ScaleArgInfo<ROCDL_V16BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V16F32Type, "F32">,
] in {
+ // Up-scaling
def ROCDL_CvtPkScalePk16 # largeT.nameForOp # smallT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk16." # largeT.name # "." # smallT.name,
[Pure], 1, [2], ["scaleSel"]>,
@@ -1158,6 +1183,42 @@ foreach smallT = [
}];
}
+
+ // Down-scaling
+ def ROCDL_CvtScaleF32Pk16 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk16." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name ;
+ let description = [{
+ Convert 8 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, multiplying by the exponent part of `scale`
+ before doing so. This op is for gfx1250+ arch.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `:` type($res)
+ }];
+ }
+
+ def ROCDL_CvtScaleF32SrPk16 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk16." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name # " with stochastic rounding";
+ let description = [{
+ Convert 8 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, multiplying by the exponent part of `scale`
+ before doing so and apply stochastic rounding. This op is for gfx1250+ arch.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `:` type($res)
+ }];
+ }
+
} // foreach largeT
} // foreach smallTOp
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ae7a085..c89fc59 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -25,7 +25,6 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
namespace mlir {
namespace bufferization {
@@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
-computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);
using PadSizeComputationFunction =
std::function<FailureOr<SmallVector<OpFoldResult>>(
- RewriterBase &, OpOperand &, ArrayRef<Range>,
+ OpBuilder &, OpOperand &, ArrayRef<Range>,
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
-FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
- ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
+FailureOr<SmallVector<OpFoldResult>>
+computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad,
+ ArrayRef<Range> iterationDomain,
+ const PadTilingInterfaceOptions &);
+
+/// Operations and values created in the process of padding a TilingInterface
+/// operation.
+struct PadTilingInterfaceResult {
+ /// The operands of the padded op.
+ SmallVector<tensor::PadOp> padOps;
+ /// The padded op, a clone of `toPad` with padded operands.
+ TilingInterface paddedOp;
+ /// Slices of the padded op's results, same types as `toPad`.
+ SmallVector<Value> replacements;
+};
-/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
-///
+/// Pad the iterator dimensions of `toPad`.
/// * "options.paddingSizes" indicates that each padding dimension should be
/// padded to the specified padding size.
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
// interpreted as the bounding box (dynamic) value to pad to.
/// * Use "options.paddingValues" to set the padding value of the created
// tensor::PadOp.
-/// * The tensor::PadOp is returned on success.
-
-FailureOr<TilingInterface>
-rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- const PadSizeComputationFunction &computePaddingSizeFun =
+FailureOr<PadTilingInterfaceResult>
+rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
+ const PadSizeComputationFunction & =
&computeIndexingMapOpInterfacePaddedShape);
namespace detail {
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index 3143ab7..99b22e5 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
Pure,
Terminator,
ReturnLike,
- ParentOneOf<["smt::SolverOp", "smt::CheckOp",
- "smt::ForallOp", "smt::ExistsOp"]>,
]> {
let summary = "terminator operation for various regions of SMT operations";
let arguments = (ins Variadic<AnyType>:$values);
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index b9d7163..5e68f75e 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
]> {
let summary = "All-gather over a device grid.";
let description = [{
- Gathers along the `gather_axis` tensor axis.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and replicates the result across all devices
+ in the group.
Example:
```mlir
@@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device grid.";
let description = [{
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ Each device in a group receives a replicated copy of the reduction result.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
- let summary = "All-slice over a device grid. This is the inverse of all-gather.";
+ let summary = "All-slice over a device grid.";
let description = [{
- Slice along the `slice_axis` tensor axis.
- This operation can be thought of as the inverse of all-gather.
- Technically, it is not required that all processes have the same input tensor.
- Each process will slice a piece of its local tensor based on its in-group device index.
- The operation does not communicate data between devices.
+ Within each device group defined by `grid_axes`, slices the input tensor along
+ the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if
+ the input data is replicated along the `slice_axis`.
+ Each process simply crops its local data to the slice corresponding to its
+ in-group device index.
+ Notice: `AllSliceOp` does not involve any communication between devices and
+ devices within a group may not have replicated input data.
Example:
```mlir
@@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
```
Result:
```
- gather tensor
+ slice tensor
axis 1
------------>
+-------+-------+
@@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device grid.";
let description = [{
- Performs an all-to-all on tensor pieces split along `split_axis`.
- The resulting pieces are concatenated along `concat_axis` on ech device.
+ Each participant logically splits its input along split_axis,
+ then scatters the resulting pieces across the group defined by `grid_axes`.
+ After receiving data pieces from other participants' scatters,
+ it concatenates them along concat_axis to produce the final result.
Example:
```
@@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
]> {
let summary = "Broadcast over a device grid.";
let description = [{
- Broadcast the tensor on `root` to all devices in each respective group.
- The operation broadcasts along grid axes `grid_axes`.
- The `root` device specifies the in-group multi-index that is broadcast to
- all other devices in the group.
+ Copies the input tensor on `root` to all devices in each group defined by
+ `grid_axes`. The `root` device is defined by its in-group multi-index.
+ The contents of input tensors on non-root devices are ignored.
Example:
```
@@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
- device (1, 0) -> | | | <- device (1, 1)
+ device (1, 0) -> | * * | * * | <- device (1, 1)
+-------+-------+
```
@@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
]> {
let summary = "Gather over a device grid.";
let description = [{
- Gathers on device `root` along the `gather_axis` tensor axis.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device for each device group.
- The result tensor on non-root devices is undefined.
- Using it will result in undefined behavior.
+ Concatenates all tensor slices from a device group defined by `grid_axes` along
+ the tensor dimension `gather_axis` and returns the resulting tensor on each
+ `root` device. The result on all other (non-root) devices is undefined.
+ The `root` device is defined by its in-group multi-index.
Example:
```mlir
@@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Receive from a device within a device group.
+ Receive tensor from device `source`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
+ The content of input tensor is ignored.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
]> {
let summary = "Reduce over a device grid.";
let description = [{
- Reduces on device `root` within each device group.
- `root` specifies the coordinates of a device along `grid_axes`.
- It uniquely identifies the root device within its device group.
- The accumulation element type is specified by the result type and
- it does not need to match the input element type.
- The input element is converted to the result element type before
- performing the reduction.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes`, using the specified reduction method. The operation performs an
+ element-wise reduction over the tensor slices from all devices in each group.
+ The reduction result will be returned on the `root` device of each group.
+ It is undefined on all other (non-root) devices.
+ The `root` device is defined by its in-group multi-index.
+ The accumulation element type is determined by the result type and does not
+ need to match the input element type. Before performing the reduction, each
+ input element is converted to the result element type.
Attributes:
`reduction`: Indicates the reduction method.
@@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device grid.";
let description = [{
- After the reduction, the result is scattered within each device group.
- The tensor is split along `scatter_axis` and the pieces distributed
- across the device group.
+ Reduces the input tensor across all devices within the groups defined by
+ `grid_axes` using the specified reduction method. The reduction is performed
+ element-wise across the tensor pieces from all devices in the group.
+ After reduction, the reduction result is scattered (split and distributed)
+ across the device group along `scatter_axis`.
Example:
```
shard.grid @grid0(shape = 2x2)
...
%1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
- : tensor<3x4xf32> -> tensor<1x4xf64>
+ : tensor<2x2xf32> -> tensor<1x2xf64>
```
Input:
```
@@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
Result:
```
+-------+
- | 6 8 | <- devices (0, 0)
+ | 5 6 | <- devices (0, 0)
+-------+
- | 10 12 | <- devices (0, 1)
+ | 7 8 | <- devices (0, 1)
+-------+
- | 22 24 | <- devices (1, 0)
+ | 13 14 | <- devices (1, 0)
+-------+
- | 26 28 | <- devices (1, 1)
+ | 15 16 | <- devices (1, 1)
+-------+
```
}];
@@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
]> {
let summary = "Scatter over a device grid.";
let description = [{
- For each device group split the input tensor on the `root` device along
- axis `scatter_axis` and scatter the parts across the group devices.
+ For each device group defined by `grid_axes`, the input tensor on the `root`
+ device is split along axis `scatter_axis` and distributed across the group.
+ The content of the input on all other (non-root) devices is ignored.
+ The `root` device is defined by its in-group multi-index.
Example:
```
@@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
(0, 1)
+-------+-------+ | scatter tensor
- device (0, 0) -> | | | | axis 0
- | | | ↓
+ device (0, 0) -> | * * | * * | | axis 0
+ | * * | * * | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
@@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
]> {
let summary = "Send over a device grid.";
let description = [{
- Send from one device to another within a device group.
+ Send input tensor to device `destination`, which is defined by its in-group
+ multi-index. The groups are defined by `grid_axes`.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
@@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
]> {
let summary = "Shift over a device grid.";
let description = [{
- Within each device group shift along grid axis `shift_axis` by an offset
- `offset`.
- The result on devices that do not have a corresponding source is undefined.
- `shift_axis` must be one of `grid_axes`.
- If the `rotate` attribute is present,
- instead of a shift a rotation is done.
+ Within each device group defined by `grid_axes`, shifts input tensors along the
+ device grid's axis `shift_axis` by the specified offset. The `shift_axis` must
+ be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular.
+ That is, the offset wraps around according to the group size along `shift_axis`.
+ Otherwise, the results on devices without a corresponding source are undefined.
Example:
```
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
index fc69b03..f6353a9 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
index b987cb3..9d9783a 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- NoTerminator
+ SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
]> {
let cppNamespace = [{ mlir::transform::smt }];
@@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
let description = [{
Allows expressing constraints on params using the SMT dialect.
- Each Transform dialect param provided as an operand has a corresponding
+ Each Transform-dialect param provided as an operand has a corresponding
argument of SMT-type in the region. The SMT-Dialect ops in the region use
- these arguments as operands.
+ these params-as-SMT-vars as operands, thereby expressing relevant
+ constraints on their allowed values.
+
+ Computations w.r.t. passed-in params can also be expressed through the
+ region's SMT-ops. Namely, the constraints express relationships to other
+ SMT-variables which can then be yielded from the region (with `smt.yield`).
The semantics of this op is that all the ops in the region together express
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
- op succeeds.
+ op succeeds and any one satisfying assignment is used to map the
+ SMT-variables yielded in the region to `transform.param`s.
---
@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
}];
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+ let results = (outs Variadic<TransformParamTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
- "`(` $params `)` attr-dict `:` type(operands) $body";
+ "`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
index b80ee2c..e9425e8 100644
--- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
+++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td
@@ -43,9 +43,41 @@ class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> :
let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target";
}
-def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {}
+def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<
+ "block",
+ "Create a nesting level with a label at its exit."> {
+ let description = [{
+ Defines a Wasm block, creating a new nested scope.
+ A block contains a body region and an optional list of input values.
+ Control can enter the block and later branch out to the block target.
+ Example:
+
+ ```mlir
+
+ wasmssa.block {
+
+ // instructions
+
+ } > ^successor
+ }];
+}
+
+def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<
+ "loop",
+ "Create a nesting level that define its entry as jump target."> {
+ let description = [{
+ Represents a Wasm loop construct. This defines a nesting level with
+ a label at the entry of the region.
-def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {}
+ Example:
+
+ ```mlir
+
+ wasmssa.loop {
+
+ } > ^successor
+ }];
+}
def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> {
@@ -55,9 +87,16 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator,
::mlir::Block* getTarget();
}];
let description = [{
- Marks a return from the current block.
+ Escape from the current nesting level and return the control flow to its successor.
+ Optionally, mark the arguments that should be transfered to the successor block.
- Example:
+ This shouldn't be confused with branch operations that targets the label defined
+ by the nesting level operation.
+
+ For instance, a `wasmssa.block_return` in a loop will give back control to the
+ successor of the loop, where a `branch` targeting the loop will flow back to the entry block of the loop.
+
+ Example:
```mlir
wasmssa.block_return
@@ -127,12 +166,18 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
- Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type
in the function type.
+ By default, `wasmssa.func` have nested visibility. Functions exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // A simple function with no arguments that returns a float32
+ // Internal function with no arguments that returns a float32
wasmssa.func @my_f32_func() -> f32
+ // Exported function with no arguments that returns a float32
+ wasmssa.func exported @my_f32_func() -> f32
+
// A function that takes a local ref argument
wasmssa.func @i64_wrap(%a: !wasmssa<local ref to i64>) -> i32
```
@@ -141,7 +186,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
WasmSSA_FuncTypeAttr: $functionType,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
- DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility);
+ UnitAttr: $exported);
let regions = (region AnyRegion: $body);
let extraClassDeclaration = [{
@@ -162,6 +207,12 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
@@ -207,8 +258,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
StrAttr: $importName,
WasmSSA_FuncTypeAttr: $type,
OptionalAttr<DictArrayAttr>:$arg_attrs,
- OptionalAttr<DictArrayAttr>:$res_attrs,
- OptionalAttr<StrAttr>:$sym_visibility);
+ OptionalAttr<DictArrayAttr>:$res_attrs);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
@@ -221,6 +271,10 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [
::llvm::ArrayRef<Type> getResultTypes() {
return getType().getResults();
}
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let builders = [
OpBuilder<(ins "StringRef":$symbol,
@@ -238,30 +292,41 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_ValTypeAttr: $type,
UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $exported);
let description = [{
WebAssembly global variable.
Body contains the initialization instructions for the variable value.
The body must contain only instructions considered `const` in a webassembly context,
such as `wasmssa.const` or `global.get`.
+ By default, `wasmssa.global` have nested visibility. Global exported by the module
+ are marked with the exported attribute. This gives them public visibility.
+
Example:
```mlir
- // Define a global_var, a mutable i32 global variable equal to 10.
- wasmssa.global @global_var i32 mutable nested : {
+ // Define module_global_var, an internal mutable i32 global variable equal to 10.
+ wasmssa.global @module_global_var i32 mutable : {
%[[VAL_0:.*]] = wasmssa.const 10 : i32
wasmssa.return %[[VAL_0]] : i32
}
+
+ // Define global_var, an exported constant i32 global variable equal to 42.
+ wasmssa.global @global_var i32 : {
+ %[[VAL_0:.*]] = wasmssa.const 42 : i32
+ wasmssa.return %[[VAL_0]] : i32
+ }
```
}];
let regions = (region AnyRegion: $initializer);
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "Type": $type,
- "bool": $isMutable)>
- ];
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
let hasCustomAssemblyFormat = 1;
}
@@ -283,18 +348,14 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [
StrAttr: $moduleName,
StrAttr: $importName,
WasmSSA_ValTypeAttr: $type,
- UnitAttr: $isMutable,
- OptionalAttr<StrAttr>:$sym_visibility);
+ UnitAttr: $isMutable);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [
- OpBuilder<(ins "StringRef":$symbol,
- "StringRef":$moduleName,
- "StringRef":$importName,
- "Type": $type,
- "bool": $isMutable)>
- ];
let hasCustomAssemblyFormat = 1;
}
@@ -442,23 +503,33 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> {
Define a memory to be used by the program.
Multiple memories can be defined in the same module.
+ By default, `wasmssa.memory` have nested visibility. Memory exported by
+ the module are marked with the exported attribute. This gives them public
+ visibility.
+
Example:
```mlir
- // Define the `mem_0` memory with defined bounds of 0 -> 65536
+ // Define the `mem_0` (internal) memory with defined size bounds of [0:65536]
wasmssa.memory @mem_0 !wasmssa<limit[0:65536]>
+
+ // Define the `mem_1` exported memory with minimal size of 512
+ wasmssa.memory exported @mem_1 !wasmssa<limit[512:]>
```
}];
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [
- OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::LimitType":$limit)>
- ];
+ UnitAttr: $exported);
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict";
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $limits attr-dict";
}
def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> {
@@ -476,16 +547,13 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]>
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_LimitTypeAttr: $limits,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_LimitTypeAttr: $limits);
let extraClassDeclaration = [{
- bool isDeclaration() const { return true; }
+ bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::LimitType":$limits)>];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
}
@@ -493,11 +561,15 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> {
let summary= "WebAssembly table value";
let arguments = (ins SymbolNameAttr: $sym_name,
WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "wasmssa::TableType":$type)>];
- let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict";
+ UnitAttr: $exported);
+ let extraClassDeclaration = [{
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return getExported() ?
+ ::mlir::SymbolTable::Visibility::Public :
+ ::mlir::SymbolTable::Visibility::Nested;
+ };
+ }];
+ let assemblyFormat = "(`exported` $exported^)? $sym_name $type attr-dict";
}
def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> {
@@ -515,17 +587,14 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterfac
let arguments = (ins SymbolNameAttr: $sym_name,
StrAttr: $moduleName,
StrAttr: $importName,
- WasmSSA_TableTypeAttr: $type,
- OptionalAttr<StrAttr>:$sym_visibility);
+ WasmSSA_TableTypeAttr: $type);
let extraClassDeclaration = [{
bool isDeclaration() const { return true; }
+ ::mlir::SymbolTable::Visibility getVisibility() {
+ return ::mlir::SymbolTable::Visibility::Nested;
+ };
}];
let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict";
- let builders = [OpBuilder<(ins
- "::llvm::StringRef":$symbol,
- "::llvm::StringRef":$moduleName,
- "::llvm::StringRef":$importName,
- "wasmssa::TableType":$type)>];
}
def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> {
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 6b4e3dd..8427ba5 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -623,6 +623,14 @@ class VectorOfLengthAndType<list<int> allowedLengths,
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+class FixedVectorOfShapeAndType<list<int> shape, Type elType>: ShapedContainerType<
+ [elType],
+ And<[IsVectorOfShape<shape>, IsFixedVectorOfAnyRankTypePred]>,
+ "vector<" # !interleave(shape, "x") # "x" # elType # ">",
+ "::mlir::VectorType">,
+ BuildableType<"::mlir::VectorType::get({" # !interleave(shape, " ,") # "} , " # elType.builderCall # " );">;
+
+
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
index 21adde8..cd9ef5b 100644
--- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
+++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h
@@ -19,6 +19,14 @@ namespace mlir {
struct WasmBinaryEncoding {
/// Byte encodings for Wasm instructions.
struct OpCode {
+ // Control instructions.
+ static constexpr std::byte block{0x02};
+ static constexpr std::byte loop{0x03};
+ static constexpr std::byte ifOpCode{0x04};
+ static constexpr std::byte elseOpCode{0x05};
+ static constexpr std::byte branchIf{0x0D};
+ static constexpr std::byte call{0x10};
+
// Locals, globals, constants.
static constexpr std::byte localGet{0x20};
static constexpr std::byte localSet{0x21};
@@ -29,6 +37,42 @@ struct WasmBinaryEncoding {
static constexpr std::byte constFP32{0x43};
static constexpr std::byte constFP64{0x44};
+ // Comparisons.
+ static constexpr std::byte eqzI32{0x45};
+ static constexpr std::byte eqI32{0x46};
+ static constexpr std::byte neI32{0x47};
+ static constexpr std::byte ltSI32{0x48};
+ static constexpr std::byte ltUI32{0x49};
+ static constexpr std::byte gtSI32{0x4A};
+ static constexpr std::byte gtUI32{0x4B};
+ static constexpr std::byte leSI32{0x4C};
+ static constexpr std::byte leUI32{0x4D};
+ static constexpr std::byte geSI32{0x4E};
+ static constexpr std::byte geUI32{0x4F};
+ static constexpr std::byte eqzI64{0x50};
+ static constexpr std::byte eqI64{0x51};
+ static constexpr std::byte neI64{0x52};
+ static constexpr std::byte ltSI64{0x53};
+ static constexpr std::byte ltUI64{0x54};
+ static constexpr std::byte gtSI64{0x55};
+ static constexpr std::byte gtUI64{0x56};
+ static constexpr std::byte leSI64{0x57};
+ static constexpr std::byte leUI64{0x58};
+ static constexpr std::byte geSI64{0x59};
+ static constexpr std::byte geUI64{0x5A};
+ static constexpr std::byte eqF32{0x5B};
+ static constexpr std::byte neF32{0x5C};
+ static constexpr std::byte ltF32{0x5D};
+ static constexpr std::byte gtF32{0x5E};
+ static constexpr std::byte leF32{0x5F};
+ static constexpr std::byte geF32{0x60};
+ static constexpr std::byte eqF64{0x61};
+ static constexpr std::byte neF64{0x62};
+ static constexpr std::byte ltF64{0x63};
+ static constexpr std::byte gtF64{0x64};
+ static constexpr std::byte leF64{0x65};
+ static constexpr std::byte geF64{0x66};
+
// Numeric operations.
static constexpr std::byte clzI32{0x67};
static constexpr std::byte ctzI32{0x68};
@@ -93,6 +137,33 @@ struct WasmBinaryEncoding {
static constexpr std::byte maxF64{0xA5};
static constexpr std::byte copysignF64{0xA6};
static constexpr std::byte wrap{0xA7};
+
+ // Conversion operations
+ static constexpr std::byte extendS{0xAC};
+ static constexpr std::byte extendU{0xAD};
+ static constexpr std::byte convertSI32F32{0xB2};
+ static constexpr std::byte convertUI32F32{0xB3};
+ static constexpr std::byte convertSI64F32{0xB4};
+ static constexpr std::byte convertUI64F32{0xB5};
+
+ static constexpr std::byte demoteF64ToF32{0xB6};
+
+ static constexpr std::byte convertSI32F64{0xB7};
+ static constexpr std::byte convertUI32F64{0xB8};
+ static constexpr std::byte convertSI64F64{0xB9};
+ static constexpr std::byte convertUI64F64{0xBA};
+
+ static constexpr std::byte promoteF32ToF64{0xBB};
+ static constexpr std::byte reinterpretF32AsI32{0xBC};
+ static constexpr std::byte reinterpretF64AsI64{0xBD};
+ static constexpr std::byte reinterpretI32AsF32{0xBE};
+ static constexpr std::byte reinterpretI64AsF64{0xBF};
+
+ static constexpr std::byte extendI328S{0xC0};
+ static constexpr std::byte extendI3216S{0xC1};
+ static constexpr std::byte extendI648S{0xC2};
+ static constexpr std::byte extendI6416S{0xC3};
+ static constexpr std::byte extendI6432S{0xC4};
};
/// Byte encodings of types in Wasm binaries