diff options
Diffstat (limited to 'mlir')
185 files changed, 5811 insertions, 1130 deletions
diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md index 686e500..2622c08 100644 --- a/mlir/docs/Canonicalization.md +++ b/mlir/docs/Canonicalization.md @@ -55,7 +55,7 @@ Some important things to think about w.r.t. canonicalization patterns: * It is always good to eliminate operations entirely when possible, e.g. by folding known identities (like "x + 0 = x"). -* Pattens with expensive running time (i.e. have O(n) complexity) or +* Patterns with expensive running time (i.e. have O(n) complexity) or complicated cost models don't belong to canonicalization: since the algorithm is executed iteratively until fixed-point we want patterns that execute quickly (in particular their matching phase). diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md index eb6ff61..573e888 100644 --- a/mlir/docs/Dialects/Shard.md +++ b/mlir/docs/Dialects/Shard.md @@ -27,9 +27,9 @@ the tensor is sharded - not specified manually. ### Device Groups -Each collective operation runs within a group of devices. You define groups -using the `grid` and `grid_axes` attributes, which describe how to slice the -full device grid into smaller groups. +Collective operations run within groups of devices, which are defined +using the `grid` and `grid_axes` attributes. These describe +how the full device grid is sliced into smaller groups. Devices that have the same coordinates *outside* the listed `grid_axes` belong to the same group. diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 2db1d84..fe42a20 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -352,7 +352,7 @@ typedef struct { /// Create a rewrite pattern that matches the operation /// with the given rootName, corresponding to mlir::OpRewritePattern. -MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( +MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames); diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h index 46573e79..60f1888 100644 --- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/IR/PatternMatch.h" #include <memory> @@ -19,8 +20,11 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" /// Populate the given list with patterns that convert from Math to ROCDL calls. -void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns); +// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`, +// none of the chipset dependent patterns are added. +void populateMathToROCDLConversionPatterns( + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + std::optional<amdgpu::Chipset> chipset); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 25e9d34..9f76f5d 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { let summary = "Convert Math dialect to ROCDL library calls"; let description = [{ This pass converts supported Math ops to ROCDL library calls. + + The chipset option specifies the target AMDGPU architecture. If the chipset + is empty, none of the chipset-dependent patterns are added, and the pass + will not attempt to parse the chipset. }]; let dependentDialects = [ "arith::ArithDialect", @@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; + let options = [Option<"chipset", "chipset", "std::string", + /*default=*/"\"\"", + "Chipset that these operations will run on">]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h index 9b59af7..830c394 100644 --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -61,7 +61,7 @@ LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor); /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops); +[[maybe_unused]] bool isPerfectlyNested(ArrayRef<AffineForOp> loops); /// Get perfectly nested sequence of loops starting at root of loop nest /// (the first op being another AffineFor, and the second op - a terminator). diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9753dca..d0811a2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", custom<ShuffleType>(ref(type($v1)), type($res), ref($mask)) }]; + let hasFolder = 1; let hasVerifier = 1; string llvmInstName = "ShuffleVector"; @@ -1985,6 +1986,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ OptionalAttr<StrAttr>:$instrument_function_exit, OptionalAttr<UnitAttr>:$no_inline, OptionalAttr<UnitAttr>:$always_inline, + OptionalAttr<UnitAttr>:$inline_hint, OptionalAttr<UnitAttr>:$no_unwind, OptionalAttr<UnitAttr>:$will_return, OptionalAttr<UnitAttr>:$optimize_none, @@ -2037,6 +2039,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ /// Returns true if the `always_inline` attribute is set, false otherwise. bool isAlwaysInline() { return bool(getAlwaysInlineAttr()); } + /// Returns true if the `inline_hint` attribute is set, false otherwise. + bool isInlineHint() { return bool(getInlineHintAttr()); } + /// Returns true if the `optimize_none` attribute is set, false otherwise. bool isOptimizeNone() { return bool(getOptimizeNoneAttr()); } }]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 6925cec..68f31e6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -412,6 +412,32 @@ def ROCDL_WaitExpcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.expcnt", [], 0, [0], let assemblyFormat = "$count attr-dict"; } +def ROCDL_WaitAsynccntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.asynccnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until ASYNCCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + +def ROCDL_WaitTensorcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.tensorcnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until TENSORCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>, Arguments<(ins I16Attr:$priority)> { let assemblyFormat = "$priority attr-dict"; diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index 8f87235..b8aa497 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -183,6 +183,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() { return StringLiteral("acc.routine_info"); } +static constexpr StringLiteral getVarNameAttrName() { + return VarNameAttr::name; +} + static constexpr StringLiteral getCombinedConstructsAttrName() { return CombinedConstructsTypeAttr::name; } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index fecf81b..1eaa21b4 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -415,6 +415,13 @@ def OpenACC_ConstructResource : Resource<"::mlir::acc::ConstructResource">; // Define a resource for the OpenACC current device setting. def OpenACC_CurrentDeviceIdResource : Resource<"::mlir::acc::CurrentDeviceIdResource">; +// Attribute for saving variable names - this can be attached to non-acc-dialect +// operations in order to ensure the name is preserved. +def OpenACC_VarNameAttr : OpenACC_Attr<"VarName", "var_name"> { + let parameters = (ins StringRefParameter<"">:$name); + let assemblyFormat = "`<` $name `>`"; +} + // Used for data specification in data clauses (2.7.1). // Either (or both) extent and upperbound must be specified. def OpenACC_DataBoundsOp : OpenACC_Op<"bounds", diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td index 6736bc8..93e9e3d 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td @@ -73,12 +73,18 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> { InterfaceMethod< /*description=*/[{ Generates allocation operations for the pointer-like type. It will create - an allocate that produces memory space for an instance of the current type. + an allocate operation that produces memory space for an instance of the + current type. The `varName` parameter is optional and can be used to provide a name - for the allocated variable. If the current type is represented - in a way that it does not capture the pointee type, `varType` must be - passed in to provide the necessary type information. + for the allocated variable. When provided, it must be used by the + implementation; and if the implementing dialect does not have its own + way to save it, the discardable `acc.var_name` attribute from the acc + dialect will be used. + + If the current type is represented in a way that it does not capture + the pointee type, `varType` must be passed in to provide the necessary + type information. The `originalVar` parameter is optional but enables support for dynamic types (e.g., dynamic memrefs). When provided, implementations can extract diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index b9d7163..5e68f75e 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -494,7 +494,9 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [ ]> { let summary = "All-gather over a device grid."; let description = [{ - Gathers along the `gather_axis` tensor axis. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and replicates the result across all devices + in the group. Example: ```mlir @@ -546,10 +548,13 @@ def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [ SameOperandsAndResultShape]> { let summary = "All-reduce over a device grid."; let description = [{ - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + Each device in a group receives a replicated copy of the reduction result. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -583,13 +588,15 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ SameOperandsAndResultElementType, SameOperandsAndResultRank ]> { - let summary = "All-slice over a device grid. This is the inverse of all-gather."; + let summary = "All-slice over a device grid."; let description = [{ - Slice along the `slice_axis` tensor axis. - This operation can be thought of as the inverse of all-gather. - Technically, it is not required that all processes have the same input tensor. - Each process will slice a piece of its local tensor based on its in-group device index. - The operation does not communicate data between devices. + Within each device group defined by `grid_axes`, slices the input tensor along + the `slice_axis` dimension. It can be viewed as the inverse of an all-gather if + the input data is replicated along the `slice_axis`. + Each process simply crops its local data to the slice corresponding to its + in-group device index. + Notice: `AllSliceOp` does not involve any communication between devices and + devices within a group may not have replicated input data. Example: ```mlir @@ -610,7 +617,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [ ``` Result: ``` - gather tensor + slice tensor axis 1 ------------> +-------+-------+ @@ -646,8 +653,10 @@ def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [ SameOperandsAndResultRank]> { let summary = "All-to-all over a device grid."; let description = [{ - Performs an all-to-all on tensor pieces split along `split_axis`. - The resulting pieces are concatenated along `concat_axis` on ech device. + Each participant logically splits its input along split_axis, + then scatters the resulting pieces across the group defined by `grid_axes`. + After receiving data pieces from other participants' scatters, + it concatenates them along concat_axis to produce the final result. Example: ``` @@ -702,10 +711,9 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ ]> { let summary = "Broadcast over a device grid."; let description = [{ - Broadcast the tensor on `root` to all devices in each respective group. - The operation broadcasts along grid axes `grid_axes`. - The `root` device specifies the in-group multi-index that is broadcast to - all other devices in the group. + Copies the input tensor on `root` to all devices in each group defined by + `grid_axes`. The `root` device is defined by its in-group multi-index. + The contents of input tensors on non-root devices are ignored. Example: ``` @@ -722,7 +730,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [ +-------+-------+ | broadcast device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 +-------+-------+ ↓ - device (1, 0) -> | | | <- device (1, 1) + device (1, 0) -> | * * | * * | <- device (1, 1) +-------+-------+ ``` @@ -758,11 +766,10 @@ def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [ ]> { let summary = "Gather over a device grid."; let description = [{ - Gathers on device `root` along the `gather_axis` tensor axis. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device for each device group. - The result tensor on non-root devices is undefined. - Using it will result in undefined behavior. + Concatenates all tensor slices from a device group defined by `grid_axes` along + the tensor dimension `gather_axis` and returns the resulting tensor on each + `root` device. The result on all other (non-root) devices is undefined. + The `root` device is defined by its in-group multi-index. Example: ```mlir @@ -821,7 +828,9 @@ def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [ ]> { let summary = "Send over a device grid."; let description = [{ - Receive from a device within a device group. + Receive tensor from device `source`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. + The content of input tensor is ignored. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -845,13 +854,15 @@ def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [ ]> { let summary = "Reduce over a device grid."; let description = [{ - Reduces on device `root` within each device group. - `root` specifies the coordinates of a device along `grid_axes`. - It uniquely identifies the root device within its device group. - The accumulation element type is specified by the result type and - it does not need to match the input element type. - The input element is converted to the result element type before - performing the reduction. + Reduces the input tensor across all devices within the groups defined by + `grid_axes`, using the specified reduction method. The operation performs an + element-wise reduction over the tensor slices from all devices in each group. + The reduction result will be returned on the `root` device of each group. + It is undefined on all other (non-root) devices. + The `root` device is defined by its in-group multi-index. + The accumulation element type is determined by the result type and does not + need to match the input element type. Before performing the reduction, each + input element is converted to the result element type. Attributes: `reduction`: Indicates the reduction method. @@ -886,16 +897,18 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" SameOperandsAndResultRank]> { let summary = "Reduce-scatter over a device grid."; let description = [{ - After the reduction, the result is scattered within each device group. - The tensor is split along `scatter_axis` and the pieces distributed - across the device group. + Reduces the input tensor across all devices within the groups defined by + `grid_axes` using the specified reduction method. The reduction is performed + element-wise across the tensor pieces from all devices in the group. + After reduction, the reduction result is scattered (split and distributed) + across the device group along `scatter_axis`. Example: ``` shard.grid @grid0(shape = 2x2) ... %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] reduction = <max> scatter_axis = 0 - : tensor<3x4xf32> -> tensor<1x4xf64> + : tensor<2x2xf32> -> tensor<1x2xf64> ``` Input: ``` @@ -916,13 +929,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter" Result: ``` +-------+ - | 6 8 | <- devices (0, 0) + | 5 6 | <- devices (0, 0) +-------+ - | 10 12 | <- devices (0, 1) + | 7 8 | <- devices (0, 1) +-------+ - | 22 24 | <- devices (1, 0) + | 13 14 | <- devices (1, 0) +-------+ - | 26 28 | <- devices (1, 1) + | 15 16 | <- devices (1, 1) +-------+ ``` }]; @@ -950,8 +963,10 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ ]> { let summary = "Scatter over a device grid."; let description = [{ - For each device group split the input tensor on the `root` device along - axis `scatter_axis` and scatter the parts across the group devices. + For each device group defined by `grid_axes`, the input tensor on the `root` + device is split along axis `scatter_axis` and distributed across the group. + The content of the input on all other (non-root) devices is ignored. + The `root` device is defined by its in-group multi-index. Example: ``` @@ -968,8 +983,8 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [ (0, 1) ↓ +-------+-------+ | scatter tensor - device (0, 0) -> | | | | axis 0 - | | | ↓ + device (0, 0) -> | * * | * * | | axis 0 + | * * | * * | ↓ +-------+-------+ device (1, 0) -> | 1 2 | 5 6 | | 3 4 | 7 8 | @@ -1018,7 +1033,8 @@ def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [ ]> { let summary = "Send over a device grid."; let description = [{ - Send from one device to another within a device group. + Send input tensor to device `destination`, which is defined by its in-group + multi-index. The groups are defined by `grid_axes`. }]; let arguments = !con(commonArgs, (ins AnyNon0RankedTensor:$input, @@ -1043,12 +1059,11 @@ def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [ ]> { let summary = "Shift over a device grid."; let description = [{ - Within each device group shift along grid axis `shift_axis` by an offset - `offset`. - The result on devices that do not have a corresponding source is undefined. - `shift_axis` must be one of `grid_axes`. - If the `rotate` attribute is present, - instead of a shift a rotation is done. + Within each device group defined by `grid_axes`, shifts input tensors along the + device grid's axis `shift_axis` by the specified offset. The `shift_axis` must + be one of the `grid_axes`. If the `rotate` attribute is set, the shift is circular. + That is, the offset wraps around according to the group size along `shift_axis`. + Otherwise, the results on devices without a corresponding source are undefined. Example: ``` diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 10491f6..4ecf03c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); /// returned by getDefaultTargetEnv() if not provided. TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); +/// A thin wrapper around the SpecificationVersion enum to represent +/// and provide utilities around the TOSA specification version. +class TosaSpecificationVersion { +public: + TosaSpecificationVersion(uint32_t major, uint32_t minor) + : majorVersion(major), minorVersion(minor) {} + TosaSpecificationVersion(SpecificationVersion version) + : TosaSpecificationVersion(fromVersionEnum(version)) {} + + bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const { + return this->majorVersion == baseVersion.majorVersion && + this->minorVersion >= baseVersion.minorVersion; + } + + uint32_t getMajor() const { return majorVersion; } + uint32_t getMinor() const { return minorVersion; } + +private: + uint32_t majorVersion = 0; + uint32_t minorVersion = 0; + + static TosaSpecificationVersion + fromVersionEnum(SpecificationVersion version) { + switch (version) { + case SpecificationVersion::V_1_0: + return TosaSpecificationVersion(1, 0); + case SpecificationVersion::V_1_1_DRAFT: + return TosaSpecificationVersion(1, 1); + } + llvm_unreachable("Unknown TOSA version"); + } +}; + +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version); + /// This class represents the capability enabled in the target implementation /// such as profile, extension, and level. It's a wrapper class around /// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles, + explicit TargetEnv(SpecificationVersion specificationVersion, Level level, + const ArrayRef<Profile> &profiles, const ArrayRef<Extension> &extensions) - : level(level) { + : specificationVersion(specificationVersion), level(level) { enabledProfiles.insert_range(profiles); enabledExtensions.insert_range(extensions); } explicit TargetEnv(TargetEnvAttr targetAttr) - : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), - targetAttr.getExtensions()) {} + : TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()) {} void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } - // TODO implement the following utilities. - // Version getSpecVersion() const; + SpecificationVersion getSpecVersion() const { return specificationVersion; } TosaLevel getLevel() const { if (level == Level::eightK) @@ -105,6 +140,7 @@ public: } private: + SpecificationVersion specificationVersion; Level level; llvm::SmallSet<Profile, 3> enabledProfiles; llvm::SmallSet<Extension, 13> enabledExtensions; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index 1f718ac..c1b5e78 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -2,441 +2,779 @@ // `tools/genspec.py` in https://git.mlplatform.org/tosa/specification.git profileComplianceMap = { {"tosa.argmax", - {{{Profile::pro_int}, {{i8T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i8T, i8T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp32T}, - {fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.max_pool2d", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Profile::pro_int}, {{i8T, i8T, i32T, i8T, i8T, i32T, i32T}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, - {fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, - {fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T, fp16T, fp16T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp16T, fp16T, fp16T, fp16T, fp16T, fp32T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T, fp32T, fp32T, fp32T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Profile::pro_int}, {{i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.erf", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sigmoid", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.tanh", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.add", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.arithmetic_right_shift", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_and", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_or", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_xor", {{{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.intdiv", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_and", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_left_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_right_shift", {{{Profile::pro_int, Profile::pro_fp}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}}}, {"tosa.logical_or", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.logical_xor", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.maximum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.minimum", - {{{Profile::pro_int}, {{i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.mul", - {{{Profile::pro_int}, {{i8T, i8T, i32T}, {i16T, i16T, i32T}}}, - {{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pow", - {{{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.sub", - {{{Profile::pro_int, Profile::pro_fp}, {{i32T, i32T, i32T}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, - {"tosa.table", {{{Profile::pro_int}, {{i8T, i8T, i8T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Profile::pro_int}, {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}}}}}, {"tosa.abs", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.bitwise_not", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}}}, - {"tosa.ceil", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.clz", {{{Profile::pro_int}, {{i32T, i32T}}}}}, - {"tosa.cos", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.exp", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.floor", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.log", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.clz", + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.logical_not", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.negate", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}}, + {{{fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reciprocal", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, - {"tosa.sin", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.select", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.greater_equal", - {{{Profile::pro_int}, {{i32T, i32T, boolT}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, boolT}, {fp32T, fp32T, boolT}}}}}, + {{{Profile::pro_int}, + {{{i32T, i32T, boolT}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, boolT}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, boolT}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_all", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_any", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}}}, {"tosa.reduce_max", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_min", - {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_product", - {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reduce_sum", - {{{Profile::pro_int}, {{i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, {{{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT, boolT}}, anyOf}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, {{Profile::pro_int}, - {{i8T, i8T, i8T}, {i16T, i16T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T}}}}}, + {{{i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}, anyOf}, - {{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int, Profile::pro_fp}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}}, + anyOf}, + {{Profile::pro_int}, + {{{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", {{{Profile::pro_int}, - {{i8T, i32T, i8T}, {i16T, i32T, i16T}, {i32T, i32T, i32T}}}, - {{Profile::pro_fp}, {{fp16T, i32T, fp16T}, {fp32T, i32T, fp32T}}}}}, + {{{i8T, i32T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, i32T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", {{{Profile::pro_int}, - {{i8T, i32T, i8T, i8T}, - {i16T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}, + {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}, {{Profile::pro_fp}, - {{fp16T, i32T, fp16T, fp16T}, {fp32T, i32T, fp32T, fp32T}}}}}, + {{{fp16T, i32T, fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Profile::pro_int}, {{i8T, i32T}, {i8T, i8T}}}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{{Profile::pro_int}, + {{{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Profile::pro_int}, - {{boolT, i8T}, - {boolT, i16T}, - {boolT, i32T}, - {i8T, boolT}, - {i8T, i16T}, - {i8T, i32T}, - {i16T, boolT}, - {i16T, i8T}, - {i16T, i32T}, - {i32T, boolT}, - {i32T, i8T}, - {i32T, i16T}}}, - {{Profile::pro_fp}, - {{i8T, fp16T}, - {i8T, fp32T}, - {i16T, fp16T}, - {i16T, fp32T}, - {i32T, fp16T}, - {i32T, fp32T}, - {fp16T, i8T}, - {fp16T, i16T}, - {fp16T, i32T}, - {fp16T, fp32T}, - {fp32T, i8T}, - {fp32T, i16T}, - {fp32T, i32T}, - {fp32T, fp16T}}}}}, + {{{boolT, i8T}, SpecificationVersion::V_1_0}, + {{boolT, i16T}, SpecificationVersion::V_1_0}, + {{boolT, i32T}, SpecificationVersion::V_1_0}, + {{i8T, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, boolT}, SpecificationVersion::V_1_0}, + {{i16T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, boolT}, SpecificationVersion::V_1_0}, + {{i32T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{i8T, fp16T}, SpecificationVersion::V_1_0}, + {{i8T, fp32T}, SpecificationVersion::V_1_0}, + {{i16T, fp16T}, SpecificationVersion::V_1_0}, + {{i16T, fp32T}, SpecificationVersion::V_1_0}, + {{i32T, fp16T}, SpecificationVersion::V_1_0}, + {{i32T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, i8T}, SpecificationVersion::V_1_0}, + {{fp16T, i16T}, SpecificationVersion::V_1_0}, + {{fp16T, i32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, i8T}, SpecificationVersion::V_1_0}, + {{fp32T, i16T}, SpecificationVersion::V_1_0}, + {{fp32T, i32T}, SpecificationVersion::V_1_0}, + {{fp32T, fp16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.rescale", {{{Profile::pro_int}, - {{i8T, i8T, i8T, i8T}, - {i8T, i8T, i16T, i16T}, - {i8T, i8T, i32T, i32T}, - {i16T, i16T, i8T, i8T}, - {i16T, i16T, i16T, i16T}, - {i16T, i16T, i32T, i32T}, - {i32T, i32T, i8T, i8T}, - {i32T, i32T, i16T, i16T}, - {i32T, i32T, i32T, i32T}}}}}, + {{{i8T, i8T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i16T, i16T, i32T, i32T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT}, {i8T}, {i16T}, {i32T}}, + {{{boolT}, SpecificationVersion::V_1_0}, + {{i8T}, SpecificationVersion::V_1_0}, + {{i16T}, SpecificationVersion::V_1_0}, + {{i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", {{{Profile::pro_int, Profile::pro_fp}, - {{boolT, boolT}, {i8T, i8T}, {i16T, i16T}, {i32T, i32T}}, + {{{boolT, boolT}, SpecificationVersion::V_1_0}, + {{i8T, i8T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}, + {{i32T, i32T}, SpecificationVersion::V_1_0}}, anyOf}, - {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}}, + {{Profile::pro_fp}, + {{{fp16T, fp16T}, SpecificationVersion::V_1_0}, + {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Profile::pro_int}, {{i8T}}}, {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}}, + {{{Profile::pro_int}, {{{i8T}, SpecificationVersion::V_1_0}}}, + {{Profile::pro_fp}, + {{{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; extensionComplianceMap = { {"tosa.argmax", - {{{Extension::int16}, {{i16T, i32T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, i32T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}}, - {{Extension::bf16}, {{bf16T, i32T}}}}}, + {{{Extension::int16}, {{{i16T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.avg_pool2d", - {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}, + SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}, + SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.conv3d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.depthwise_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, - {"tosa.fft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T, fp32T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, + {"tosa.fft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.matmul", - {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}}, + {{{Extension::int16}, + {{{i16T, i16T, i16T, i16T, i48T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, - {fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}}}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, - {fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}}}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}, {{Extension::fp8e4m3, Extension::fp8e5m2}, - {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, - {fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, - {fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}}, + {{{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8e5m2T, fp8e4m3T, fp8e5m2T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp16T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8e4m3T, fp8e5m2T, fp8e4m3T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}, allOf}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.max_pool2d", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rfft2d", {{{Extension::fft}, {{fp32T, fp32T, fp32T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rfft2d", + {{{Extension::fft}, + {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose_conv2d", - {{{Extension::int4}, {{i8T, i4T, i32T, i8T, i4T, i32T, i32T}}}, - {{Extension::int16}, {{i16T, i8T, i48T, i16T, i8T, i48T, i48T}}}, + {{{Extension::int4}, + {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, + {{{i16T, i8T, i48T, i16T, i8T, i48T, i48T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}}}, + {{{fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T, fp8e4m3T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}}}, + {{{fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T, fp8e5m2T, fp16T, fp16T}, + SpecificationVersion::V_1_0}}}, {{Extension::bf16}, - {{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}}}}}, + {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, + SpecificationVersion::V_1_0}}}}}, {"tosa.clamp", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.erf", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sigmoid", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.tanh", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.add", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.maximum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.minimum", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.mul", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.pow", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.sub", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.table", {{{Extension::int16}, {{i16T, i16T, i32T}}}}}, - {"tosa.abs", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.ceil", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.cos", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}}, - {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.sin", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, - {"tosa.equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.greater_equal", {{{Extension::bf16}, {{bf16T, bf16T, boolT}}}}}, - {"tosa.reduce_max", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_min", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_product", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.reduce_sum", {{{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.erf", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sigmoid", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.tanh", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.add", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.maximum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.minimum", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.mul", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.pow", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sub", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.table", + {{{Extension::int16}, + {{{i16T, i16T, i32T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.abs", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.ceil", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.cos", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.exp", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.floor", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.log", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.negate", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reciprocal", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.rsqrt", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.sin", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.select", + {{{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.greater_equal", + {{{Extension::bf16}, + {{{bf16T, bf16T, boolT}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_max", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_min", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_product", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.reduce_sum", + {{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.concat", - {{{Extension::int16}, {{i16T, i16T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.pad", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reshape", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.reverse", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.slice", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.tile", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.transpose", - {{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.gather", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.scatter", - {{{Extension::fp8e4m3}, {{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, i32T, bf16T, bf16T}}}}}, + {{{Extension::fp8e4m3}, + {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, i32T, fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, + {{{bf16T, i32T, bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.resize", - {{{Extension::int16}, {{i16T, i48T}, {i16T, i16T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, + {{{Extension::int16}, + {{{i16T, i48T}, SpecificationVersion::V_1_0}, + {{i16T, i16T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.cast", {{{Extension::bf16}, - {{i8T, bf16T}, - {i16T, bf16T}, - {i32T, bf16T}, - {bf16T, i8T}, - {bf16T, i16T}, - {bf16T, i32T}, - {bf16T, fp32T}, - {fp32T, bf16T}}}, + {{{i8T, bf16T}, SpecificationVersion::V_1_0}, + {{i16T, bf16T}, SpecificationVersion::V_1_0}, + {{i32T, bf16T}, SpecificationVersion::V_1_0}, + {{bf16T, i8T}, SpecificationVersion::V_1_0}, + {{bf16T, i16T}, SpecificationVersion::V_1_0}, + {{bf16T, i32T}, SpecificationVersion::V_1_0}, + {{bf16T, fp32T}, SpecificationVersion::V_1_0}, + {{fp32T, bf16T}, SpecificationVersion::V_1_0}}}, {{Extension::bf16, Extension::fp8e4m3}, - {{bf16T, fp8e4m3T}, {fp8e4m3T, bf16T}}, + {{{bf16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::bf16, Extension::fp8e5m2}, - {{bf16T, fp8e5m2T}, {fp8e5m2T, bf16T}}, + {{{bf16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, bf16T}, SpecificationVersion::V_1_0}}, allOf}, {{Extension::fp8e4m3}, - {{fp8e4m3T, fp16T}, - {fp8e4m3T, fp32T}, - {fp16T, fp8e4m3T}, - {fp32T, fp8e4m3T}}}, + {{{fp8e4m3T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e4m3T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e4m3T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, {{Extension::fp8e5m2}, - {{fp8e5m2T, fp16T}, - {fp8e5m2T, fp32T}, - {fp16T, fp8e5m2T}, - {fp32T, fp8e5m2T}}}}}, + {{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0}, + {{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0}, + {{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0}, + {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}}, {"tosa.rescale", {{{Extension::int16}, - {{i48T, i48T, i8T, i8T}, - {i48T, i48T, i16T, i16T}, - {i48T, i48T, i32T, i32T}}}}}, + {{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i16T, i16T}, SpecificationVersion::V_1_0}, + {{i48T, i48T, i32T, i32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.const", - {{{Extension::int4}, {{i4T}}}, - {{Extension::int16}, {{i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T}}}}}, + {{{Extension::int4}, {{{i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}}, {"tosa.identity", - {{{Extension::int4}, {{i4T, i4T}}}, - {{Extension::int16}, {{i48T, i48T}}}, - {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}}, - {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T}}}, - {{Extension::bf16}, {{bf16T, bf16T}}}}}, - {"tosa.variable", {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}}, + {{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e4m3}, + {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}}, + {{Extension::fp8e5m2}, + {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}}, + {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.variable", + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_write", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, {"tosa.variable_read", - {{{Extension::variable}, {{i8T}, {fp16T}, {fp32T}}}}}, + {{{Extension::variable}, + {{{i8T}, SpecificationVersion::V_1_0}, + {{fp16T}, SpecificationVersion::V_1_0}, + {{fp32T}, SpecificationVersion::V_1_0}}}}}, }; + // End of auto-generated metadata diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 38cb293..8376a4c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic, } //===----------------------------------------------------------------------===// -// TOSA Spec Section 1.5. +// TOSA Profiles and extensions // // Profile: // INT : Integer Inference. Integer operations, primarily 8 and 32-bit values. @@ -293,12 +293,6 @@ def Tosa_ExtensionAttr def Tosa_ExtensionArrayAttr : TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">; -def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; -def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; - -def Tosa_LevelAttr - : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; - // The base class for defining op availability dimensions. class Availability { // The following are fields for controlling the generated C++ OpInterface. @@ -405,17 +399,40 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability { } //===----------------------------------------------------------------------===// +// TOSA Levels +//===----------------------------------------------------------------------===// + +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; + +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; + +//===----------------------------------------------------------------------===// +// TOSA Specification versions +//===----------------------------------------------------------------------===// + +def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">; +def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">; + +def Tosa_SpecificationVersion : Tosa_I32EnumAttr< + "SpecificationVersion", "TOSA specification version", "specification_version", + [Tosa_V_1_0, Tosa_V_1_1_DRAFT]>; + +//===----------------------------------------------------------------------===// // TOSA target environment. //===----------------------------------------------------------------------===// def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { let summary = "Target environment information."; let parameters = ( ins + "SpecificationVersion": $specification_version, "Level": $level, ArrayRefParameter<"Profile">: $profiles, ArrayRefParameter<"Extension">: $extensions ); - let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` " + "`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " "`extensions` `=` `[` $extensions `]` `>`"; } diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 8f5c72b..7b946ad 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -36,12 +36,15 @@ enum CheckCondition { allOf }; +using VersionedTypeInfo = + std::pair<SmallVector<TypeInfo>, SpecificationVersion>; + template <typename T> struct OpComplianceInfo { // Certain operations require multiple modes enabled. // e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3. SmallVector<T> mode; - SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet; + SmallVector<VersionedTypeInfo> operandTypeInfoSet; CheckCondition condition = CheckCondition::anyOf; }; @@ -130,9 +133,8 @@ public: // Find the required profiles or extensions from the compliance info according // to the operand type combination. template <typename T> - SmallVector<T> findMatchedProfile(Operation *op, - SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition); + OpComplianceInfo<T> + findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo); SmallVector<Profile> getCooperativeProfiles(Extension ext) { switch (ext) { @@ -168,8 +170,7 @@ public: private: template <typename T> - FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op, - CheckCondition &condition); + FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op); OperationProfileComplianceMap profileComplianceMap; OperationExtensionComplianceMap extensionComplianceMap; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 6ae19d8..14b00b0 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { ]; let options = [ + Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion", + /*default=*/"mlir::tosa::SpecificationVersion::V_1_0", + "The specification version that TOSA operators should conform to.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"), + clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft") + )}]>, Option<"level", "level", "mlir::tosa::Level", /*default=*/"mlir::tosa::Level::eightK", "The TOSA level that operators should conform to. A TOSA level defines " diff --git a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td index b80ee2c..e9425e8 100644 --- a/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td +++ b/mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td @@ -43,9 +43,41 @@ class WasmSSA_BlockLikeOp<string mnemonic, string summaryStr> : let assemblyFormat = "(`(`$inputs^`)` `:` type($inputs))? attr-dict `:` $body `>` $target"; } -def WasmSSA_BlockOp : WasmSSA_BlockLikeOp<"block", "Create a nesting level"> {} +def WasmSSA_BlockOp : WasmSSA_BlockLikeOp< + "block", + "Create a nesting level with a label at its exit."> { + let description = [{ + Defines a Wasm block, creating a new nested scope. + A block contains a body region and an optional list of input values. + Control can enter the block and later branch out to the block target. + Example: + + ```mlir + + wasmssa.block { + + // instructions + + } > ^successor + }]; +} + +def WasmSSA_LoopOp : WasmSSA_BlockLikeOp< + "loop", + "Create a nesting level that define its entry as jump target."> { + let description = [{ + Represents a Wasm loop construct. This defines a nesting level with + a label at the entry of the region. -def WasmSSA_LoopOp : WasmSSA_BlockLikeOp<"loop", "Create a nesting level similar to Block Op, except that it has itself as a successor."> {} + Example: + + ```mlir + + wasmssa.loop { + + } > ^successor + }]; +} def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator, DeclareOpInterfaceMethods<LabelBranchingOpInterface>]> { @@ -55,9 +87,16 @@ def WasmSSA_BlockReturnOp : WasmSSA_Op<"block_return", [Terminator, ::mlir::Block* getTarget(); }]; let description = [{ - Marks a return from the current block. + Escape from the current nesting level and return the control flow to its successor. + Optionally, mark the arguments that should be transfered to the successor block. - Example: + This shouldn't be confused with branch operations that targets the label defined + by the nesting level operation. + + For instance, a `wasmssa.block_return` in a loop will give back control to the + successor of the loop, where a `branch` targeting the loop will flow back to the entry block of the loop. + + Example: ```mlir wasmssa.block_return @@ -127,12 +166,18 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ - Arguments of the entry block of type `!wasm<local T>`, with T the corresponding type in the function type. + By default, `wasmssa.func` have nested visibility. Functions exported by the module + are marked with the exported attribute. This gives them public visibility. + Example: ```mlir - // A simple function with no arguments that returns a float32 + // Internal function with no arguments that returns a float32 wasmssa.func @my_f32_func() -> f32 + // Exported function with no arguments that returns a float32 + wasmssa.func exported @my_f32_func() -> f32 + // A function that takes a local ref argument wasmssa.func @i64_wrap(%a: !wasmssa<local ref to i64>) -> i32 ``` @@ -141,7 +186,7 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ WasmSSA_FuncTypeAttr: $functionType, OptionalAttr<DictArrayAttr>:$arg_attrs, OptionalAttr<DictArrayAttr>:$res_attrs, - DefaultValuedAttr<StrAttr, "\"nested\"">:$sym_visibility); + UnitAttr: $exported); let regions = (region AnyRegion: $body); let extraClassDeclaration = [{ @@ -162,6 +207,12 @@ def WasmSSA_FuncOp : WasmSSA_Op<"func", [ /// Returns the result types of this function. ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); } + + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; }]; let builders = [ @@ -207,8 +258,7 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [ StrAttr: $importName, WasmSSA_FuncTypeAttr: $type, OptionalAttr<DictArrayAttr>:$arg_attrs, - OptionalAttr<DictArrayAttr>:$res_attrs, - OptionalAttr<StrAttr>:$sym_visibility); + OptionalAttr<DictArrayAttr>:$res_attrs); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } @@ -221,6 +271,10 @@ def WasmSSA_FuncImportOp : WasmSSA_Op<"import_func", [ ::llvm::ArrayRef<Type> getResultTypes() { return getType().getResults(); } + + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; let builders = [ OpBuilder<(ins "StringRef":$symbol, @@ -238,30 +292,41 @@ def WasmSSA_GlobalOp : WasmSSA_Op<"global", [ let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_ValTypeAttr: $type, UnitAttr: $isMutable, - OptionalAttr<StrAttr>:$sym_visibility); + UnitAttr: $exported); let description = [{ WebAssembly global variable. Body contains the initialization instructions for the variable value. The body must contain only instructions considered `const` in a webassembly context, such as `wasmssa.const` or `global.get`. + By default, `wasmssa.global` have nested visibility. Global exported by the module + are marked with the exported attribute. This gives them public visibility. + Example: ```mlir - // Define a global_var, a mutable i32 global variable equal to 10. - wasmssa.global @global_var i32 mutable nested : { + // Define module_global_var, an internal mutable i32 global variable equal to 10. + wasmssa.global @module_global_var i32 mutable : { %[[VAL_0:.*]] = wasmssa.const 10 : i32 wasmssa.return %[[VAL_0]] : i32 } + + // Define global_var, an exported constant i32 global variable equal to 42. + wasmssa.global @global_var i32 : { + %[[VAL_0:.*]] = wasmssa.const 42 : i32 + wasmssa.return %[[VAL_0]] : i32 + } ``` }]; let regions = (region AnyRegion: $initializer); - let builders = [ - OpBuilder<(ins "StringRef":$symbol, - "Type": $type, - "bool": $isMutable)> - ]; + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; let hasCustomAssemblyFormat = 1; } @@ -283,18 +348,14 @@ def WasmSSA_GlobalImportOp : WasmSSA_Op<"import_global", [ StrAttr: $moduleName, StrAttr: $importName, WasmSSA_ValTypeAttr: $type, - UnitAttr: $isMutable, - OptionalAttr<StrAttr>:$sym_visibility); + UnitAttr: $isMutable); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } + + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; - let builders = [ - OpBuilder<(ins "StringRef":$symbol, - "StringRef":$moduleName, - "StringRef":$importName, - "Type": $type, - "bool": $isMutable)> - ]; let hasCustomAssemblyFormat = 1; } @@ -442,23 +503,33 @@ def WasmSSA_MemOp : WasmSSA_Op<"memory", [Symbol]> { Define a memory to be used by the program. Multiple memories can be defined in the same module. + By default, `wasmssa.memory` have nested visibility. Memory exported by + the module are marked with the exported attribute. This gives them public + visibility. + Example: ```mlir - // Define the `mem_0` memory with defined bounds of 0 -> 65536 + // Define the `mem_0` (internal) memory with defined size bounds of [0:65536] wasmssa.memory @mem_0 !wasmssa<limit[0:65536]> + + // Define the `mem_1` exported memory with minimal size of 512 + wasmssa.memory exported @mem_1 !wasmssa<limit[512:]> ``` }]; let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_LimitTypeAttr: $limits, - OptionalAttr<StrAttr>:$sym_visibility); - let builders = [ - OpBuilder<(ins - "::llvm::StringRef":$symbol, - "wasmssa::LimitType":$limit)> - ]; + UnitAttr: $exported); - let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $limits attr-dict"; + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; + + let assemblyFormat = "(`exported` $exported^)? $sym_name $limits attr-dict"; } def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> { @@ -476,16 +547,13 @@ def WasmSSA_MemImportOp : WasmSSA_Op<"import_mem", [Symbol, ImportOpInterface]> let arguments = (ins SymbolNameAttr: $sym_name, StrAttr: $moduleName, StrAttr: $importName, - WasmSSA_LimitTypeAttr: $limits, - OptionalAttr<StrAttr>:$sym_visibility); + WasmSSA_LimitTypeAttr: $limits); let extraClassDeclaration = [{ - bool isDeclaration() const { return true; } + bool isDeclaration() const { return true; } + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "::llvm::StringRef":$moduleName, - "::llvm::StringRef":$importName, - "wasmssa::LimitType":$limits)>]; let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict"; } @@ -493,11 +561,15 @@ def WasmSSA_TableOp : WasmSSA_Op<"table", [Symbol]> { let summary= "WebAssembly table value"; let arguments = (ins SymbolNameAttr: $sym_name, WasmSSA_TableTypeAttr: $type, - OptionalAttr<StrAttr>:$sym_visibility); - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "wasmssa::TableType":$type)>]; - let assemblyFormat = "$sym_name custom<WasmVisibility>($sym_visibility) $type attr-dict"; + UnitAttr: $exported); + let extraClassDeclaration = [{ + ::mlir::SymbolTable::Visibility getVisibility() { + return getExported() ? + ::mlir::SymbolTable::Visibility::Public : + ::mlir::SymbolTable::Visibility::Nested; + }; + }]; + let assemblyFormat = "(`exported` $exported^)? $sym_name $type attr-dict"; } def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterface]> { @@ -515,17 +587,14 @@ def WasmSSA_TableImportOp : WasmSSA_Op<"import_table", [Symbol, ImportOpInterfac let arguments = (ins SymbolNameAttr: $sym_name, StrAttr: $moduleName, StrAttr: $importName, - WasmSSA_TableTypeAttr: $type, - OptionalAttr<StrAttr>:$sym_visibility); + WasmSSA_TableTypeAttr: $type); let extraClassDeclaration = [{ bool isDeclaration() const { return true; } + ::mlir::SymbolTable::Visibility getVisibility() { + return ::mlir::SymbolTable::Visibility::Nested; + }; }]; let assemblyFormat = "$importName `from` $moduleName `as` $sym_name attr-dict"; - let builders = [OpBuilder<(ins - "::llvm::StringRef":$symbol, - "::llvm::StringRef":$moduleName, - "::llvm::StringRef":$importName, - "wasmssa::TableType":$type)>]; } def WasmSSA_ReturnOp : WasmSSA_Op<"return", [Terminator]> { diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 5695d5d..19a5231 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().contains(name); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { return getAttrs().getAs<ArrayAttr>("stride"); } + ArrayAttr getBlockAttr() { + return getAttrs().getAs<ArrayAttr>("block"); + } + }]; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73f9061..426377f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, } def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, - AllElementTypesMatch<["mem_desc", "res"]>, - AllRanksMatch<["mem_desc", "res"]>]> { + AllElementTypesMatch<["mem_desc", "res"]>]> { let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic<Index>: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr<UnitAttr>:$subgroup_block_io, OptionalAttr<DistributeLayoutAttr>:$layout ); - let results = (outs XeGPU_ValueType:$res); + let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); let assemblyFormat = [{ $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands) `->` type(results) @@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, Arguments: - `mem_desc`: the memory descriptor identifying the SLM region. - `offsets`: the coordinates within the matrix to read from. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block load. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1336,7 +1339,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } ArrayRef<int64_t> getDataShape() { - return getRes().getType().getShape(); + auto resTy = getRes().getType(); + if (auto vecTy = llvm::dyn_cast<VectorType>(resTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1344,13 +1350,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - AllElementTypesMatch<["mem_desc", "data"]>, - AllRanksMatch<["mem_desc", "data"]>]> { + AllElementTypesMatch<["mem_desc", "data"]>]> { let arguments = (ins - XeGPU_ValueType:$data, + AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, XeGPU_MemDesc:$mem_desc, Variadic<Index>: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr<UnitAttr>:$subgroup_block_io, OptionalAttr<DistributeLayoutAttr>:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) @@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - `mem_desc`: the memory descriptor specifying the SLM region. - `offsets`: the coordinates within the matrix where the data will be written. - `data`: the values to be stored in the matrix. + - `subgroup_block_io`: [optional] An attribute indicating that the operation can be + lowered to a subgroup block store. When this attribute is present, + the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1378,7 +1387,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, } ArrayRef<int64_t> getDataShape() { - return getData().getType().getShape(); + auto DataTy = getData().getType(); + if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1386,41 +1398,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, let hasVerifier = 1; } -def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview", - [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> { - let description = [{ - Creates a subview of a memory descriptor. The resulting memory descriptor can have - a lower rank than the source; in this case, the result dimensions correspond to the - higher-order dimensions of the source memory descriptor. - - Arguments: - - `src` : a memory descriptor. - - `offsets` : the coordinates within the matrix the subview will be created from. - - Results: - - `res` : a memory descriptor with smaller size. - - }]; - let arguments = (ins XeGPU_MemDesc:$src, - Variadic<Index>:$offsets, - DenseI64ArrayAttr:$const_offsets); - let results = (outs XeGPU_MemDesc:$res); - let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict - attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}]; - let builders = [ - OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)> - ]; - - let extraClassDeclaration = [{ - mlir::Value getViewSource() { return getSrc(); } - - SmallVector<OpFoldResult> getMixedOffsets() { - return getMixedValues(getConstOffsets(), getOffsets(), getContext()); - } - }]; - - let hasVerifier = 1; -} - - #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84902b2..b1196fb 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,12 +237,11 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStrides() { + ArrayAttr getStrideAttr() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { - return layout.getStrides(); + return layout.getStrideAttr(); } - // derive and return default strides SmallVector<int64_t> defaultStrides; llvm::append_range(defaultStrides, getShape().drop_front()); @@ -250,6 +249,63 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m Builder builder(getContext()); return builder.getI64ArrayAttr(defaultStrides); } + + ArrayAttr getBlockAttr() { + auto layout = getMemLayout(); + if (layout && layout.hasAttr("block")) { + return layout.getBlockAttr(); + } + Builder builder(getContext()); + return builder.getI64ArrayAttr({}); + } + + /// Heuristic to determine if the MemDesc uses column-major layout, + /// based on the rank and the value of the first stride dimension. + bool isColMajor() { + auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]); + return getRank() == 2 && dim0.getInt() == 1; + } + + // Get the Blocking shape for a MemDescType, Which is represented + // as an attribute in MemDescType. By default it is the shape + // of the mdescTy + SmallVector<int64_t> getBlockShape() { + SmallVector<int64_t> size(getShape()); + ArrayAttr blockAttr = getBlockAttr(); + if (!blockAttr.empty()) { + size.clear(); + for (auto attr : blockAttr.getValue()) { + size.push_back(cast<IntegerAttr>(attr).getInt()); + } + } + return size; + } + + // Get strides as vector of integer. + // If it contains block attribute, the strides are blocked strides. + // + // The blocking is applied to the base matrix shape derived from the + // memory descriptor's stride information. If the matrix described by + // the memory descriptor is not contiguous, it is assumed that the base + // matrix is contiguous and follows the same memory layout. + // + // It first computes the original matrix shape using the stride info, + // then computes the number of blocks in each dimension of original shape, + // then compute the outer block shape and stride, + // then combines the inner and outer block shape and stride + // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>` + // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) + // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1] + // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) + SmallVector<int64_t> getStrideShape(); + + /// Generates instructions to compute the linearize offset + // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout + // the strides of memory descriptor is always considered regardless of blocked or not + Value getLinearOffsets(OpBuilder &builder, + Location loc, ArrayRef<OpFoldResult> offsets); + + }]; let hasCustomAssemblyFormat = true; diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h index 252da21..997aef2 100644 --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -88,7 +88,7 @@ public: /// /// Constraints that do not meet the restriction that they can only reference /// `$_self` and `$_op` are not uniqued. - void emitOpConstraints(ArrayRef<const llvm::Record *> opDefs); + void emitOpConstraints(); /// Unique all compatible type and attribute constraints from a pattern file /// and emit them at the top of the generated file. diff --git a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h index 21adde8..cd9ef5b 100644 --- a/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h +++ b/mlir/include/mlir/Target/Wasm/WasmBinaryEncoding.h @@ -19,6 +19,14 @@ namespace mlir { struct WasmBinaryEncoding { /// Byte encodings for Wasm instructions. struct OpCode { + // Control instructions. + static constexpr std::byte block{0x02}; + static constexpr std::byte loop{0x03}; + static constexpr std::byte ifOpCode{0x04}; + static constexpr std::byte elseOpCode{0x05}; + static constexpr std::byte branchIf{0x0D}; + static constexpr std::byte call{0x10}; + // Locals, globals, constants. static constexpr std::byte localGet{0x20}; static constexpr std::byte localSet{0x21}; @@ -29,6 +37,42 @@ struct WasmBinaryEncoding { static constexpr std::byte constFP32{0x43}; static constexpr std::byte constFP64{0x44}; + // Comparisons. + static constexpr std::byte eqzI32{0x45}; + static constexpr std::byte eqI32{0x46}; + static constexpr std::byte neI32{0x47}; + static constexpr std::byte ltSI32{0x48}; + static constexpr std::byte ltUI32{0x49}; + static constexpr std::byte gtSI32{0x4A}; + static constexpr std::byte gtUI32{0x4B}; + static constexpr std::byte leSI32{0x4C}; + static constexpr std::byte leUI32{0x4D}; + static constexpr std::byte geSI32{0x4E}; + static constexpr std::byte geUI32{0x4F}; + static constexpr std::byte eqzI64{0x50}; + static constexpr std::byte eqI64{0x51}; + static constexpr std::byte neI64{0x52}; + static constexpr std::byte ltSI64{0x53}; + static constexpr std::byte ltUI64{0x54}; + static constexpr std::byte gtSI64{0x55}; + static constexpr std::byte gtUI64{0x56}; + static constexpr std::byte leSI64{0x57}; + static constexpr std::byte leUI64{0x58}; + static constexpr std::byte geSI64{0x59}; + static constexpr std::byte geUI64{0x5A}; + static constexpr std::byte eqF32{0x5B}; + static constexpr std::byte neF32{0x5C}; + static constexpr std::byte ltF32{0x5D}; + static constexpr std::byte gtF32{0x5E}; + static constexpr std::byte leF32{0x5F}; + static constexpr std::byte geF32{0x60}; + static constexpr std::byte eqF64{0x61}; + static constexpr std::byte neF64{0x62}; + static constexpr std::byte ltF64{0x63}; + static constexpr std::byte gtF64{0x64}; + static constexpr std::byte leF64{0x65}; + static constexpr std::byte geF64{0x66}; + // Numeric operations. static constexpr std::byte clzI32{0x67}; static constexpr std::byte ctzI32{0x68}; @@ -93,6 +137,33 @@ struct WasmBinaryEncoding { static constexpr std::byte maxF64{0xA5}; static constexpr std::byte copysignF64{0xA6}; static constexpr std::byte wrap{0xA7}; + + // Conversion operations + static constexpr std::byte extendS{0xAC}; + static constexpr std::byte extendU{0xAD}; + static constexpr std::byte convertSI32F32{0xB2}; + static constexpr std::byte convertUI32F32{0xB3}; + static constexpr std::byte convertSI64F32{0xB4}; + static constexpr std::byte convertUI64F32{0xB5}; + + static constexpr std::byte demoteF64ToF32{0xB6}; + + static constexpr std::byte convertSI32F64{0xB7}; + static constexpr std::byte convertUI32F64{0xB8}; + static constexpr std::byte convertSI64F64{0xB9}; + static constexpr std::byte convertUI64F64{0xBA}; + + static constexpr std::byte promoteF32ToF64{0xBB}; + static constexpr std::byte reinterpretF32AsI32{0xBC}; + static constexpr std::byte reinterpretF64AsI64{0xBD}; + static constexpr std::byte reinterpretI32AsF32{0xBE}; + static constexpr std::byte reinterpretI64AsF64{0xBF}; + + static constexpr std::byte extendI328S{0xC0}; + static constexpr std::byte extendI3216S{0xC1}; + static constexpr std::byte extendI648S{0xC2}; + static constexpr std::byte extendI6416S{0xC3}; + static constexpr std::byte extendI6432S{0xC4}; }; /// Byte encodings of types in Wasm binaries diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 30ce1fb..6588b53 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -1244,8 +1244,9 @@ bool FlatLinearValueConstraints::areVarsAlignedWithOther( /// Checks if the SSA values associated with `cst`'s variables in range /// [start, end) are unique. -static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( - const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { +[[maybe_unused]] static bool +areVarsUnique(const FlatLinearValueConstraints &cst, unsigned start, + unsigned end) { assert(start <= cst.getNumDimAndSymbolVars() && "Start position out of bounds"); @@ -1267,14 +1268,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( } /// Checks if the SSA values associated with `cst`'s variables are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst) { return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars()); } /// Checks if the SSA values associated with `cst`'s variables of kind `kind` /// are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { if (kind == VarKind::SetDim) diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index a1cbe29..547a4c2 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -34,7 +34,7 @@ using Direction = Simplex::Direction; const int nullIndex = std::numeric_limits<int>::max(); // Return a + scale*b; -LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static SmallVector<DynamicAPInt, 8> scaleAndAddForAssert(ArrayRef<DynamicAPInt> a, const DynamicAPInt &scale, ArrayRef<DynamicAPInt> b) { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7b17106..06d0256 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2730,6 +2730,17 @@ public: operation->get(), toMlirStringRef(name))); } + static void + forEachAttr(MlirOperation op, + llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) { + intptr_t n = mlirOperationGetNumAttributes(op); + for (intptr_t i = 0; i < n; ++i) { + MlirNamedAttribute na = mlirOperationGetAttribute(op, i); + MlirStringRef name = mlirIdentifierStr(na.name); + fn(name, na.attribute); + } + } + static void bind(nb::module_ &m) { nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") .def("__contains__", &PyOpAttributeMap::dunderContains) @@ -2737,7 +2748,50 @@ public: .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) .def("__setitem__", &PyOpAttributeMap::dunderSetItem) - .def("__delitem__", &PyOpAttributeMap::dunderDelItem); + .def("__delitem__", &PyOpAttributeMap::dunderDelItem) + .def("__iter__", + [](PyOpAttributeMap &self) { + nb::list keys; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + keys.append(nb::str(name.data, name.length)); + }); + return nb::iter(keys); + }) + .def("keys", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + out.append(nb::str(name.data, name.length)); + }); + return out; + }) + .def("values", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef, MlirAttribute attr) { + out.append(PyAttribute(self.operation->getContext(), attr) + .maybeDownCast()); + }); + return out; + }) + .def("items", [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute attr) { + out.append(nb::make_tuple( + nb::str(name.data, name.length), + PyAttribute(self.operation->getContext(), attr) + .maybeDownCast())); + }); + return out; + }); } private: diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 5ddb3fb..0f0ed22 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -205,7 +205,7 @@ public: nb::object res = f(opView, PyPatternRewriter(rewriter)); return logicalResultFromObject(res); }; - MlirRewritePattern pattern = mlirOpRewritePattenCreate( + MlirRewritePattern pattern = mlirOpRewritePatternCreate( rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 46c329d..41ceb15 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -341,7 +341,7 @@ private: } // namespace mlir -MlirRewritePattern mlirOpRewritePattenCreate( +MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) { diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index b215211..c03f3a5 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns( GPUSubgroupBroadcastOpToROCDL>(converter); patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset); - populateMathToROCDLConversionPatterns(converter, patterns); + populateMathToROCDLConversionPatterns(converter, patterns, chipset); } diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt index 2771955a..8cc3fde 100644 --- a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL Core LINK_LIBS PUBLIC + MLIRAMDGPUUtils MLIRDialectUtils MLIRFuncDialect MLIRGPUToGPURuntimeTransforms diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index df219f3..a2dfc12 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,8 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -19,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/DebugLog.h" #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" @@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +struct ClampFOpConversion final + : public ConvertOpToLLVMPattern<math::ClampFOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only f16 and f32 types are supported by fmed3 + Type opTy = op.getType(); + Type resultType = getTypeConverter()->convertType(opTy); + + if (auto vectorType = dyn_cast<VectorType>(opTy)) + opTy = vectorType.getElementType(); + + if (!isa<Float16Type, Float32Type>(opTy)) + return rewriter.notifyMatchFailure( + op, "fmed3 only supports f16 and f32 types"); + + // Handle multi-dimensional vectors (converted to LLVM arrays) + if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) -> Value { + typename math::ClampFOp::Adaptor adaptor(operands); + return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getValue(), adaptor.getMin(), + adaptor.getMax()); + }, + rewriter); + + // Handle 1D vectors and scalars directly + rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(), + op.getMin(), op.getMax()); + return success(); + } +}; + void mlir::populateMathToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + std::optional<amdgpu::Chipset> chipset) { // Handled by mathToLLVM: math::AbsIOp // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp @@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns( // worth creating a separate pass for it. populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32", "__ocml_fmod_f64", "__ocml_fmod_f16"); + + if (chipset.has_value() && chipset->majorVersion >= 9) { + patterns.add<ClampFOpConversion>(converter); + } else { + LDBG() << "Chipset dependent patterns were not added"; + } } -namespace { -struct ConvertMathToROCDLPass - : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> { - ConvertMathToROCDLPass() = default; +struct ConvertMathToROCDLPass final + : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> { + using impl::ConvertMathToROCDLBase< + ConvertMathToROCDLPass>::ConvertMathToROCDLBase; + void runOnOperation() override; }; -} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); @@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() { RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - populateMathToROCDLConversionPatterns(converter, patterns); + + FailureOr<amdgpu::Chipset> maybeChipset; + if (!chipset.empty()) { + maybeChipset = amdgpu::Chipset::parse(chipset); + if (failed(maybeChipset)) + return signalPassFailure(); + } + populateMathToROCDLConversionPatterns( + converter, patterns, + succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt); + ConversionTarget target(getContext()); - target.addLegalDialect<BuiltinDialect, func::FuncDialect, - vector::VectorDialect, LLVM::LLVMDialect>(); + target + .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect, + LLVM::LLVMDialect, ROCDL::ROCDLDialect>(); target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index f0d8b78..610ce1f 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { if (auto vectorType = dyn_cast<VectorType>(operandType)) nanAttr = DenseElementsAttr::get(vectorType, nan); - Value NanValue = + Value nanValue = spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); Value lhs = spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, - NanValue, adaptor.getLhs()); + nanValue, adaptor.getLhs()); Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5355909..41d8d53 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle`. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + // For 1-d vector, we additionally do a `shufflevector`. int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0); SmallVector<int32_t> zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison, - zeroValues); + auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>( + broadcast.getLoc(), v, poison, zeroValues); + rewriter.replaceOp(broadcast, shuffle); return success(); } }; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt index 84b2580..dd9edc4 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM MLIRIndexDialect MLIRSCFDialect MLIRXeGPUDialect + MLIRXeGPUUtils MLIRPass MLIRTransforms MLIRSCFTransforms diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 71687b1..fcbf66d 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -20,7 +20,9 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -62,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { case xegpu::MemorySpace::SLM: return static_cast<int>(xevm::AddrSpace::SHARED); } + llvm_unreachable("Unknown XeGPU memory space"); } // Get same bitwidth flat vector type of new element type. @@ -185,6 +188,7 @@ class CreateNdDescToXeVMPattern int64_t rank = mixedSizes.size(); if (rank != 2) return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -363,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { // Add a builder that creates // offset * elemByteSize + baseAddr -static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, - Value baseAddr, Value offset, int64_t elemByteSize) { +static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter, + Location loc, Value baseAddr, Value offset, + int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI64Type(), elemByteSize); + rewriter, loc, baseAddr.getType(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -390,7 +395,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // Load result or Store valye Type can be vector or scalar. Type valOrResTy; if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) - valOrResTy = op.getResult().getType(); + valOrResTy = + this->getTypeConverter()->convertType(op.getResult().getType()); else valOrResTy = adaptor.getValue().getType(); VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); @@ -441,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // If offset is provided, we add them to the base pointer. // Offset is in number of elements, we need to multiply by // element byte size. - basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); + basePtrI64 = + addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -504,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { } }; +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than +// 32 bits will be converted to 32 bits. +class CreateMemDescOpPattern final + : public OpConversionPattern<xegpu::CreateMemDescOp> { +public: + using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto resTy = op.getMemDesc(); + + // Create the result MemRefType with the same shape, element type, and + // memory space + auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); + + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, + op.getSource(), zero, ValueRange()); + rewriter.replaceOp(op, viewOp); + return success(); + } +}; + +template <typename OpType, + typename = std::enable_if_t<llvm::is_one_of< + OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> +class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<OpFoldResult> offsets = op.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); + + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + Value basePtrStruct = adaptor.getMemDesc(); + Value mdescVal = op.getMemDesc(); + // Load result or Store value Type can be vector or scalar. + Value data; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) + data = op.getResult(); + else + data = adaptor.getData(); + VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); + if (!valOrResVecTy) + valOrResVecTy = VectorType::get(1, data.getType()); + + int64_t elemBitWidth = + valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + + // Default memory space is SLM. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); + + auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); + + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, basePtrStruct); + + // Convert base pointer (ptr) to i32 + Value basePtrI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), basePtrLLVM); + + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + linearOffset = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), linearOffset); + basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, + elemByteSize); + + // convert base pointer (i32) to LLVM pointer type + basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); + + if (op.getSubgroupBlockIoAttr()) { + // if the attribute 'subgroup_block_io' is set to true, it lowers to + // xevm.blockload + + Type intElemTy = rewriter.getIntegerType(elemBitWidth); + VectorType intVecTy = + VectorType::get(valOrResVecTy.getShape(), intElemTy); + + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + Value loadOp = + xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); + if (intVecTy != valOrResVecTy) { + loadOp = + vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); + } + rewriter.replaceOp(op, loadOp); + } else { + Value dataToStore = adaptor.getData(); + if (valOrResVecTy != intVecTy) { + dataToStore = + vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); + } + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, + nullptr); + rewriter.eraseOp(op); + } + return success(); + } + + if (valOrResVecTy.getNumElements() >= 1) { + auto chipOpt = xegpu::getChipStr(op); + if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { + // the lowering for chunk load only works for pvc and bmg + return rewriter.notifyMatchFailure( + op, "The lowering is specific to pvc or bmg."); + } + } + + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + // if the size of valOrResVecTy is 1, it lowers to a scalar load/store + // operation. LLVM load/store does not support vector of size 1, so we + // need to handle this case separately. + auto scalarTy = valOrResVecTy.getElementType(); + LLVM::LoadOp loadOp; + if (valOrResVecTy.getNumElements() == 1) + loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); + else + loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } + return success(); + } +}; + class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -546,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { op, "Expected element type bit width to be multiple of 8."); elemByteSize = elemBitWidth / 8; } - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets, + elemByteSize); } } // Default memory space is global. @@ -784,6 +932,13 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); + // Convert MemDescType into flattened MemRefType for SLM + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + Type elemTy = type.getElementType(); + int numElems = type.getNumElements(); + return MemRefType::get(numElems, elemTy, AffineMap(), 3); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); @@ -878,10 +1033,30 @@ struct ConvertXeGPUToXeVMPass } return {}; }; - typeConverter.addSourceMaterialization(memrefMaterializationCast); - typeConverter.addSourceMaterialization(ui64MaterializationCast); - typeConverter.addSourceMaterialization(ui32MaterializationCast); - typeConverter.addSourceMaterialization(vectorMaterializationCast); + + // If result type of original op is single element vector and lowered type + // is scalar. This materialization cast creates a single element vector by + // broadcasting the scalar value. + auto singleElementVectorMaterializationCast = + [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType().isIntOrIndexOrFloat()) { + // If the input is a scalar, and the target type is a vector of single + // element, create a single element vector by broadcasting. + if (auto vecTy = dyn_cast<VectorType>(type)) { + if (vecTy.getNumElements() == 1) { + return vector::BroadcastOp::create(builder, loc, vecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization( + singleElementVectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); @@ -918,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns( LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( typeConverter, patterns.getContext()); + patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>, + LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>, + CreateMemDescOpPattern>(typeConverter, patterns.getContext()); patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index f405d0c..c798adb 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -757,13 +757,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> { offset = numElements - 4l; } Type scaleSrcElemType = scaleSrcType.getElementType(); - auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}), - scaleSrcElemType); + auto newSrcType = + VectorType::get(ArrayRef{numElements}, scaleSrcElemType); Value newScaleSrc = vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); auto extract = vector::ExtractStridedSliceOp::create( - rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset}, - ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1}); + rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, + ArrayRef{int64_t(1)}); rewriter.modifyOpInPlace(op, [&] { op->setOperand(opIdx, extract); setOpsel(opIdx, opsel); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 7e5ce26..749e2ba 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -125,9 +125,9 @@ static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, // Use "unused attribute" marker to silence clang-tidy warning stemming from // the inability to see through "llvm::TypeSwitch". template <> -bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op, - Region *src, Region *dest, - const IRMapping &mapping) { +[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src, + Region *dest, + const IRMapping &mapping) { // If it's a valid dimension, we need to check that it remains so. if (isValidDim(op.getResult(), src)) return remainsLegalAfterInline( @@ -1032,8 +1032,8 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map, /// Simplify the map while exploiting information on the values in `operands`. // Use "unused attribute" marker to silence warning stemming from the inability // to see through the template expansion. -static void LLVM_ATTRIBUTE_UNUSED -simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) { +[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map, + ArrayRef<Value> operands) { assert(map.getNumInputs() == operands.size() && "invalid operands for map"); SmallVector<AffineExpr> newResults; newResults.reserve(map.getNumResults()); @@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, return success(*map != initialMap); } +/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form +/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`, +/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove) +/// into `replacementsMap`. If no entries were added to `replacementsMap`, +/// nothing was found. +static void shortenAddChainsContainingAll( + AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove, + AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) { + auto binOp = dyn_cast<AffineBinaryOpExpr>(e); + if (!binOp) + return; + AffineExpr lhs = binOp.getLHS(); + AffineExpr rhs = binOp.getRHS(); + if (binOp.getKind() != AffineExprKind::Add) { + shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap); + shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap); + return; + } + SmallVector<AffineExpr> toPreserve; + llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove); + AffineExpr thisTerm = rhs; + AffineExpr nextTerm = lhs; + + while (thisTerm) { + if (!ourTracker.erase(thisTerm)) { + toPreserve.push_back(thisTerm); + shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal, + replacementsMap); + } + auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm); + if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) { + thisTerm = nextTerm; + nextTerm = AffineExpr(); + } else { + thisTerm = nextBinOp.getRHS(); + nextTerm = nextBinOp.getLHS(); + } + } + if (!ourTracker.empty()) + return; + // We reverse the terms to be preserved here in order to preserve + // associativity between them. + AffineExpr newExpr = newVal; + for (AffineExpr preserved : llvm::reverse(toPreserve)) + newExpr = newExpr + preserved; + replacementsMap.insert({e, newExpr}); +} + +/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N + +/// ...` (not necessarily in order) where the set of the `x_i` is the set of +/// outputs of an `affine.delinearize_index` whos inverse is that expression, +/// replace that expression with the input of that delinearize_index op. +/// +/// `unitDimInput` is the input that was detected as the potential start to this +/// replacement chain - if it isn't the rightmost result of the delinearization, +/// this method fails. (This is intended to ensure we don't have redundant scans +/// over the same expression). +/// +/// While this currently only handles delinearizations with a constant basis, +/// that isn't a fundamental limitation. +/// +/// This is a utility function for `replaceDimOrSym` below. +static LogicalResult replaceAffineDelinearizeIndexInverseExpression( + AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, + SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) { + if (!delinOp.getDynamicBasis().empty()) + return failure(); + if (resultToReplace != delinOp.getMultiIndex().back()) + return failure(); + + MLIRContext *ctx = delinOp.getContext(); + SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr()); + for (auto [pos, dim] : llvm::enumerate(dims)) { + auto asResult = dyn_cast_if_present<OpResult>(dim); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx); + } + for (auto [pos, sym] : llvm::enumerate(syms)) { + auto asResult = dyn_cast_if_present<OpResult>(sym); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx); + } + if (llvm::is_contained(resToExpr, AffineExpr())) + return failure(); + + bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>); + int64_t stride = 1; + llvm::SmallDenseSet<AffineExpr, 4> expectedExprs; + // This isn't zip_equal since sometimes the delinearize basis is missing a + // size for the first result. + for (auto [binding, size] : llvm::zip( + llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) { + expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx)); + stride *= size; + } + if (resToExpr.size() != delinOp.getStaticBasis().size()) + expectedExprs.insert(resToExpr[0] * stride); + + DenseMap<AffineExpr, AffineExpr> replacements; + AffineExpr delinInExpr = isDimReplacement + ? getAffineDimExpr(dims.size(), ctx) + : getAffineSymbolExpr(syms.size(), ctx); + + for (AffineExpr e : map->getResults()) + shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements); + if (replacements.empty()) + return failure(); + + AffineMap origMap = *map; + if (isDimReplacement) + dims.push_back(delinOp.getLinearIndex()); + else + syms.push_back(delinOp.getLinearIndex()); + *map = origMap.replace(replacements, dims.size(), syms.size()); + + // Blank out dead dimensions and symbols + for (AffineExpr e : resToExpr) { + if (auto d = dyn_cast<AffineDimExpr>(e)) { + unsigned pos = d.getPosition(); + if (!map->isFunctionOfDim(pos)) + dims[pos] = nullptr; + } + if (auto s = dyn_cast<AffineSymbolExpr>(e)) { + unsigned pos = s.getPosition(); + if (!map->isFunctionOfSymbol(pos)) + syms[pos] = nullptr; + } + } + return success(); +} + /// Replace all occurrences of AffineExpr at position `pos` in `map` by the /// defining AffineApplyOp expression and operands. /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. @@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map, syms); } + if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) { + return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims, + syms); + } + auto affineApply = v.getDefiningOp<AffineApplyOp>(); if (!affineApply) return failure(); diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index cd216ef..4743941 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1357,7 +1357,7 @@ bool mlir::affine::isValidLoopInterchangePermutation( /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] bool mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) { assert(!loops.empty() && "no loops provided"); @@ -1920,8 +1920,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, return copyNestRoot; } -static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED -emitRemarkForBlock(Block &block) { +[[maybe_unused]] static InFlightDiagnostic emitRemarkForBlock(Block &block) { return block.getParentOp()->emitRemark(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index 70faa71..bc17990 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -41,18 +41,37 @@ namespace bufferization { using namespace mlir; -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { - func::ReturnOp returnOp; +/// Get all the ReturnOp in the funcOp. +static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) { + SmallVector<func::ReturnOp> returnOps; for (Block &b : funcOp.getBody()) { if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; + returnOps.push_back(candidateOp); } } - return returnOp; + return returnOps; +} + +/// Get the operands at the specified position for all returnOps. +static SmallVector<Value> +getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) { + return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) { + return returnOp.getOperand(pos); + }); +} + +/// Check if all given values are the same buffer as the block argument (modulo +/// cast ops). +static bool operandsEqualFuncArgument(ArrayRef<Value> operands, + BlockArgument argument) { + for (Value val : operands) { + while (auto castOp = val.getDefiningOp<memref::CastOp>()) + val = castOp.getSource(); + + if (val != argument) + return false; + } + return true; } LogicalResult @@ -72,40 +91,45 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { for (auto funcOp : module.getOps<func::FuncOp>()) { if (funcOp.isExternal() || funcOp.isPublic()) continue; - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - // TODO: Support functions with multiple blocks. - if (!returnOp) + SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); + if (returnOps.empty()) continue; // Compute erased results. - SmallVector<Value> newReturnValues; - BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); + size_t numReturnOps = returnOps.size(); + size_t numReturnValues = funcOp.getFunctionType().getNumResults(); + SmallVector<SmallVector<Value>> newReturnValues(numReturnOps); + BitVector erasedResultIndices(numReturnValues); DenseMap<int64_t, int64_t> resultToArgs; - for (const auto &it : llvm::enumerate(returnOp.getOperands())) { + for (size_t i = 0; i < numReturnValues; ++i) { bool erased = false; + SmallVector<Value> returnOperands = + getReturnOpsOperandInPos(returnOps, i); for (BlockArgument bbArg : funcOp.getArguments()) { - Value val = it.value(); - while (auto castOp = val.getDefiningOp<memref::CastOp>()) - val = castOp.getSource(); - - if (val == bbArg) { - resultToArgs[it.index()] = bbArg.getArgNumber(); + if (operandsEqualFuncArgument(returnOperands, bbArg)) { + resultToArgs[i] = bbArg.getArgNumber(); erased = true; break; } } if (erased) { - erasedResultIndices.set(it.index()); + erasedResultIndices.set(i); } else { - newReturnValues.push_back(it.value()); + for (auto [newReturnValue, operand] : + llvm::zip(newReturnValues, returnOperands)) { + newReturnValue.push_back(operand); + } } } // Update function. if (failed(funcOp.eraseResults(erasedResultIndices))) return failure(); - returnOp.getOperandsMutable().assign(newReturnValues); + + for (auto [returnOp, newReturnValue] : + llvm::zip(returnOps, newReturnValues)) + returnOp.getOperandsMutable().assign(newReturnValue); // Update function calls. for (func::CallOp callOp : callerMap[funcOp]) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 7ca09d9..3eae67f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2826,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() { return success(); } +// Folding for shufflevector op when v1 is single element 1D vector +// and the mask is a single zero. OpFoldResult will be v1 in this case. +OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) { + // Check if operand 0 is a single element vector. + auto vecType = llvm::dyn_cast<VectorType>(getV1().getType()); + if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1) + return {}; + // Check if the mask is a single zero. + // Note: The mask is guaranteed to be non-empty. + if (getMask().size() != 1 || getMask()[0] != 0) + return {}; + return getV1(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 01a16ce..ac35eea 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -134,10 +134,10 @@ static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams, /// These are unused for now. /// TODO: Move over to these once more types have been migrated to TypeDef. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 55fe0d9..2a8c330 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -798,6 +798,26 @@ LogicalResult MmaOp::verify() { " attribute"); } + // Validate layout combinations. According to the operation description, most + // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 + // can use other layout combinations. + bool isM8N8K4_F16 = + (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && + getMultiplicandAPtxType() == MMATypes::f16); + + if (!isM8N8K4_F16) { + // For all other shapes/types, layoutA must be row and layoutB must be col + if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { + return emitOpError("requires layoutA = #nvvm.mma_layout<row> and " + "layoutB = #nvvm.mma_layout<col> for shape <") + << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] + << "> with element types " + << stringifyEnum(*getMultiplicandAPtxType()) << " and " + << stringifyEnum(*getMultiplicandBPtxType()) + << ". Only m8n8k4 with f16 supports other layouts."; + } + } + return success(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index d8f983f..6192d79 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3024,10 +3024,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } if (dynamicPointParseResult.has_value()) { - Type ChunkSizesType; + Type chunkSizesType; if (failed(*dynamicPointParseResult) || parser.parseComma() || - parser.parseType(ChunkSizesType) || - parser.resolveOperand(dynamicChunkSizes, ChunkSizesType, + parser.parseType(chunkSizesType) || + parser.resolveOperand(dynamicChunkSizes, chunkSizesType, result.operands)) { return failure(); } @@ -3399,9 +3399,9 @@ void transform::ContinuousTileSizesOp::getEffects( } static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, - Type targetType, Type tile_sizes, + Type targetType, Type tileSizes, Type) { - printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes}); + printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes}); } static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 507597b..94947b7 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,45 @@ public: return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern<ReinterpretCastOp> { +public: + using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + unsigned srcStaticCount = llvm::count_if( + llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { return isa<Attribute>(ofr); }); + + SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()}; + SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes(); + SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides(); + + // TODO: Using counting comparison instead of direct comparison because + // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns + // IntegerAttrs, while constifyIndexValues (and therefore + // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs. + if (srcStaticCount == + llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides), + [](OpFoldResult ofr) { return isa<Attribute>(ofr); })) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides); + + rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); + results.add<ReinterpretCastOpExtractStridedMetadataFolder, + ReinterpretCastOpConstantFolder>(context); } FailureOr<std::optional<SmallVector<Value>>> diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index 49b7162..6f815ae 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -121,7 +121,7 @@ struct EmulateWideIntPass final [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); RewritePatternSet patterns(ctx); - // Add common pattenrs to support contants, functions, etc. + // Add common patterns to support contants, functions, etc. arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 642ced9..90cbbd8 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -40,6 +40,16 @@ static bool isScalarLikeType(Type type) { return type.isIntOrIndexOrFloat() || isa<ComplexType>(type); } +/// Helper function to attach the `VarName` attribute to an operation +/// if a variable name is provided. +static void attachVarNameAttr(Operation *op, OpBuilder &builder, + StringRef varName) { + if (!varName.empty()) { + auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName); + op->setAttr(acc::getVarNameAttrName(), varNameAttr); + } +} + struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, MemRefType> { @@ -83,7 +93,9 @@ struct MemRefPointerLikeModel // then we can generate an alloca operation. if (memrefTy.hasStaticShape()) { needsFree = false; // alloca doesn't need deallocation - return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy); + attachVarNameAttr(allocaOp, builder, varName); + return allocaOp.getResult(); } // For dynamic memrefs, extract sizes from the original variable if @@ -103,8 +115,10 @@ struct MemRefPointerLikeModel // Static dimensions are handled automatically by AllocOp } needsFree = true; // alloc needs deallocation - return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) - .getResult(); + auto allocOp = + memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes); + attachVarNameAttr(allocOp, builder, varName); + return allocOp.getResult(); } // TODO: Unranked not yet supported. diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 135c033..645cbff 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op, } template <typename It> -bool isUnique(It begin, It end) { +static bool isUnique(It begin, It end) { if (begin == end) { return true; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index a1711a6..069191c 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -143,8 +143,8 @@ void VarInfo::setNum(Var::Num n) { /// Helper function for `assertUsageConsistency` to better handle SMLoc /// mismatches. -LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc -minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { +[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, + llvm::SMLoc sm2) { const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1)); assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index f539502..684c088 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -43,8 +43,8 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// #ifndef NDEBUG -LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, - Location loc, Value memref) { +[[maybe_unused]] static void dumpIndexMemRef(OpBuilder &builder, Location loc, + Value memref) { memref = memref::CastOp::create( builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); createFuncCall(builder, loc, "printMemrefInd", TypeRange{}, diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 5aad671..1cba1bb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { @@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) { } TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { - return TargetEnvAttr::get(context, Level::eightK, + return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK, {Profile::pro_int, Profile::pro_fp}, {}); } @@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp index bcb880a..a0661e4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -61,8 +61,8 @@ public: ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - const auto targetEnvAttr = - TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + const auto targetEnvAttr = TargetEnvAttr::get( + ctx, specificationVersion, level, selectedProfiles, selectedExtensions); mod->setAttr(TargetEnvAttr::name, targetEnvAttr); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 20f9333..f072e3e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { //===----------------------------------------------------------------------===// template <typename T> -FailureOr<SmallVector<T>> -TosaProfileCompliance::getOperatorDefinition(Operation *op, - CheckCondition &condition) { +FailureOr<OpComplianceInfo<T>> +TosaProfileCompliance::getOperatorDefinition(Operation *op) { const std::string opName = op->getName().getStringRef().str(); const auto complianceMap = getProfileComplianceMap<T>(); const auto it = complianceMap.find(opName); if (it == complianceMap.end()) return {}; - return findMatchedProfile<T>(op, it->second, condition); + return findMatchedEntry<T>(op, it->second); } template <typename T> @@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( if (specRequiredModeSet.size() == 0) return success(); - CheckCondition condition = CheckCondition::invalid; - const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition); - if (failed(maybeOpRequiredMode)) { + const auto maybeOpDefinition = getOperatorDefinition<T>(op); + if (failed(maybeOpDefinition)) { // Operators such as control-flow and shape ops do not have an operand type // restriction. When the profile compliance information of operation is not // found, confirm if the target have enabled the profile required from the // specification. - int mode_count = 0; + int modeCount = 0; for (const auto &cands : specRequiredModeSet) { if (targetEnv.allowsAnyOf(cands)) return success(); - mode_count += cands.size(); + modeCount += cands.size(); } op->emitOpError() << "illegal: requires" - << (mode_count > 1 ? " any of " : " ") << "[" + << (modeCount > 1 ? " any of " : " ") << "[" << llvm::join(stringifyProfile<T>(specRequiredModeSet), ", ") << "] but not enabled in target\n"; @@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( // Find the required profiles or extensions according to the operand type // combination. - const auto opRequiredMode = maybeOpRequiredMode.value(); + const auto opDefinition = maybeOpDefinition.value(); + const SmallVector<T> opRequiredMode = opDefinition.mode; + const CheckCondition condition = opDefinition.condition; + if (opRequiredMode.size() == 0) { // No matched restriction found. return success(); @@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( } } + // Ensure the matched op compliance version does not exceed the target + // specification version. + const VersionedTypeInfo versionedTypeInfo = + opDefinition.operandTypeInfoSet[0]; + const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second}; + const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()}; + if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) { + op->emitOpError() << "illegal: the target specification version (" + << stringifyVersion(targetVersion) + << ") is not backwards compatible with the op compliance " + "specification version (" + << stringifyVersion(complianceVersion) << ")\n"; + return failure(); + } + return success(); } @@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op, } LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { - CheckCondition condition = CheckCondition::invalid; - const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); - const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + const auto maybeProfDef = getOperatorDefinition<Profile>(op); + const auto maybeExtDef = getOperatorDefinition<Extension>(op); if (failed(maybeProfDef) && failed(maybeExtDef)) return success(); - const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || - (succeeded(maybeExtDef) && !maybeExtDef->empty()); + const bool hasEntry = + (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->mode.empty()); if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); @@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { SmallVector<TypeInfo> bestTypeInfo; const auto searchBestMatch = [&](auto map) { for (const auto &complianceInfos : map[opName]) { - for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) { + for (const auto &versionedTypeInfos : + complianceInfos.operandTypeInfoSet) { + const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first; const int matches = llvm::count_if( llvm::zip_equal(current, typeInfos), [&](const auto zipType) { return isSameTypeInfo(std::get<0>(zipType), @@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { // Find the profiles or extensions requirement according to the signature of // type of the operand list. template <typename T> -SmallVector<T> TosaProfileCompliance::findMatchedProfile( - Operation *op, SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition) { +OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry( + Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) { assert(compInfo.size() != 0 && "profile-based compliance information is empty"); @@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile( return {}; for (size_t i = 0; i < compInfo.size(); i++) { - SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet; - for (SmallVector<TypeInfo> expected : sets) { + SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet; + for (const auto &set : sets) { + SmallVector<TypeInfo> expected = set.first; assert(present.size() == expected.size() && "the entries for profile-based compliance do not match between " "the generated metadata and the type definition retrieved from " " the operation"); - bool is_found = true; + bool isFound = true; // Compare the type signature between the given operation and the // compliance metadata. for (size_t j = 0; j < expected.size(); j++) { if (!isSameTypeInfo(present[j], expected[j])) { // Verify the next mode set from the list. - is_found = false; + isFound = false; break; } } - if (is_found == true) { - condition = compInfo[i].condition; - return compInfo[i].mode; + if (isFound == true) { + SmallVector<VersionedTypeInfo> typeInfoSet{set}; + OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet, + compInfo[i].condition}; + return info; } } } diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp index 9a24c2b..a2cff6a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -21,10 +21,10 @@ using namespace mlir; // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc" diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 025ee9a..c809c502 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -91,7 +91,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) { // Check whether the two source vector dimensions that are greater than one // must be transposed with each other so that we can apply one of the 2-D - // transpose pattens. Otherwise, these patterns are not applicable. + // transpose patterns. Otherwise, these patterns are not applicable. if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], op.getPermutation())) return failure(); diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp index 89b62a2..a514ea9 100644 --- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp +++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Region.h" #include "mlir/IR/SymbolTable.h" @@ -39,28 +40,6 @@ void printElseRegion(OpAsmPrinter &opPrinter, Operation *op, opPrinter.printKeywordOrString("else "); opPrinter.printRegion(elseRegion); } - -ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) { - std::string keyword; - auto initLocation = opParser.getCurrentLocation(); - std::ignore = opParser.parseOptionalKeywordOrString(&keyword); - if (keyword == "nested" or keyword == "") { - visibility = StringAttr::get(opParser.getContext(), "nested"); - return ParseResult::success(); - } - - if (keyword == "public" || keyword == "private") { - visibility = StringAttr::get(opParser.getContext(), keyword); - return ParseResult::success(); - } - opParser.emitError(initLocation, "expecting symbol visibility"); - return ParseResult::failure(); -} - -void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op, - Attribute visibility) { - opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref()); -} } // namespace #define GET_OP_CLASSES @@ -167,10 +146,23 @@ Block *FuncOp::addEntryBlock() { void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState, StringRef symbol, FunctionType funcType) { - FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested"); + FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto *ctx = parser.getContext(); + std::string visibilityString; + auto loc = parser.getNameLoc(); + ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString); + bool exported{false}; + if (res.succeeded()) { + if (visibilityString != "exported") + return parser.emitError( + loc, "expecting either `exported` or symbol name. got ") + << visibilityString; + exported = true; + } + auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, function_interface_impl::VariadicFlag, @@ -191,11 +183,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { return builder.getFunctionType(argTypesWithoutLocal, results); }; - - return function_interface_impl::parseFunctionOp( + auto funcParseRes = function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); + if (exported) + result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx)); + return funcParseRes; } LogicalResult FuncOp::verifyBody() { @@ -224,9 +218,18 @@ LogicalResult FuncOp::verifyBody() { } void FuncOp::print(OpAsmPrinter &p) { + /// If exported, print it before and mask it before printing + /// using generic interface. + auto exported = getExported(); + if (exported) { + p << " exported"; + removeExportedAttr(); + } function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); + if (exported) + setExported(true); } //===----------------------------------------------------------------------===// @@ -237,38 +240,37 @@ void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, StringRef symbol, StringRef moduleName, StringRef importName, FunctionType type) { FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - type, {}, {}, odsBuilder.getStringAttr("nested")); + type, {}, {}); } //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// - -void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, Type type, bool isMutable) { - GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable, - odsBuilder.getStringAttr("nested")); -} - // Custom formats ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr symbolName; Type globalType; auto *ctx = parser.getContext(); - ParseResult res = parser.parseSymbolName( - symbolName, SymbolTable::getSymbolAttrName(), result.attributes); + std::string visibilityString; + auto loc = parser.getNameLoc(); + ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString); + if (res.succeeded()) { + if (visibilityString != "exported") + return parser.emitError( + loc, "expecting either `exported` or symbol name. got ") + << visibilityString; + result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx)); + } + res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(), + result.attributes); res = parser.parseType(globalType); result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType)); std::string mutableString; res = parser.parseOptionalKeywordOrString(&mutableString); if (res.succeeded() && mutableString == "mutable") result.addAttribute("isMutable", UnitAttr::get(ctx)); - std::string visibilityString; - res = parser.parseOptionalKeywordOrString(&visibilityString); - if (res.succeeded()) - result.addAttribute("sym_visibility", - StringAttr::get(ctx, visibilityString)); + res = parser.parseColon(); Region *globalInitRegion = result.addRegion(); res = parser.parseRegion(*globalInitRegion); @@ -276,11 +278,11 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { } void GlobalOp::print(OpAsmPrinter &printer) { + if (getExported()) + printer << " exported"; printer << " @" << getSymName().str() << " " << getType(); if (getIsMutable()) printer << " mutable"; - if (auto vis = getSymVisibility()) - printer << " " << *vis; printer << " :"; Region &body = getRegion(); if (!body.empty()) { @@ -319,13 +321,6 @@ GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // GlobalImportOp //===----------------------------------------------------------------------===// -void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, StringRef moduleName, - StringRef importName, Type type, bool isMutable) { - GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - type, isMutable, odsBuilder.getStringAttr("nested")); -} - ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) { auto *ctx = parser.getContext(); ParseResult res = parseImportOp(parser, result); @@ -335,12 +330,8 @@ ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) { res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString); if (res.succeeded() && mutableOrSymVisString == "mutable") { result.addAttribute("isMutable", UnitAttr::get(ctx)); - res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString); } - if (res.succeeded()) - result.addAttribute("sym_visibility", - StringAttr::get(ctx, mutableOrSymVisString)); res = parser.parseColon(); Type importedType; @@ -356,8 +347,6 @@ void GlobalImportOp::print(OpAsmPrinter &printer) { << "\" as @" << getSymName(); if (getIsMutable()) printer << " mutable"; - if (auto vis = getSymVisibility()) - printer << " " << *vis; printer << " : " << getType(); } @@ -431,27 +420,6 @@ LogicalResult LocalTeeOp::verify() { Block *LoopOp::getLabelTarget() { return &getBody().front(); } //===----------------------------------------------------------------------===// -// MemOp -//===----------------------------------------------------------------------===// - -void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, LimitType limit) { - MemOp::build(odsBuilder, odsState, symbol, limit, - odsBuilder.getStringAttr("nested")); -} - -//===----------------------------------------------------------------------===// -// MemImportOp -//===----------------------------------------------------------------------===// - -void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, StringRef moduleName, - StringRef importName, LimitType limits) { - MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - limits, odsBuilder.getStringAttr("nested")); -} - -//===----------------------------------------------------------------------===// // ReinterpretOp //===----------------------------------------------------------------------===// @@ -471,24 +439,3 @@ LogicalResult ReinterpretOp::verify() { //===----------------------------------------------------------------------===// void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {} - -//===----------------------------------------------------------------------===// -// TableOp -//===----------------------------------------------------------------------===// - -void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, TableType type) { - TableOp::build(odsBuilder, odsState, symbol, type, - odsBuilder.getStringAttr("nested")); -} - -//===----------------------------------------------------------------------===// -// TableImportOp -//===----------------------------------------------------------------------===// - -void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, StringRef moduleName, - StringRef importName, TableType type) { - TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - type, odsBuilder.getStringAttr("nested")); -} diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9beb22d..1599ae9 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } +// a helper utility to perform binary operation on OpFoldResult. +// If both a and b are attributes, it will simply return the result. +// Otherwise, the corresponding arith op will be generated, and an +// contant op will be created if one of them is an attribute. +template <typename ArithOp> +OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, + OpBuilder &builder) { + auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); + auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); + return builder.create<ArithOp>(loc, aVal, bVal).getResult(); +} + +// a helper utility to perform division operation on OpFoldResult and int64_t. +#define div(a, b) \ + genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform reminder operation on OpFoldResult and int64_t. +#define rem(a, b) \ + genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform multiply operation on OpFoldResult and int64_t. +#define mul(a, b) \ + genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform addition operation on two OpFoldResult. +#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder) + +// block the given offsets according to the block shape +// say the original offset is [y, x], and the block shape is [By, Bx], +// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] +SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets, + ArrayRef<int64_t> blockShape) { + + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + SmallVector<OpFoldResult> blockedOffsets; + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + return blockedOffsets; +} + +// Get strides as vector of integer for MemDesc. +SmallVector<int64_t> MemDescType::getStrideShape() { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + + ArrayAttr strideAttr = getStrideAttr(); + SmallVector<int64_t> strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast<IntegerAttr>(attr).getInt()); + } + + SmallVector<int64_t> innerBlkShape = getBlockShape(); + + // get perm from FCD to LCD + // perm[i] = the dim with i-th smallest stride + SmallVector<int, 4> perm = + llvm::to_vector<4>(llvm::seq<int>(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + + SmallVector<int64_t> innerBlkStride(innerBlkShape.size()); + innerBlkStride[perm[0]] = 1; + for (size_t i = 1; i < perm.size(); ++i) + innerBlkStride[perm[i]] = + innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector<int64_t> matrixShapeOrig(matrixShape.size()); + SmallVector<int64_t> BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } + + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; + + SmallVector<int64_t> outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = + outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; + } + + // combine the inner and outer strides + SmallVector<int64_t> blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + return blockedStrides; +} + +// Calculate the linear offset using the blocked offsets and stride +Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets) { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + SmallVector<int64_t> blockShape = getBlockShape(); + SmallVector<int64_t> strides = getStrideShape(); + SmallVector<OpFoldResult> blockedOffsets; + + // blockshape equal to matrixshape means no blocking + if (llvm::equal(blockShape, matrixShape)) { + // remove the outer dims from strides + strides.erase(strides.begin(), strides.begin() + matrixShape.size()); + } else { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + offsets = blockedOffsets; + } + + // Start with initial value as matrix descriptor's base offset. + Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); + for (size_t i = 0; i < offsets.size(); ++i) { + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + } + + return linearOffset; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788..abd12e2 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -20,8 +20,8 @@ #define DEBUG_TYPE "xegpu" -namespace mlir { -namespace xegpu { +using namespace mlir; +using namespace mlir::xegpu; static bool isSharedMemory(const MemRefType &memrefTy) { Attribute attr = memrefTy.getMemorySpace(); @@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } +LogicalResult +IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, + UnitAttr subgroup_block_io, + function_ref<InFlightDiagnostic()> emitError) { + + if (!dataTy) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + else + return success(); + } + + if (mdescTy.getRank() != 2) + return emitError() << "mem_desc must be 2D."; + + ArrayRef<int64_t> dataShape = dataTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + + if (dataShape.size() == 2) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitError() << "data shape must not exceed mem_desc shape."; + } else { + SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); + // if the subgroup_block_io attribute is set, mdescTy must have block + // attribute + if (subgroup_block_io && !blockShape.size()) + return emitError() << "mem_desc must have block attribute when " + "subgroup_block_io is set."; + // if the subgroup_block_io attribute is set, the memdesc should be row + // major + if (subgroup_block_io && mdescTy.isColMajor()) + return emitError() << "mem_desc should be row major when " + "subgroup_block_io is set."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + // Call the generated builder with all parameters (including optional ones as + // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { - VectorType resTy = getRes().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); + auto resTy = dyn_cast<VectorType>(getRes().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); - ArrayRef<int64_t> valueShape = resTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); - return success(); + return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult StoreMatrixOp::verify() { - VectorType dataTy = getData().getType(); - MemDescType mdescTy = getMemDesc().getType(); - - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - - ArrayRef<int64_t> dataShape = dataTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - - return success(); -} - -//===----------------------------------------------------------------------===// -// XeGPU_MemDescSubviewOp -//===----------------------------------------------------------------------===// - -void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, - Type resTy, Value src, - llvm::ArrayRef<OpFoldResult> offsets) { - llvm::SmallVector<Value> dynamicOffsets; - llvm::SmallVector<int64_t> staticOffsets; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); -} - -LogicalResult MemDescSubviewOp::verify() { - MemDescType srcTy = getSrc().getType(); - MemDescType resTy = getRes().getType(); - ArrayRef<int64_t> srcShape = srcTy.getShape(); - ArrayRef<int64_t> resShape = resTy.getShape(); - - if (srcTy.getRank() < resTy.getRank()) - return emitOpError("result rank must not exceed source rank."); - if (llvm::any_of( - llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed source shape."); - - if (srcTy.getStrides() != resTy.getStrides()) - return emitOpError("result must inherit the source strides."); - - return success(); + auto dataTy = dyn_cast<VectorType>(getData().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); + return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } -} // namespace xegpu -} // namespace mlir - namespace mlir { #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0f..aafa1b7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = op.getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType()); + assert(valueTy && "the value type must be vector type!"); + std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> { return failure(); Location loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType()); + assert(valueTy && "the value type must be vector type!"); ArrayRef<int64_t> shape = valueTy.getShape(); auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c28d2fc..31a967d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { return failure(); ArrayRef<int64_t> wgShape = op.getDataShape(); - VectorType valueTy = op.getRes().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType()); + assert(valueTy && "the value type must be vector type!"); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 89b81cf..5f63fe6 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -1204,7 +1204,7 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, /// present in result expressions is less than `dimCount` and the highest index /// of symbolic identifier present in result expressions is less than /// `symbolCount`. -LLVM_ATTRIBUTE_UNUSED static bool +[[maybe_unused]] static bool willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, ArrayRef<AffineExpr> results) { int64_t maxDimPosition = -1; diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index cb90ef8..d52d5e7 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( raw_ostream &os, const RecordKeeper &records, StringRef tag) : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} -void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef<const Record *> opDefs) { - NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); +void StaticVerifierFunctionEmitter::emitOpConstraints() { emitTypeConstraints(); emitAttrConstraints(); emitPropConstraints(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 9603813..857e31b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2604,6 +2604,7 @@ static constexpr std::array kExplicitLLVMFuncOpAttributes{ StringLiteral("denormal-fp-math-f32"), StringLiteral("fp-contract"), StringLiteral("frame-pointer"), + StringLiteral("inlinehint"), StringLiteral("instrument-function-entry"), StringLiteral("instrument-function-exit"), StringLiteral("memory"), @@ -2643,6 +2644,8 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, funcOp.setNoInline(true); if (func->hasFnAttribute(llvm::Attribute::AlwaysInline)) funcOp.setAlwaysInline(true); + if (func->hasFnAttribute(llvm::Attribute::InlineHint)) + funcOp.setInlineHint(true); if (func->hasFnAttribute(llvm::Attribute::OptimizeNone)) funcOp.setOptimizeNone(true); if (func->hasFnAttribute(llvm::Attribute::Convergent)) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 845a14f..147613f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1652,6 +1652,8 @@ static void convertFunctionAttributes(LLVMFuncOp func, llvmFunc->addFnAttr(llvm::Attribute::NoInline); if (func.getAlwaysInlineAttr()) llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline); + if (func.getInlineHintAttr()) + llvmFunc->addFnAttr(llvm::Attribute::InlineHint); if (func.getOptimizeNoneAttr()) llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone); if (func.getConvergentAttr()) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 0c3e87a..d9ad8fb 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2619,6 +2619,11 @@ LogicalResult ControlFlowStructurizer::structurize() { // region. We cannot handle such cases given that once a value is sinked into // the SelectionOp/LoopOp's region, there is no escape for it. for (auto *block : constructBlocks) { + if (!block->use_empty()) + return emitError(block->getParent()->getLoc(), + "failed control flow structurization: " + "block has uses outside of the " + "enclosing selection/loop construct"); for (Operation &op : *block) if (!op.use_empty()) return op.emitOpError("failed control flow structurization: value has " diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 51c6077..366ba8f 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" @@ -138,6 +139,10 @@ using ImportDesc = using parsed_inst_t = FailureOr<SmallVector<Value>>; +struct EmptyBlockMarker {}; +using BlockTypeParseResult = + std::variant<EmptyBlockMarker, TypeIdxRecord, Type>; + struct WasmModuleSymbolTables { SmallVector<FunctionSymbolRefContainer> funcSymbols; SmallVector<GlobalSymbolRefContainer> globalSymbols; @@ -175,6 +180,9 @@ class ParserHead; /// Wrapper around SmallVector to only allow access as push and pop on the /// stack. Makes sure that there are no "free accesses" on the stack to preserve /// its state. +/// This class also keep tracks of the Wasm labels defined by different ops, +/// which can be targeted by control flow ops. This can be modeled as part of +/// the Value Stack as Wasm control flow ops can only target enclosing labels. class ValueStack { private: struct LabelLevel { @@ -206,6 +214,16 @@ public: /// if an error occurs. LogicalResult pushResults(ValueRange results, Location *opLoc); + void addLabelLevel(LabelLevelOpInterface levelOp) { + labelLevel.push_back({values.size(), levelOp}); + LDBG() << "Adding a new frame context to ValueStack"; + } + + void dropLabelLevel() { + assert(!labelLevel.empty() && "Trying to drop a frame from empty context"); + auto newSize = labelLevel.pop_back_val().stackIdx; + values.truncate(newSize); + } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// A simple dump function for debugging. /// Writes output to llvm::dbgs(). @@ -214,6 +232,7 @@ public: private: SmallVector<Value> values; + SmallVector<LabelLevel> labelLevel; }; using local_val_t = TypedValue<wasmssa::LocalRefType>; @@ -248,6 +267,19 @@ private: buildNumericOp(OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr); + /// Construct a conversion operation of type \p opType that takes a value from + /// type \p inputType on the stack and will produce a value of type + /// \p outputType. + /// + /// \p opType - The WASM dialect operation to build. + /// \p inputType - The operand type for the built instruction. + /// \p outputType - The result type for the built instruction. + /// + /// \returns The parsed instruction result, or failure. + template <typename opType, typename inputType, typename outputType, + typename... extraArgsT> + inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...); + /// This function generates a dispatch tree to associate an opcode with a /// parser. Parsers are registered by specialising the /// `parseSpecificInstruction` function for the op code to handle. @@ -280,11 +312,105 @@ private: } } + /// + /// RAII guard class for creating a nesting level + /// + struct NestingContextGuard { + NestingContextGuard(ExpressionParser &parser, LabelLevelOpInterface levelOp) + : parser{parser} { + parser.addNestingContextLevel(levelOp); + } + NestingContextGuard(NestingContextGuard &&other) : parser{other.parser} { + other.shouldDropOnDestruct = false; + } + NestingContextGuard(NestingContextGuard const &) = delete; + ~NestingContextGuard() { + if (shouldDropOnDestruct) + parser.dropNestingContextLevel(); + } + ExpressionParser &parser; + bool shouldDropOnDestruct = true; + }; + + void addNestingContextLevel(LabelLevelOpInterface levelOp) { + valueStack.addLabelLevel(levelOp); + } + + void dropNestingContextLevel() { + // Should always succeed as we are droping the frame that was previously + // created. + valueStack.dropLabelLevel(); + } + + llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder, + EmptyBlockMarker) { + return builder.getFunctionType({}, {}); + } + + llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder, + TypeIdxRecord type) { + if (type.id >= symbols.moduleFuncTypes.size()) + return emitError(*currentOpLoc, + "type index references nonexistent type (") + << type.id << "). Only " << symbols.moduleFuncTypes.size() + << " types are registered"; + return symbols.moduleFuncTypes[type.id]; + } + + llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder, + Type valType) { + return builder.getFunctionType({}, {valType}); + } + + llvm::FailureOr<FunctionType> + getFuncTypeFor(OpBuilder &builder, BlockTypeParseResult parseResult) { + return std::visit( + [this, &builder](auto value) { return getFuncTypeFor(builder, value); }, + parseResult); + } + + llvm::FailureOr<FunctionType> + getFuncTypeFor(OpBuilder &builder, + llvm::FailureOr<BlockTypeParseResult> parseResult) { + if (llvm::failed(parseResult)) + return failure(); + return getFuncTypeFor(builder, *parseResult); + } + + llvm::FailureOr<FunctionType> parseBlockFuncType(OpBuilder &builder); + struct ParseResultWithInfo { SmallVector<Value> opResults; std::byte endingByte; }; + template <typename FilterT = ByteSequence<WasmBinaryEncoding::endByte>> + /// @param blockToFill: the block which content will be populated + /// @param resType: the type that this block is supposed to return + llvm::FailureOr<std::byte> + parseBlockContent(OpBuilder &builder, Block *blockToFill, TypeRange resTypes, + Location opLoc, LabelLevelOpInterface levelOp, + FilterT parseEndBytes = {}) { + OpBuilder::InsertionGuard guard{builder}; + builder.setInsertionPointToStart(blockToFill); + LDBG() << "parsing a block of type " + << builder.getFunctionType(blockToFill->getArgumentTypes(), + resTypes); + auto nC = addNesting(levelOp); + + if (failed(pushResults(blockToFill->getArguments()))) + return failure(); + auto bodyParsingRes = parse(builder, parseEndBytes); + if (failed(bodyParsingRes)) + return failure(); + auto returnOperands = popOperands(resTypes); + if (failed(returnOperands)) + return failure(); + builder.create<BlockReturnOp>(opLoc, *returnOperands); + LDBG() << "end of parsing of a block"; + return bodyParsingRes->endingByte; + } + public: template <std::byte ParseEndByte = WasmBinaryEncoding::endByte> parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {}); @@ -294,7 +420,11 @@ public: parse(OpBuilder &builder, ByteSequence<ExpressionParseEnd...> parsingEndFilters); - FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) { + NestingContextGuard addNesting(LabelLevelOpInterface levelOp) { + return NestingContextGuard{*this, levelOp}; + } + + FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes) { return valueStack.popOperands(operandTypes, ¤tOpLoc.value()); } @@ -308,6 +438,12 @@ public: template <typename OpToCreate> parsed_inst_t parseSetOrTee(OpBuilder &); + /// Blocks and Loops have a similar format and differ only in how their exit + /// is handled which doesn´t matter at parsing time. Factorizes in one + /// function. + template <typename OpToCreate> + parsed_inst_t parseBlockLikeOp(OpBuilder &); + private: std::optional<Location> currentOpLoc; ParserHead &parser; @@ -586,6 +722,29 @@ public: return success(); } + llvm::FailureOr<BlockTypeParseResult> parseBlockType(MLIRContext *ctx) { + auto loc = getLocation(); + auto blockIndicator = peek(); + if (failed(blockIndicator)) + return failure(); + if (*blockIndicator == WasmBinaryEncoding::Type::emptyBlockType) { + offset += 1; + return {EmptyBlockMarker{}}; + } + if (isValueOneOf(*blockIndicator, valueTypesEncodings)) + return parseValueType(ctx); + /// Block type idx is a 32 bit positive integer encoded as a 33 bit signed + /// value + auto typeIdx = parseI64(); + if (failed(typeIdx)) + return failure(); + if (*typeIdx < 0 || *typeIdx > std::numeric_limits<uint32_t>::max()) + return emitError(loc, "type ID should be representable with an unsigned " + "32 bits integer. Got ") + << *typeIdx; + return {TypeIdxRecord{static_cast<uint32_t>(*typeIdx)}}; + } + bool end() const { return curHead().empty(); } ParserHead copy() const { return *this; } @@ -701,17 +860,41 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) { void ValueStack::dump() const { llvm::dbgs() << "================= Wasm ValueStack =======================\n"; llvm::dbgs() << "size: " << size() << "\n"; + llvm::dbgs() << "nbFrames: " << labelLevel.size() << '\n'; llvm::dbgs() << "<Top>" << "\n"; // Stack is pushed to via push_back. Therefore the top of the stack is the // end of the vector. Iterate in reverse so that the first thing we print // is the top of the stack. + auto indexGetter = [this]() { + size_t idx = labelLevel.size(); + return [this, idx]() mutable -> std::optional<std::pair<size_t, size_t>> { + llvm::dbgs() << "IDX: " << idx << '\n'; + if (idx == 0) + return std::nullopt; + auto frameId = idx - 1; + auto frameLimit = labelLevel[frameId].stackIdx; + idx -= 1; + return {{frameId, frameLimit}}; + }; + }; + auto getNextFrameIndex = indexGetter(); + auto nextFrameIdx = getNextFrameIndex(); size_t stackSize = size(); - for (size_t idx = 0; idx < stackSize; idx++) { + for (size_t idx = 0; idx < stackSize; ++idx) { size_t actualIdx = stackSize - 1 - idx; + while (nextFrameIdx && (nextFrameIdx->second > actualIdx)) { + llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first + << ")\n"; + nextFrameIdx = getNextFrameIndex(); + } llvm::dbgs() << " "; values[actualIdx].dump(); } + while (nextFrameIdx) { + llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first << ")\n"; + nextFrameIdx = getNextFrameIndex(); + } llvm::dbgs() << "<Bottom>" << "\n"; llvm::dbgs() << "=========================================================\n"; @@ -726,7 +909,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { return emitError(*opLoc, "stack doesn't contain enough values. trying to get ") << operandTypes.size() << " operands on a stack containing only " - << values.size() << " values."; + << values.size() << " values"; size_t stackIdxOffset = values.size() - operandTypes.size(); SmallVector<Value> res{}; res.reserve(operandTypes.size()); @@ -735,8 +918,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { Type stackType = operand.getType(); if (stackType != operandTypes[i]) return emitError(*opLoc, "invalid operand type on stack. expecting ") - << operandTypes[i] << ", value on stack is of type " << stackType - << "."; + << operandTypes[i] << ", value on stack is of type " << stackType; LDBG() << " POP: " << operand; res.push_back(operand); } @@ -792,6 +974,151 @@ ExpressionParser::parse(OpBuilder &builder, } } +llvm::FailureOr<FunctionType> +ExpressionParser::parseBlockFuncType(OpBuilder &builder) { + return getFuncTypeFor(builder, parser.parseBlockType(builder.getContext())); +} + +template <typename OpToCreate> +parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) { + auto opLoc = currentOpLoc; + auto funcType = parseBlockFuncType(builder); + if (failed(funcType)) + return failure(); + + auto inputTypes = funcType->getInputs(); + auto inputOps = popOperands(inputTypes); + if (failed(inputOps)) + return failure(); + + Block *curBlock = builder.getBlock(); + Region *curRegion = curBlock->getParent(); + auto resTypes = funcType->getResults(); + llvm::SmallVector<Location> locations{}; + locations.resize(resTypes.size(), *currentOpLoc); + auto *successor = + builder.createBlock(curRegion, curRegion->end(), resTypes, locations); + builder.setInsertionPointToEnd(curBlock); + auto blockOp = + builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor); + auto *blockBody = blockOp.createBlock(); + if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp))) + return failure(); + builder.setInsertionPointToStart(successor); + return {ValueRange{successor->getArguments()}}; +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>( + OpBuilder &builder) { + return parseBlockLikeOp<BlockOp>(builder); +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>( + OpBuilder &builder) { + return parseBlockLikeOp<LoopOp>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder) { + auto opLoc = currentOpLoc; + auto funcType = parseBlockFuncType(builder); + if (failed(funcType)) + return failure(); + + LDBG() << "Parsing an if instruction of type " << *funcType; + auto inputTypes = funcType->getInputs(); + auto conditionValue = popOperands(builder.getI32Type()); + if (failed(conditionValue)) + return failure(); + auto inputOps = popOperands(inputTypes); + if (failed(inputOps)) + return failure(); + + Block *curBlock = builder.getBlock(); + Region *curRegion = curBlock->getParent(); + auto resTypes = funcType->getResults(); + llvm::SmallVector<Location> locations{}; + locations.resize(resTypes.size(), *currentOpLoc); + auto *successor = + builder.createBlock(curRegion, curRegion->end(), resTypes, locations); + builder.setInsertionPointToEnd(curBlock); + auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(), + *inputOps, successor); + auto *ifEntryBlock = ifOp.createIfBlock(); + constexpr auto ifElseFilter = + ByteSequence<WasmBinaryEncoding::endByte, + WasmBinaryEncoding::OpCode::elseOpCode>{}; + auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc, + ifOp, ifElseFilter); + if (failed(parseIfRes)) + return failure(); + if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) { + LDBG() << " else block is present."; + Block *elseEntryBlock = ifOp.createElseBlock(); + auto parseElseRes = + parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp); + if (failed(parseElseRes)) + return failure(); + } + builder.setInsertionPointToStart(successor); + return {ValueRange{successor->getArguments()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder) { + auto level = parser.parseLiteral<uint32_t>(); + if (failed(level)) + return failure(); + Block *curBlock = builder.getBlock(); + Region *curRegion = curBlock->getParent(); + auto sip = builder.saveInsertionPoint(); + Block *elseBlock = builder.createBlock(curRegion, curRegion->end()); + auto condition = popOperands(builder.getI32Type()); + if (failed(condition)) + return failure(); + builder.restoreInsertionPoint(sip); + auto targetOp = + LabelBranchingOpInterface::getTargetOpFromBlock(curBlock, *level); + if (failed(targetOp)) + return failure(); + auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes(); + auto branchArgs = popOperands(inputTypes); + if (failed(branchArgs)) + return failure(); + builder.create<BranchIfOp>(*currentOpLoc, condition->front(), + builder.getUI32IntegerAttr(*level), *branchArgs, + elseBlock); + builder.setInsertionPointToStart(elseBlock); + return {*branchArgs}; +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>( + OpBuilder &builder) { + auto loc = *currentOpLoc; + auto funcIdx = parser.parseLiteral<uint32_t>(); + if (failed(funcIdx)) + return failure(); + if (*funcIdx >= symbols.funcSymbols.size()) + return emitError(loc, "Invalid function index: ") << *funcIdx; + auto callee = symbols.funcSymbols[*funcIdx]; + llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs(); + llvm::ArrayRef<Type> resTypes = callee.functionType.getResults(); + parsed_inst_t inOperands = popOperands(inTypes); + if (failed(inOperands)) + return failure(); + auto callOp = + builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands); + return {callOp.getResults()}; +} + template <> inline parsed_inst_t ExpressionParser::parseSpecificInstruction< WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) { @@ -834,7 +1161,7 @@ parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) { if (valueStack.empty()) return emitError( *currentOpLoc, - "invalid stack access, trying to access a value on an empty stack."); + "invalid stack access, trying to access a value on an empty stack"); parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType()); if (failed(poppedOp)) @@ -1000,11 +1327,23 @@ inline parsed_inst_t ExpressionParser::buildNumericOp( BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign) BUILD_NUMERIC_BINOP_FP(DivOp, div) +BUILD_NUMERIC_BINOP_FP(GeOp, ge) +BUILD_NUMERIC_BINOP_FP(GtOp, gt) +BUILD_NUMERIC_BINOP_FP(LeOp, le) +BUILD_NUMERIC_BINOP_FP(LtOp, lt) BUILD_NUMERIC_BINOP_FP(MaxOp, max) BUILD_NUMERIC_BINOP_FP(MinOp, min) BUILD_NUMERIC_BINOP_INT(AndOp, and) BUILD_NUMERIC_BINOP_INT(DivSIOp, divS) BUILD_NUMERIC_BINOP_INT(DivUIOp, divU) +BUILD_NUMERIC_BINOP_INT(GeSIOp, geS) +BUILD_NUMERIC_BINOP_INT(GeUIOp, geU) +BUILD_NUMERIC_BINOP_INT(GtSIOp, gtS) +BUILD_NUMERIC_BINOP_INT(GtUIOp, gtU) +BUILD_NUMERIC_BINOP_INT(LeSIOp, leS) +BUILD_NUMERIC_BINOP_INT(LeUIOp, leU) +BUILD_NUMERIC_BINOP_INT(LtSIOp, ltS) +BUILD_NUMERIC_BINOP_INT(LtUIOp, ltU) BUILD_NUMERIC_BINOP_INT(OrOp, or) BUILD_NUMERIC_BINOP_INT(RemSIOp, remS) BUILD_NUMERIC_BINOP_INT(RemUIOp, remU) @@ -1015,7 +1354,9 @@ BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS) BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU) BUILD_NUMERIC_BINOP_INT(XOrOp, xor) BUILD_NUMERIC_BINOP_INTFP(AddOp, add) +BUILD_NUMERIC_BINOP_INTFP(EqOp, eq) BUILD_NUMERIC_BINOP_INTFP(MulOp, mul) +BUILD_NUMERIC_BINOP_INTFP(NeOp, ne) BUILD_NUMERIC_BINOP_INTFP(SubOp, sub) BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs) BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil) @@ -1025,6 +1366,7 @@ BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt) BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc) BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz) BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz) +BUILD_NUMERIC_UNARY_OP_INT(EqzOp, eqz) BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt) // Don't need these anymore so let's undef them. @@ -1036,6 +1378,105 @@ BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt) #undef BUILD_NUMERIC_OP #undef BUILD_NUMERIC_CAST_OP +template <typename opType, typename inputType, typename outputType, + typename... extraArgsT> +inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder, + extraArgsT... extraArgs) { + static_assert(std::is_arithmetic_v<inputType>, + "InputType should be an arithmetic type"); + static_assert(std::is_arithmetic_v<outputType>, + "OutputType should be an arithmetic type"); + auto intype = buildLiteralType<inputType>(builder); + auto outType = buildLiteralType<outputType>(builder); + auto operand = popOperands(intype); + if (failed(operand)) + return failure(); + auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(), + extraArgs...); + LDBG() << "Built operation: " << op; + return {{op.getResult()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::demoteF64ToF32>(OpBuilder &builder) { + return buildConvertOp<DemoteOp, double, float>(builder); +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::wrap>( + OpBuilder &builder) { + return buildConvertOp<WrapOp, int64_t, int32_t>(builder); +} + +#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP) \ + template <> \ + inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::SOURCE_OP>(OpBuilder & builder) { \ + return buildConvertOp<TARGET_OP, IN_T, OUT_T>(builder); \ + } + +#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH) \ + BUILD_CONVERSION_OP(uint32_t, DEST_T, convertUI32F##WIDTH, ConvertUOp) \ + BUILD_CONVERSION_OP(int32_t, DEST_T, convertSI32F##WIDTH, ConvertSOp) \ + BUILD_CONVERSION_OP(uint64_t, DEST_T, convertUI64F##WIDTH, ConvertUOp) \ + BUILD_CONVERSION_OP(int64_t, DEST_T, convertSI64F##WIDTH, ConvertSOp) + +BUILD_CONVERT_OP_FOR(float, 32) +BUILD_CONVERT_OP_FOR(double, 64) + +#undef BUILD_CONVERT_OP_FOR + +BUILD_CONVERSION_OP(int32_t, int64_t, extendS, ExtendSI32Op) +BUILD_CONVERSION_OP(int32_t, int64_t, extendU, ExtendUI32Op) + +#undef BUILD_CONVERSION_OP + +#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH) \ + template <> \ + parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::extendI##IT_WIDTH##EXTRACT_WIDTH##S>( \ + OpBuilder & builder) { \ + using inout_t = int##IT_WIDTH##_t; \ + auto attr = builder.getUI32IntegerAttr(EXTRACT_WIDTH); \ + return buildConvertOp<ExtendLowBitsSOp, inout_t, inout_t>(builder, attr); \ + } + +BUILD_SLICE_EXTEND_PARSER(32, 8) +BUILD_SLICE_EXTEND_PARSER(32, 16) +BUILD_SLICE_EXTEND_PARSER(64, 8) +BUILD_SLICE_EXTEND_PARSER(64, 16) +BUILD_SLICE_EXTEND_PARSER(64, 32) + +#undef BUILD_SLICE_EXTEND_PARSER + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::promoteF32ToF64>(OpBuilder &builder) { + return buildConvertOp<PromoteOp, float, double>(builder); +} + +#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE) \ + template <> \ + inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::reinterpretF##WIDTH##AsI##WIDTH>(OpBuilder & \ + builder) { \ + return buildConvertOp<ReinterpretOp, FP_TYPE, int##WIDTH##_t>(builder); \ + } \ + \ + template <> \ + inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::reinterpretI##WIDTH##AsF##WIDTH>(OpBuilder & \ + builder) { \ + return buildConvertOp<ReinterpretOp, int##WIDTH##_t, FP_TYPE>(builder); \ + } + +BUILD_REINTERPRET_PARSER(32, float) +BUILD_REINTERPRET_PARSER(64, double) + +#undef BUILD_REINTERPRET_PARSER + class WasmBinaryParser { private: struct SectionRegistry { @@ -1153,7 +1594,7 @@ private: if (tid.id >= symbols.moduleFuncTypes.size()) return emitError(loc, "invalid type id: ") << tid.id << ". Only " << symbols.moduleFuncTypes.size() - << " type registration."; + << " type registrations"; FunctionType type = symbols.moduleFuncTypes[tid.id]; std::string symbol = symbols.getNewFuncSymbolName(); auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName, @@ -1221,7 +1662,7 @@ public: FileLineColLoc magicLoc = parser.getLocation(); FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size()); if (failed(magic) || magic->compare(wasmHeader)) { - emitError(magicLoc, "source file does not contain valid Wasm header."); + emitError(magicLoc, "source file does not contain valid Wasm header"); return; } auto const expectedVersionString = StringRef{"\1\0\0\0", 4}; @@ -1391,7 +1832,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph, return failure(); Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol); - SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public); + op->setAttr("exported", UnitAttr::get(op->getContext())); StringAttr symName = SymbolTable::getSymbolName(op); return SymbolTable{mOp}.rename(symName, *exportName); } diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp index 9670285..3fda5a7 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -93,7 +93,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { // Emit function to add the generated matchers to the pattern list. os << "template <typename... ConfigsT>\n" - "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" + "[[maybe_unused]] static void populateGeneratedPDLLPatterns(" "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n"; for (const auto &name : patternNames) os << " patterns.add<" << name diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 9f5246d..ffa96ad 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -137,6 +137,16 @@ declare_mlir_dialect_python_bindings( declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/OpenACCOps.td + SOURCES + dialects/openacc.py + DIALECT_NAME acc + DEPENDS acc_common_td + ) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUOps.td SOURCES_GLOB dialects/gpu/*.py DIALECT_NAME gpu diff --git a/mlir/python/mlir/dialects/OpenACCOps.td b/mlir/python/mlir/dialects/OpenACCOps.td new file mode 100644 index 0000000..69a3002 --- /dev/null +++ b/mlir/python/mlir/dialects/OpenACCOps.td @@ -0,0 +1,14 @@ +//===-- OpenACCOps.td - Entry point for OpenACCOps bind ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_OPENACC_OPS +#define PYTHON_BINDINGS_OPENACC_OPS + +include "mlir/Dialect/OpenACC/OpenACCOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py index 4cd80aa..b14ea68 100644 --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -3,5 +3,151 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_ops_gen import * +from .._gpu_ops_gen import _Dialect from .._gpu_enum_gen import * from ..._mlir_libs._mlirDialectsGPU import * +from typing import Callable, Sequence, Union, Optional, List + +try: + from ...ir import ( + FunctionType, + TypeAttr, + StringAttr, + UnitAttr, + Block, + InsertionPoint, + ArrayAttr, + Type, + DictAttr, + Attribute, + DenseI32ArrayAttr, + ) + from .._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GPUFuncOp(GPUFuncOp): + __doc__ = GPUFuncOp.__doc__ + + KERNEL_ATTR_NAME = "gpu.kernel" + KNOWN_BLOCK_SIZE_ATTR_NAME = "known_block_size" + KNOWN_GRID_SIZE_ATTR_NAME = "known_grid_size" + + FUNCTION_TYPE_ATTR_NAME = "function_type" + SYM_NAME_ATTR_NAME = "sym_name" + ARGUMENT_ATTR_NAME = "arg_attrs" + RESULT_ATTR_NAME = "res_attrs" + + def __init__( + self, + function_type: Union[FunctionType, TypeAttr], + sym_name: Optional[Union[str, StringAttr]] = None, + kernel: Optional[bool] = None, + workgroup_attrib_attrs: Optional[Sequence[dict]] = None, + private_attrib_attrs: Optional[Sequence[dict]] = None, + known_block_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None, + known_grid_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None, + loc=None, + ip=None, + body_builder: Optional[Callable[[GPUFuncOp], None]] = None, + ): + """ + Create a GPUFuncOp with the provided `function_type`, `sym_name`, + `kernel`, `workgroup_attrib_attrs`, `private_attrib_attrs`, `known_block_size`, + `known_grid_size`, and `body_builder`. + - `function_type` is a FunctionType or a TypeAttr. + - `sym_name` is a string or a StringAttr representing the function name. + - `kernel` is a boolean representing whether the function is a kernel. + - `workgroup_attrib_attrs` is an optional list of dictionaries. + - `private_attrib_attrs` is an optional list of dictionaries. + - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size. + - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size. + - `body_builder` is an optional callback. When provided, a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + function_type = ( + TypeAttr.get(function_type) + if not isinstance(function_type, TypeAttr) + else function_type + ) + super().__init__( + function_type, + workgroup_attrib_attrs=workgroup_attrib_attrs, + private_attrib_attrs=private_attrib_attrs, + loc=loc, + ip=ip, + ) + + if isinstance(sym_name, str): + self.attributes[self.SYM_NAME_ATTR_NAME] = StringAttr.get(sym_name) + elif isinstance(sym_name, StringAttr): + self.attributes[self.SYM_NAME_ATTR_NAME] = sym_name + else: + raise ValueError("sym_name must be a string or a StringAttr") + + if kernel: + self.attributes[self.KERNEL_ATTR_NAME] = UnitAttr.get() + + if known_block_size is not None: + if isinstance(known_block_size, Sequence): + block_size = DenseI32ArrayAttr.get(known_block_size) + self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = block_size + elif isinstance(known_block_size, DenseI32ArrayAttr): + self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = known_block_size + else: + raise ValueError( + "known_block_size must be a list of integers or a DenseI32ArrayAttr" + ) + + if known_grid_size is not None: + if isinstance(known_grid_size, Sequence): + grid_size = DenseI32ArrayAttr.get(known_grid_size) + self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = grid_size + elif isinstance(known_grid_size, DenseI32ArrayAttr): + self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = known_grid_size + else: + raise ValueError( + "known_grid_size must be a list of integers or a DenseI32ArrayAttr" + ) + + if body_builder is not None: + with InsertionPoint(self.add_entry_block()): + body_builder(self) + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes[self.SYM_NAME_ATTR_NAME]) + + @property + def is_kernel(self) -> bool: + return self.KERNEL_ATTR_NAME in self.attributes + + def add_entry_block(self) -> Block: + if len(self.body.blocks) > 0: + raise RuntimeError(f"Entry block already exists for {self.name.value}") + + function_type = self.function_type.value + return self.body.blocks.append( + *function_type.inputs, + arg_locs=[self.location for _ in function_type.inputs], + ) + + @property + def entry_block(self) -> Block: + if len(self.body.blocks) == 0: + raise RuntimeError( + f"Entry block does not exist for {self.name.value}." + + " Do you need to call the add_entry_block() method on this GPUFuncOp?" + ) + return self.body.blocks[0] + + @property + def arguments(self) -> Sequence[Type]: + return self.function_type.value.inputs diff --git a/mlir/python/mlir/dialects/openacc.py b/mlir/python/mlir/dialects/openacc.py new file mode 100644 index 0000000..057f71a --- /dev/null +++ b/mlir/python/mlir/dialects/openacc.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._acc_ops_gen import * diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index dbff233..455f886 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s +// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9 +// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9 module @test_module { // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16 @@ -596,3 +597,76 @@ module @test_module { func.return %result : vector<2x2xf16> } } + +// ----- + +// f16 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_f16 +func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 { + %r = math.clampf %x to [%lo, %hi] : f16 + return %r : f16 + // POST9: rocdl.fmed3 {{.*}} : f16 + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : f16 +} + +// f32 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_f32 +func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { + %r = math.clampf %x to [%lo, %hi] : f32 + return %r : f32 + // POST9: rocdl.fmed3 {{.*}} : f32 + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : f32 +} + +// ----- + +// Vector f16 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_vector_f16 +func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> { + %r = math.clampf %x to [%lo, %hi] : vector<2xf16> + return %r : vector<2xf16> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2xf16> +} + +// ----- + +// Vector f32 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_vector_f32 +func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> { + %r = math.clampf %x to [%lo, %hi] : vector<2xf32> + return %r : vector<2xf32> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf32> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2xf32> +} + +// ----- + +// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors) +// CHECK-LABEL: func.func @clampf_vector_2d_f16 +func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> { + %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16> + return %r : vector<2x2xf16> + // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>> + // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2x2xf16> +} + +// ----- +// CHECK-LABEL: func.func @clampf_bf16 +func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 { + %r = math.clampf %x to [%lo, %hi] : bf16 + return %r : bf16 + // CHECK: math.clampf {{.*}} : bf16 + // CHECK-NOT: rocdl.fmed3 +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2d33888..d669a3b 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -76,6 +76,18 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> { // ----- +func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1xf32> + return %0 : vector<1xf32> +} +// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32 +// CHECK-SAME: %[[A:.*]]: f32) +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK-NOT: llvm.shufflevector +// CHECK: return %[[T0]] : vector<1xf32> + +// ----- + func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> { %0 = vector.broadcast %arg0 : f32 to vector<[2]xf32> return %0 : vector<[2]xf32> diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index e6f22f0..a9ab0be 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -1,17 +1,13 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> -#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]> -#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]> - -gpu.module @load_store_check { +gpu.module @test_kernel { // CHECK-LABEL: func.func @dpas( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { // Loads are checked in a separate test. // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>} // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> return %d : vector<8xf32> } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir new file mode 100644 index 0000000..d4cb493 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -0,0 +1,201 @@ +// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s + +gpu.module @test_kernel [#xevm.target<chip = "pvc">] { + + // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> + // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) + //CHECK-LABEL: load_store_matrix_1 + gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + + gpu.return %1: f32 + } + +// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_2 + gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c13:.*]] = arith.constant 13 : index + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 + + + %tid_x = gpu.thread_id x + %c13 = arith.constant 13 : index + %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index + gpu.return %1: f16 + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_3 + gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c19:.*]] = arith.constant 19 : index + %tid_x = gpu.thread_id x + %c19 = arith.constant 19: index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 + %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index + + //CHECK: gpu.return %[[loaded]] : f16 + gpu.return %1: f16 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_4 + gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>> + + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> + + %tid_x = gpu.thread_id x + %c16 = arith.constant 16 : index + %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16> + + //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index + + gpu.return %1: vector<8xf16> + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_5 + gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[c48:.*]] = arith.constant 48 : index + + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index + //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index + //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 + //CHECK: %[[c2:.*]] = arith.constant 2 : i32 + //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 + //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> + //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> + //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> + + %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16> + + //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) + + xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index + + gpu.return %1: vector<8xf16> + } + +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index 0b150e9..9c552d8 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -14,19 +14,36 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64 // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> // CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) { - // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16> - // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16> - // CHECK: scf.yield %[[VAR8]] : f16 - // CHECK: } else { - // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16> - // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16> + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16 // CHECK: scf.yield %[[VAR7]] : f16 + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16 + // CHECK: scf.yield %[[CST_0]] : f16 // CHECK: } %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> gpu.return } } + +// ----- +gpu.module @test { +// CHECK-LABEL: @source_materialize_single_elem_vec +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16> +gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) { + %1 = arith.constant dense<1>: vector<1xi1> + %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> + // CHECK: %[[VAR_IF:.*]] = scf.if + // CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16> + %c0 = arith.constant 0 : index + vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16> + gpu.return +} +} + // ----- gpu.module @test { diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index e56079c..1169cd1 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind // ----- +// CHECK-LABEL: func @delin_apply_cancel_exact +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]] +// CHECK-NOT: memref.store +// CHECK: return +func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref<?xindex>) { + %a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index + %b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index + %c:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + %t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0] + memref.store %t2, %arg1[%t2] : memref<?xindex> + + %t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1] + memref.store %t3, %arg1[%t3] : memref<?xindex> + + %t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0] + memref.store %t4, %arg1[%t4] : memref<?xindex> + + %t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0] + memref.store %t5, %arg1[%t5] : memref<?xindex> + + %t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0] + memref.store %t6, %arg1[%t5] : memref<?xindex> + + return +} + +// ----- + +// CHECK-LABEL: func @delin_apply_cancel_exact_dim +// CHECK: affine.for %[[arg1:.+]] = 0 to 256 +// CHECK: memref.store %[[arg1]] +// CHECK: return +func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) { + affine.for %arg1 = 0 to 256 { + %a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index + %i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1) + memref.store %i, %arg0[%i] : memref<?xindex> + } + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)> +// CHECK-LABEL: func @delin_apply_cancel_const_term +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref<?xindex>) { + %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)> +// CHECK-LABEL: func @delin_apply_cancel_var_term +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref<?xindex>, %arg2: index) { + %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)> +// CHECK-LABEL: func @delin_apply_cancel_nested_exprs +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref<?xindex>) { + %a:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20) +// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]] +// CHECK: return +func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref<?xindex>) { + %a:2 = affine.delinearize_index %arg0 into (20) : index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)> +// CHECK-LABEL: func @delin_apply_dont_cancel_partial +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>) +// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5) +// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1] +// CHECK: return +func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref<?xindex>) { + %a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index + + %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1] + memref.store %t1, %arg1[%t1] : memref<?xindex> + + return +} + +// ----- + // CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index // CHECK-SAME: (%[[ARG0:.*]]: index) // CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir index 8accf6e..755e3a3 100644 --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr { // ----- +// CHECK-LABEL: fold_shufflevector +// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32> +llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> { + // CHECK-NOT: llvm.shufflevector + %c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32> + // CHECK: llvm.return %[[ARG1]] + llvm.return %c : vector<1xf32> +} + +// ----- + // Check that LLVM constants participate in cross-dialect constant folding. The // resulting constant is created in the arith dialect because the last folded // operation belongs to it. diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 358bd33..242c04f 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1035,6 +1035,20 @@ llvm.func @rocdl.s.wait.expcnt() { llvm.return } +llvm.func @rocdl.s.wait.asynccnt() { + // CHECK-LABEL: rocdl.s.wait.asynccnt + // CHECK: rocdl.s.wait.asynccnt 0 + rocdl.s.wait.asynccnt 0 + llvm.return +} + +llvm.func @rocdl.s.wait.tensorcnt() { + // CHECK-LABEL: rocdl.s.wait.tensorcnt + // CHECK: rocdl.s.wait.tensorcnt 0 + rocdl.s.wait.tensorcnt 0 + llvm.return +} + // ----- llvm.func @rocdl.readfirstlane(%src : f32) -> f32 { diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir index 35f520a..93a0336 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s +///---------------------------------------------------------------------------------------- +/// Tests for linalg.dot +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_dot func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) { @@ -20,6 +24,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.matvec +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_matvec func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { @@ -41,6 +49,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.matmul +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_matmul func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32> @@ -138,6 +150,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.batch_matmul +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: contraction_batch_matmul func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32> @@ -159,6 +175,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.cantract +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: @matmul_as_contract // CHECK-SAME: %[[A:.*]]: tensor<24x12xf32> // CHECK-SAME: %[[B:.*]]: tensor<12x25xf32> @@ -220,6 +240,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.fill +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: func @test_vectorize_fill func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> @@ -259,70 +283,14 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @test_vectorize_copy -func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { - // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> - // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> - memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32> - return -} +///---------------------------------------------------------------------------------------- +/// Tests for linalg.pack +///---------------------------------------------------------------------------------------- -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} +// Note, see a similar test in: +// * vectorization.mlir. -// ----- - -// CHECK-LABEL: func @test_vectorize_copy_0d -func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) { - // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>) - // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32> - // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32> - // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32> - // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32> - memref.copy %A, %B : memref<f32> to memref<f32> - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @test_vectorize_copy_complex -// CHECK-NOT: vector< -func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) { - memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>> - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// Input identical as the test in vectorization.mlir. Output is different - -// vector sizes are inferred (rather than user-specified) and hence _no_ -// masking was used. - -func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { +func.func @pack_no_padding(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32> return %pack : tensor<4x1x32x16x2xf32> } @@ -336,7 +304,7 @@ module attributes {transform.with_named_sequence} { } } -// CHECK-LABEL: func.func @test_vectorize_pack( +// CHECK-LABEL: func.func @pack_no_padding( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { // CHECK-DAG: %[[VAL_2:.*]] = ub.poison : f32 @@ -349,13 +317,16 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { +// Note, see a similar test in: +// * vectorization.mlir. + +func.func @pack_with_padding(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { %pad = arith.constant 0.000000e+00 : f32 %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> return %pack : tensor<32x4x1x16x2xf32> } -// CHECK-LABEL: func.func @test_vectorize_padded_pack( +// CHECK-LABEL: func.func @pack_with_padding( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 @@ -377,6 +348,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.map +///---------------------------------------------------------------------------------------- + func.func @vectorize_map(%arg0: memref<64xf32>, %arg1: memref<64xf32>, %arg2: memref<64xf32>) { linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>) @@ -403,6 +378,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.transpose +///---------------------------------------------------------------------------------------- + func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>, %arg1: memref<32x64x16xf32>) { linalg.transpose ins(%arg0 : memref<16x32x64xf32>) @@ -424,6 +403,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.reduce +///---------------------------------------------------------------------------------------- + func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>, %arg1: memref<16x64xf32>) { linalg.reduce ins(%arg0 : memref<16x32x64xf32>) @@ -449,6 +432,10 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Tests for linalg.generic +///---------------------------------------------------------------------------------------- + #matmul_trait = { indexing_maps = [ affine_map<(m, n, k) -> (m, k)>, @@ -1446,6 +1433,8 @@ module attributes {transform.with_named_sequence} { // ----- +// TODO: Two Linalg Ops in one tests - either split or document "why". + // CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> // CHECK-LABEL: func @fused_broadcast_red_2d @@ -1896,3 +1885,65 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +///---------------------------------------------------------------------------------------- +/// Tests for memref.copy +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func @test_vectorize_copy +func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_0d +func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) { + // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>) + // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32> + // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32> + // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32> + // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32> + memref.copy %A, %B : memref<f32> to memref<f32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_complex +// CHECK-NOT: vector< +func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) { + memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 11bea8d..1304a90 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -1307,14 +1307,17 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf /// Tests for linalg.pack ///---------------------------------------------------------------------------------------- -// Input identical as the test in vectorization-with-patterns.mlir. Output is -// different - vector sizes are inferred (rather than user-specified) and hence -// masking was used. +// This packing requires no padding, so no out-of-bounds read/write vector Ops. -// CHECK-LABEL: func @test_vectorize_pack +// Note, see a similar test in: +// * vectorization-with-patterns.mlir +// The output is identical (the input vector sizes == the inferred vector +// sizes, i.e. the tensor sizes). + +// CHECK-LABEL: func @pack_no_padding // CHECK-SAME: %[[SRC:.*]]: tensor<32x8x16xf32>, // CHECK-SAME: %[[DEST:.*]]: tensor<4x1x32x16x2xf32> -func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { +func.func @pack_no_padding(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { %pack = linalg.pack %src outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32> return %pack : tensor<4x1x32x16x2xf32> } @@ -1325,9 +1328,9 @@ func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x1 // CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> // CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32> // CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] // CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32> -// CHECK: return %[[write]] : tensor<4x1x32x16x2xf32> +// CHECK: return %[[WRITE]] : tensor<4x1x32x16x2xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%src: !transform.any_op {transform.readonly}) { @@ -1339,14 +1342,18 @@ module attributes {transform.with_named_sequence} { // ----- -// Input identical as the test in vectorization-with-patterns.mlir. Output is -// different - vector sizes are inferred (rather than user-specified) and hence -// masking was used. +// This packing does require padding, so there are out-of-bounds read/write +// vector Ops. + +// Note, see a similar test in: +// * vectorization-with-patterns.mlir. +// The output is different (the input vector sizes != inferred vector sizes, +// i.e. the tensor sizes). -// CHECK-LABEL: func @test_vectorize_padded_pack +// CHECK-LABEL: func @pack_with_padding // CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>, // CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32> -func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { +func.func @pack_with_padding(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { %pad = arith.constant 0.000000e+00 : f32 %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> return %pack : tensor<32x4x1x16x2xf32> @@ -1364,9 +1371,9 @@ func.func @test_vectorize_padded_pack(%src: tensor<32x7x15xf32>, %dest: tensor<3 // CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> // CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> // CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[write:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] // CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> -// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32> +// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -1378,10 +1385,46 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @test_vectorize_dynamic_pack +// This packing does require padding, so there are out-of-bounds read/write +// vector Ops. + +// Note, see a similar test in: +// * vectorization-with-patterns.mlir. +// The output is identical (in both cases the vector sizes are inferred). + +// CHECK-LABEL: func @pack_with_padding_no_vector_sizes +// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32> +func.func @pack_with_padding_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { + %pad = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> + return %pack : tensor<32x4x1x16x2xf32> +} +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]] +// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32> +// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> +// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] +// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @pack_with_dynamic_dims // CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>, // CHECK-SAME: %[[DEST:.*]]: tensor<?x?x16x2xf32> -func.func @test_vectorize_dynamic_pack(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> { +func.func @pack_with_dynamic_dims(%src: tensor<?x?xf32>, %dest: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> { %pack = linalg.pack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?xf32> -> tensor<?x?x16x2xf32> return %pack : tensor<?x?x16x2xf32> } @@ -1418,64 +1461,6 @@ module attributes {transform.with_named_sequence} { } } -// ----- - -// CHECK-LABEL: func @test_vectorize_pack_no_vector_sizes -// CHECK-SAME: %[[SRC:.*]]: tensor<64x4xf32>, -// CHECK-SAME: %[[DEST:.*]]: tensor<2x4x16x2xf32> -func.func @test_vectorize_pack_no_vector_sizes(%src: tensor<64x4xf32>, %dest: tensor<2x4x16x2xf32>) -> tensor<2x4x16x2xf32> { - %pack = linalg.pack %src outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %dest : tensor<64x4xf32> -> tensor<2x4x16x2xf32> - return %pack : tensor<2x4x16x2xf32> -} -// CHECK-DAG: %[[CST:.*]] = ub.poison : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST]] -// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32> -// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<64x4xf32> to vector<4x16x2x2xf32> -// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [2, 0, 1, 3] : vector<4x16x2x2xf32> to vector<2x4x16x2xf32> -// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] -// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<2x4x16x2xf32>, tensor<2x4x16x2xf32> -// CHECK: return %[[WRITE]] : tensor<2x4x16x2xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 : !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes -// CHECK-SAME: %[[SRC:.*]]: tensor<32x7x15xf32>, -// CHECK-SAME: %[[DEST:.*]]: tensor<32x4x1x16x2xf32> -func.func @test_vectorize_padded_pack_no_vector_sizes(%src: tensor<32x7x15xf32>, %dest: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { - %pad = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %src padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> - return %pack : tensor<32x4x1x16x2xf32> -} -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]] -// CHECK-SAME: {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32> -// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> -// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> -// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]], %[[C0_1]]] -// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> -// CHECK: return %[[WRITE]] : tensor<32x4x1x16x2xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 : !transform.any_op - transform.yield - } -} - - ///---------------------------------------------------------------------------------------- /// Tests for other Ops ///---------------------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 16b7a5c..7160b52 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { // ----- +// CHECK-LABEL: func @reinterpret_constant_fold +// CHECK-SAME: (%[[ARG:.*]]: memref<f32>) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] +func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> + return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>> +} + +// ----- + // CHECK-LABEL: func @reinterpret_of_reinterpret // CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] @@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x // when the strides don't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> @@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> diff --git a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir index 603ace8..3d4bec7 100644 --- a/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir +++ b/mlir/test/Dialect/OpenACC/pointer-like-interface-alloc.mlir @@ -3,7 +3,7 @@ func.func @test_static_memref_alloc() { %0 = memref.alloca() {test.ptr} : memref<10x20xf32> // CHECK: Successfully generated alloc for operation: %[[ORIG:.*]] = memref.alloca() {test.ptr} : memref<10x20xf32> - // CHECK: Generated: %{{.*}} = memref.alloca() : memref<10x20xf32> + // CHECK: Generated: %{{.*}} = memref.alloca() {acc.var_name = #acc.var_name<"test_alloc">} : memref<10x20xf32> return } @@ -19,6 +19,6 @@ func.func @test_dynamic_memref_alloc() { // CHECK: Generated: %[[DIM0:.*]] = memref.dim %[[ORIG]], %[[C0]] : memref<?x?xf32> // CHECK: Generated: %[[C1:.*]] = arith.constant 1 : index // CHECK: Generated: %[[DIM1:.*]] = memref.dim %[[ORIG]], %[[C1]] : memref<?x?xf32> - // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32> + // CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"test_alloc">} : memref<?x?xf32> return } diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir index 35355c6..8846c9e 100644 --- a/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir +++ b/mlir/test/Dialect/OpenACC/recipe-populate-firstprivate.mlir @@ -2,7 +2,7 @@ // CHECK: acc.firstprivate.recipe @firstprivate_scalar : memref<f32> init { // CHECK: ^bb0(%{{.*}}: memref<f32>): -// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<f32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32> // CHECK: acc.yield %[[ALLOC]] : memref<f32> // CHECK: } copy { // CHECK: ^bb0(%[[SRC:.*]]: memref<f32>, %[[DST:.*]]: memref<f32>): @@ -20,7 +20,7 @@ func.func @test_scalar() { // CHECK: acc.firstprivate.recipe @firstprivate_static_2d : memref<10x20xf32> init { // CHECK: ^bb0(%{{.*}}: memref<10x20xf32>): -// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<10x20xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32> // CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32> // CHECK: } copy { // CHECK: ^bb0(%[[SRC:.*]]: memref<10x20xf32>, %[[DST:.*]]: memref<10x20xf32>): @@ -42,7 +42,7 @@ func.func @test_static_2d() { // CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32> // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32> -// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32> // CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32> // CHECK: } copy { // CHECK: ^bb0(%[[SRC:.*]]: memref<?x?xf32>, %[[DST:.*]]: memref<?x?xf32>): @@ -65,7 +65,7 @@ func.func @test_dynamic_2d(%arg0: index, %arg1: index) { // CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>): // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32> -// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) : memref<10x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32> // CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32> // CHECK: } copy { // CHECK: ^bb0(%[[SRC:.*]]: memref<10x?xf32>, %[[DST:.*]]: memref<10x?xf32>): @@ -86,7 +86,7 @@ func.func @test_mixed_dims(%arg0: index) { // CHECK: acc.firstprivate.recipe @firstprivate_scalar_int : memref<i32> init { // CHECK: ^bb0(%{{.*}}: memref<i32>): -// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<i32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32> // CHECK: acc.yield %[[ALLOC]] : memref<i32> // CHECK: } copy { // CHECK: ^bb0(%[[SRC:.*]]: memref<i32>, %[[DST:.*]]: memref<i32>): diff --git a/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir index 8403ee8..3d5a918 100644 --- a/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir +++ b/mlir/test/Dialect/OpenACC/recipe-populate-private.mlir @@ -2,7 +2,7 @@ // CHECK: acc.private.recipe @private_scalar : memref<f32> init { // CHECK: ^bb0(%{{.*}}: memref<f32>): -// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<f32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar">} : memref<f32> // CHECK: acc.yield %[[ALLOC]] : memref<f32> // CHECK: } // CHECK-NOT: destroy @@ -16,7 +16,7 @@ func.func @test_scalar() { // CHECK: acc.private.recipe @private_static_2d : memref<10x20xf32> init { // CHECK: ^bb0(%{{.*}}: memref<10x20xf32>): -// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<10x20xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"static_2d">} : memref<10x20xf32> // CHECK: acc.yield %[[ALLOC]] : memref<10x20xf32> // CHECK: } // CHECK-NOT: destroy @@ -34,7 +34,7 @@ func.func @test_static_2d() { // CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %[[C0]] : memref<?x?xf32> // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<?x?xf32> -// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {acc.var_name = #acc.var_name<"dynamic_2d">} : memref<?x?xf32> // CHECK: acc.yield %[[ALLOC]] : memref<?x?xf32> // CHECK: } destroy { // CHECK: ^bb0(%{{.*}}: memref<?x?xf32>, %[[VAL:.*]]: memref<?x?xf32>): @@ -53,7 +53,7 @@ func.func @test_dynamic_2d(%arg0: index, %arg1: index) { // CHECK: ^bb0(%[[ARG:.*]]: memref<10x?xf32>): // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<10x?xf32> -// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) : memref<10x?xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM1]]) {acc.var_name = #acc.var_name<"mixed_dims">} : memref<10x?xf32> // CHECK: acc.yield %[[ALLOC]] : memref<10x?xf32> // CHECK: } destroy { // CHECK: ^bb0(%{{.*}}: memref<10x?xf32>, %[[VAL:.*]]: memref<10x?xf32>): @@ -70,7 +70,7 @@ func.func @test_mixed_dims(%arg0: index) { // CHECK: acc.private.recipe @private_scalar_int : memref<i32> init { // CHECK: ^bb0(%{{.*}}: memref<i32>): -// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<i32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() {acc.var_name = #acc.var_name<"scalar_int">} : memref<i32> // CHECK: acc.yield %[[ALLOC]] : memref<i32> // CHECK: } // CHECK-NOT: destroy diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir index b6c72be..f66cf7a 100644 --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -490,3 +490,32 @@ func.func @collapse_shape_regression( tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32> return } + +// ----- + +// CHECK-LABEL: func private @mult_return_callee( +// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1, +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index { +// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: return %[[A]] : index +// CHECK: ^bb2: +// CHECK: return %[[B]] : index +func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) { + %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32> + cf.cond_br %cond,^a, ^b +^a: + return %casted, %a : tensor<10xf32>, index +^b: + return %casted, %b : tensor<10xf32>, index +} + +// CHECK-LABEL: func @mult_return( +// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1, +// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) { +func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) { + // CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index + // CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index + %t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index) + return %t_res, %v : tensor<10xf32>, index +} diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir index d6c886c..a0c59c0 100644 --- a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -1,12 +1,14 @@ // RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL // RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K // RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1 // ----- -// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>} -// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} -// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>} +// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>} // CHECK-LABEL: test_simple func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir new file mode 100644 index 0000000..51089df --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" + +// ----- + +func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> + // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16> + return %0 : tensor<1x14x28xf16> +} + +// ----- + +func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + // expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir new file mode 100644 index 0000000..8164509 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s + +// ----- + +func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16> + return %0 : tensor<1x14x28xf16> +} + +// ----- + +// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type +func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir index b9b3420..a25abbd 100644 --- a/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir +++ b/mlir/test/Dialect/WasmSSA/custom_parser/global.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s | FileCheck %s module { - wasmssa.import_global "from_js" from "env" as @global_0 nested : i32 + wasmssa.import_global "from_js" from "env" as @global_0 : i32 wasmssa.global @global_1 i32 : { %0 = wasmssa.const 10 : i32 @@ -21,7 +21,7 @@ module { } } -// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32 +// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 : i32 // CHECK-LABEL: wasmssa.global @global_1 i32 : { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir index 01068cb..cee3c69 100644 --- a/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir +++ b/mlir/test/Dialect/WasmSSA/custom_parser/if.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s | FileCheck %s -// CHECK-LABEL: wasmssa.func nested @func_0( +// CHECK-LABEL: wasmssa.func @func_0( // CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 // CHECK: wasmssa.if %[[VAL_0]] : { @@ -12,7 +12,7 @@ // CHECK: }> ^bb1 // CHECK: ^bb1(%[[VAL_3:.*]]: f32): // CHECK: wasmssa.return %[[VAL_3]] : f32 -wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 { +wasmssa.func @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 { %cond = wasmssa.local_get %arg0 : ref to i32 wasmssa.if %cond : { %c0 = wasmssa.const 0.5 : f32 @@ -25,7 +25,7 @@ wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 { wasmssa.return %retVal : f32 } -// CHECK-LABEL: wasmssa.func nested @func_1( +// CHECK-LABEL: wasmssa.func @func_1( // CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 // CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32 @@ -38,7 +38,7 @@ wasmssa.func nested @func_0(%arg0 : !wasmssa<local ref to i32>) -> i32 { // CHECK: ^bb1: // CHECK: %[[VAL_4:.*]] = wasmssa.local_get %[[VAL_1]] : ref to i32 // CHECK: wasmssa.return %[[VAL_4]] : i32 -wasmssa.func nested @func_1(%arg0 : !wasmssa<local ref to i32>) -> i32 { +wasmssa.func @func_1(%arg0 : !wasmssa<local ref to i32>) -> i32 { %cond = wasmssa.local_get %arg0 : ref to i32 %var = wasmssa.local of type i32 %zero = wasmssa.const 0 diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir index 3cc0548..dc23229 100644 --- a/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir +++ b/mlir/test/Dialect/WasmSSA/custom_parser/import.mlir @@ -5,13 +5,13 @@ module { wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()} wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>} wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"} - wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32 - wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32 + wasmssa.import_global "glob" from "my_module" as @global_0 : i32 + wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32 } // CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()} // CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()} // CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>} // CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"} -// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32 -// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32 +// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 : i32 +// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32 diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir index 3f6423f..f613ebf 100644 --- a/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir +++ b/mlir/test/Dialect/WasmSSA/custom_parser/local.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s | FileCheck %s module { - wasmssa.func nested @func_0() -> f32 { + wasmssa.func @func_0() -> f32 { %0 = wasmssa.local of type f32 %1 = wasmssa.local of type f32 %2 = wasmssa.const 8.000000e+00 : f32 @@ -9,7 +9,7 @@ module { %4 = wasmssa.add %2 %3 : f32 wasmssa.return %4 : f32 } - wasmssa.func nested @func_1() -> i32 { + wasmssa.func @func_1() -> i32 { %0 = wasmssa.local of type i32 %1 = wasmssa.local of type i32 %2 = wasmssa.const 8 : i32 @@ -17,13 +17,13 @@ module { %4 = wasmssa.add %2 %3 : i32 wasmssa.return %4 : i32 } - wasmssa.func nested @func_2(%arg0: !wasmssa<local ref to i32>) -> i32 { + wasmssa.func @func_2(%arg0: !wasmssa<local ref to i32>) -> i32 { %0 = wasmssa.const 3 : i32 wasmssa.return %0 : i32 } } -// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 { +// CHECK-LABEL: wasmssa.func @func_0() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32 // CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32 // CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32 @@ -31,7 +31,7 @@ module { // CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : f32 // CHECK: wasmssa.return %[[VAL_4]] : f32 -// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 { +// CHECK-LABEL: wasmssa.func @func_1() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32 // CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32 // CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32 @@ -39,7 +39,7 @@ module { // CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : i32 // CHECK: wasmssa.return %[[VAL_4]] : i32 -// CHECK-LABEL: wasmssa.func nested @func_2( +// CHECK-LABEL: wasmssa.func @func_2( // CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32 // CHECK: wasmssa.return %[[VAL_0]] : i32 diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir index 47551db..ca6ebe0 100644 --- a/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir +++ b/mlir/test/Dialect/WasmSSA/custom_parser/memory.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s | FileCheck %s -// CHECK: wasmssa.memory @mem0 public !wasmssa<limit[0: 65536]> -wasmssa.memory @mem0 public !wasmssa<limit[0:65536]> - -// CHECK: wasmssa.memory @mem1 nested !wasmssa<limit[512:]> +// CHECK: wasmssa.memory @mem1 !wasmssa<limit[512:]> wasmssa.memory @mem1 !wasmssa<limit[512:]> + +// CHECK: wasmssa.memory exported @mem2 !wasmssa<limit[0: 65536]> +wasmssa.memory exported @mem2 !wasmssa<limit[0:65536]> diff --git a/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir index 5a874f4..ea630de 100644 --- a/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir +++ b/mlir/test/Dialect/WasmSSA/custom_parser/table.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s | FileCheck %s -// CHECK: wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0: 65536]> -wasmssa.table @tab0 public !wasmssa<tabletype !wasmssa.externref [0:65536]> +// CHECK: wasmssa.table exported @tab0 !wasmssa<tabletype !wasmssa.externref [0: 65536]> +wasmssa.table exported @tab0 !wasmssa<tabletype !wasmssa.externref [0:65536]> -// CHECK: wasmssa.table @tab1 nested !wasmssa<tabletype !wasmssa.funcref [348:]> +// CHECK: wasmssa.table @tab1 !wasmssa<tabletype !wasmssa.funcref [348:]> wasmssa.table @tab1 !wasmssa<tabletype !wasmssa.funcref [348:]> diff --git a/mlir/test/Dialect/WasmSSA/extend-invalid.mlir b/mlir/test/Dialect/WasmSSA/extend-invalid.mlir index 8d78280..7687e5f 100644 --- a/mlir/test/Dialect/WasmSSA/extend-invalid.mlir +++ b/mlir/test/Dialect/WasmSSA/extend-invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -wasmssa.func nested @extend_low_64() -> i32 { +wasmssa.func @extend_low_64() -> i32 { %0 = wasmssa.const 10 : i32 // expected-error@+1 {{extend op can only take 8, 16 or 32 bits. Got 64}} %1 = wasmssa.extend 64 low bits from %0: i32 @@ -10,7 +10,7 @@ wasmssa.func nested @extend_low_64() -> i32 { // ----- -wasmssa.func nested @extend_too_much() -> i32 { +wasmssa.func @extend_too_much() -> i32 { %0 = wasmssa.const 10 : i32 // expected-error@+1 {{trying to extend the 32 low bits from a 'i32' value is illegal}} %1 = wasmssa.extend 32 low bits from %0: i32 diff --git a/mlir/test/Dialect/WasmSSA/global-invalid.mlir b/mlir/test/Dialect/WasmSSA/global-invalid.mlir index b9cafd8..c5bc606 100644 --- a/mlir/test/Dialect/WasmSSA/global-invalid.mlir +++ b/mlir/test/Dialect/WasmSSA/global-invalid.mlir @@ -13,7 +13,7 @@ module { // ----- module { - wasmssa.import_global "glob" from "my_module" as @global_0 mutable nested : i32 + wasmssa.import_global "glob" from "my_module" as @global_0 mutable : i32 wasmssa.global @global_1 i32 : { // expected-error@+1 {{global.get op is considered constant if it's referring to a import.global symbol marked non-mutable}} %0 = wasmssa.global_get @global_0 : i32 @@ -30,3 +30,13 @@ module { wasmssa.return %0 : i32 } } + +// ----- + +module { + // expected-error@+1 {{expecting either `exported` or symbol name. got exproted}} + wasmssa.global exproted @global_1 i32 : { + %0 = wasmssa.const 17 : i32 + wasmssa.return %0 : i32 + } +} diff --git a/mlir/test/Dialect/WasmSSA/locals-invalid.mlir b/mlir/test/Dialect/WasmSSA/locals-invalid.mlir index 35c590b..eaad80e 100644 --- a/mlir/test/Dialect/WasmSSA/locals-invalid.mlir +++ b/mlir/test/Dialect/WasmSSA/locals-invalid.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -wasmssa.func nested @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 { +wasmssa.func @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 { %0 = wasmssa.const 3 : i64 // expected-error@+1 {{input type and result type of local.set do not match}} wasmssa.local_set %arg0 : ref to i32 to %0 : i64 @@ -9,7 +9,7 @@ wasmssa.func nested @local_set_err(%arg0: !wasmssa<local ref to i32>) -> i64 { // ----- -wasmssa.func nested @local_tee_err(%arg0: !wasmssa<local ref to i32>) -> i32 { +wasmssa.func @local_tee_err(%arg0: !wasmssa<local ref to i32>) -> i32 { %0 = wasmssa.const 3 : i64 // expected-error@+1 {{input type and output type of local.tee do not match}} %1 = wasmssa.local_tee %arg0 : ref to i32 to %0 : i64 diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 228ef69d..ebbe3ce 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16> // ----- func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed mem_desc shape}} + // expected-error@+1 {{data shape must not exceed mem_desc shape}} %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16> return } @@ -871,6 +871,14 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { } // ----- +func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> + return +} + + +// ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16> @@ -892,30 +900,16 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve } // ----- -func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed source shape}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) { - // expected-error@+1 {{result must inherit the source strides}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16> - return -} - -// ----- -func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{failed to verify that all of {src, res} have same element type}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>> +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } // ----- -func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result rank must not exceed source rank}} - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16> +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index bb37902..0a10f68 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -825,53 +825,73 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } -// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) { +// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> gpu.return } -// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) -gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { +// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) +gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16> gpu.return } +// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) +gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) +gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) +gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> + gpu.return +} -// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> gpu.return } -// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>> +// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { +gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16> + xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>> +// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) +gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> + xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> gpu.return } -// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) -gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) { - //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>> - %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>> +// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) { +gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> + xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> gpu.return } diff --git a/mlir/test/Target/LLVMIR/Import/function-attributes.ll b/mlir/test/Target/LLVMIR/Import/function-attributes.ll index cc3d799..00d09ba 100644 --- a/mlir/test/Target/LLVMIR/Import/function-attributes.ll +++ b/mlir/test/Target/LLVMIR/Import/function-attributes.ll @@ -393,6 +393,12 @@ declare void @alwaysinline_attribute() alwaysinline // ----- +; CHECK-LABEL: @inlinehint_attribute +; CHECK-SAME: attributes {inline_hint} +declare void @inlinehint_attribute() inlinehint + +// ----- + ; CHECK-LABEL: @optnone_attribute ; CHECK-SAME: attributes {no_inline, optimize_none} declare void @optnone_attribute() noinline optnone diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 69814f2..cc243c8 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -2555,6 +2555,17 @@ llvm.func @always_inline() attributes { always_inline } { // ----- +// CHECK-LABEL: @inline_hint +// CHECK-SAME: #[[ATTRS:[0-9]+]] +llvm.func @inline_hint() attributes { inline_hint } { + llvm.return +} + +// CHECK: #[[ATTRS]] +// CHECK-SAME: inlinehint + +// ----- + // CHECK-LABEL: @optimize_none // CHECK-SAME: #[[ATTRS:[0-9]+]] llvm.func @optimize_none() attributes { no_inline, optimize_none } { diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index b2fe2f5..6cccfe4 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -568,6 +568,18 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ llvm.return } +// ----- + +// Test that ensures invalid row/col layouts for matrices A and B are not accepted +llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { + // expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<col>, + multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} // ----- diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index fdd2c91..6536fac 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -276,6 +276,20 @@ llvm.func @rocdl.s.wait.expcnt() { llvm.return } +llvm.func @rocdl.s.wait.asynccnt() { + // CHECK-LABEL: rocdl.s.wait.asynccnt + // CHECK-NEXT: call void @llvm.amdgcn.s.wait.asynccnt(i16 0) + rocdl.s.wait.asynccnt 0 + llvm.return +} + +llvm.func @rocdl.s.wait.tensorcnt() { + // CHECK-LABEL: rocdl.s.wait.tensorcnt + // CHECK-NEXT: call void @llvm.amdgcn.s.wait.tensorcnt(i16 0) + rocdl.s.wait.tensorcnt 0 + llvm.return +} + llvm.func @rocdl.setprio() { // CHECK: call void @llvm.amdgcn.s.setprio(i16 0) rocdl.s.setprio 0 diff --git a/mlir/test/Target/Wasm/abs.mlir b/mlir/test/Target/Wasm/abs.mlir index 9c45ba7..fe3602a 100644 --- a/mlir/test/Target/Wasm/abs.mlir +++ b/mlir/test/Target/Wasm/abs.mlir @@ -12,12 +12,12 @@ ) */ -// CHECK-LABEL: wasmssa.func @abs_f32() -> f32 { +// CHECK-LABEL: wasmssa.func exported @abs_f32() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f32 // CHECK: wasmssa.return %[[VAL_1]] : f32 -// CHECK-LABEL: wasmssa.func @abs_f64() -> f64 { +// CHECK-LABEL: wasmssa.func exported @abs_f64() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.abs %[[VAL_0]] : f64 // CHECK: wasmssa.return %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/add_div.mlir b/mlir/test/Target/Wasm/add_div.mlir new file mode 100644 index 0000000..8a87c60 --- /dev/null +++ b/mlir/test/Target/Wasm/add_div.mlir @@ -0,0 +1,40 @@ +// RUN: yaml2obj %S/inputs/add_div.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: + (module $test.wasm + (type (;0;) (func (param i32) (result i32))) + (type (;1;) (func (param i32 i32) (result i32))) + (import "env" "twoTimes" (func $twoTimes (type 0))) + (func $add (type 1) (param i32 i32) (result i32) + local.get 0 + call $twoTimes + local.get 1 + call $twoTimes + i32.add + i32.const 2 + i32.div_s) + (memory (;0;) 2) + (global $__stack_pointer (mut i32) (i32.const 66560)) + (export "memory" (memory 0)) + (export "add" (func $add))) +*/ + +// CHECK-LABEL: wasmssa.import_func "twoTimes" from "env" as @func_0 {type = (i32) -> i32} + +// CHECK-LABEL: wasmssa.func exported @add( +// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>, +// CHECK-SAME: %[[ARG1:.*]]: !wasmssa<local ref to i32>) -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.call @func_0(%[[VAL_0]]) : (i32) -> i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.local_get %[[ARG1]] : ref to i32 +// CHECK: %[[VAL_3:.*]] = wasmssa.call @func_0(%[[VAL_2]]) : (i32) -> i32 +// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_3]] : i32 +// CHECK: %[[VAL_5:.*]] = wasmssa.const 2 : i32 +// CHECK: %[[VAL_6:.*]] = wasmssa.div_si %[[VAL_4]] %[[VAL_5]] : i32 +// CHECK: wasmssa.return %[[VAL_6]] : i32 +// CHECK: } +// CHECK: wasmssa.memory exported @memory !wasmssa<limit[2:]> + +// CHECK-LABEL: wasmssa.global @global_0 i32 mutable : { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 66560 : i32 +// CHECK: wasmssa.return %[[VAL_0]] : i32 diff --git a/mlir/test/Target/Wasm/and.mlir b/mlir/test/Target/Wasm/and.mlir index 4c0fea0..323d41a 100644 --- a/mlir/test/Target/Wasm/and.mlir +++ b/mlir/test/Target/Wasm/and.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @and_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @and_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.and %0 %1 : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @and_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @and_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.and %0 %1 : i64 diff --git a/mlir/test/Target/Wasm/block.mlir b/mlir/test/Target/Wasm/block.mlir new file mode 100644 index 0000000..c85fc1e --- /dev/null +++ b/mlir/test/Target/Wasm/block.mlir @@ -0,0 +1,16 @@ +// RUN: yaml2obj %S/inputs/block.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module +(func(export "i_am_a_block") +(block $i_am_a_block) +) +) +*/ + +// CHECK-LABEL: wasmssa.func exported @i_am_a_block() { +// CHECK: wasmssa.block : { +// CHECK: wasmssa.block_return +// CHECK: }> ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.return diff --git a/mlir/test/Target/Wasm/block_complete_type.mlir b/mlir/test/Target/Wasm/block_complete_type.mlir new file mode 100644 index 0000000..67df198 --- /dev/null +++ b/mlir/test/Target/Wasm/block_complete_type.mlir @@ -0,0 +1,24 @@ +// RUN: yaml2obj %S/inputs/block_complete_type.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module + (type (;0;) (func (param i32) (result i32))) + (type (;1;) (func (result i32))) + (func (;0;) (type 1) (result i32) + i32.const 14 + block (param i32) (result i32) ;; label = @1 + i32.const 1 + i32.add + end)) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 14 : i32 +// CHECK: wasmssa.block(%[[VAL_0]]) : i32 : { +// CHECK: ^bb0(%[[VAL_1:.*]]: i32): +// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_2]] : i32 +// CHECK: wasmssa.block_return %[[VAL_3]] : i32 +// CHECK: }> ^bb1 +// CHECK: ^bb1(%[[VAL_4:.*]]: i32): +// CHECK: wasmssa.return %[[VAL_4]] : i32 diff --git a/mlir/test/Target/Wasm/block_value_type.mlir b/mlir/test/Target/Wasm/block_value_type.mlir new file mode 100644 index 0000000..fa30f08 --- /dev/null +++ b/mlir/test/Target/Wasm/block_value_type.mlir @@ -0,0 +1,19 @@ +// RUN: yaml2obj %S/inputs/block_value_type.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module + (type (;0;) (func (result i32))) + (func (;0;) (type 0) (result i32) + block (result i32) ;; label = @1 + i32.const 17 + end)) +*/ + + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: wasmssa.block : { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i32 +// CHECK: wasmssa.block_return %[[VAL_0]] : i32 +// CHECK: }> ^bb1 +// CHECK: ^bb1(%[[VAL_1:.*]]: i32): +// CHECK: wasmssa.return %[[VAL_1]] : i32 diff --git a/mlir/test/Target/Wasm/branch_if.mlir b/mlir/test/Target/Wasm/branch_if.mlir new file mode 100644 index 0000000..c91ff37 --- /dev/null +++ b/mlir/test/Target/Wasm/branch_if.mlir @@ -0,0 +1,29 @@ +// RUN: yaml2obj %S/inputs/branch_if.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module + (type $produce_i32 (func (result i32))) + (func (type $produce_i32) + (block $my_block (type $produce_i32) + i32.const 1 + i32.const 2 + br_if $my_block + i32.const 1 + i32.add + ) + ) +) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: wasmssa.block : { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32 +// CHECK: wasmssa.branch_if %[[VAL_1]] to level 0 with args(%[[VAL_0]] : i32) else ^bb1 +// CHECK: ^bb1: +// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_0]] %[[VAL_2]] : i32 +// CHECK: wasmssa.block_return %[[VAL_3]] : i32 +// CHECK: }> ^bb1 +// CHECK: ^bb1(%[[VAL_4:.*]]: i32): +// CHECK: wasmssa.return %[[VAL_4]] : i32 diff --git a/mlir/test/Target/Wasm/call.mlir b/mlir/test/Target/Wasm/call.mlir new file mode 100644 index 0000000..c0169aa --- /dev/null +++ b/mlir/test/Target/Wasm/call.mlir @@ -0,0 +1,17 @@ +// RUN: yaml2obj %S/inputs/call.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module +(func $forty_two (result i32) +i32.const 42) +(func(export "forty_two")(result i32) +call $forty_two)) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32 +// CHECK: wasmssa.return %[[VAL_0]] : i32 + +// CHECK-LABEL: wasmssa.func exported @forty_two() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.call @func_0 : () -> i32 +// CHECK: wasmssa.return %[[VAL_0]] : i32 diff --git a/mlir/test/Target/Wasm/clz.mlir b/mlir/test/Target/Wasm/clz.mlir index 3e6641d..858c09d 100644 --- a/mlir/test/Target/Wasm/clz.mlir +++ b/mlir/test/Target/Wasm/clz.mlir @@ -14,12 +14,12 @@ ) */ -// CHECK-LABEL: wasmssa.func @clz_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @clz_i32() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i32 // CHECK: wasmssa.return %[[VAL_1]] : i32 -// CHECK-LABEL: wasmssa.func @clz_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @clz_i64() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.clz %[[VAL_0]] : i64 // CHECK: wasmssa.return %[[VAL_1]] : i64 diff --git a/mlir/test/Target/Wasm/comparison_ops.mlir b/mlir/test/Target/Wasm/comparison_ops.mlir new file mode 100644 index 0000000..91e3a6a --- /dev/null +++ b/mlir/test/Target/Wasm/comparison_ops.mlir @@ -0,0 +1,269 @@ +// RUN: yaml2obj %S/inputs/comparison_ops.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s +/* Source code used to create this test: +(module + (func $lt_si32 (result i32) + i32.const 12 + i32.const 50 + i32.lt_s + ) + (func $le_si32 (result i32) + i32.const 12 + i32.const 50 + i32.le_s + ) + (func $lt_ui32 (result i32) + i32.const 12 + i32.const 50 + i32.lt_u + ) + (func $le_ui32 (result i32) + i32.const 12 + i32.const 50 + i32.le_u + ) + (func $gt_si32 (result i32) + i32.const 12 + i32.const 50 + i32.gt_s + ) + (func $gt_ui32 (result i32) + i32.const 12 + i32.const 50 + i32.gt_u + ) + (func $ge_si32 (result i32) + i32.const 12 + i32.const 50 + i32.ge_s + ) + (func $ge_ui32 (result i32) + i32.const 12 + i32.const 50 + i32.ge_u + ) + (func $lt_si64 (result i32) + i64.const 12 + i64.const 50 + i64.lt_s + ) + (func $le_si64 (result i32) + i64.const 12 + i64.const 50 + i64.le_s + ) + (func $lt_ui64 (result i32) + i64.const 12 + i64.const 50 + i64.lt_u + ) + (func $le_ui64 (result i32) + i64.const 12 + i64.const 50 + i64.le_u + ) + (func $gt_si64 (result i32) + i64.const 12 + i64.const 50 + i64.gt_s + ) + (func $gt_ui64 (result i32) + i64.const 12 + i64.const 50 + i64.gt_u + ) + (func $ge_si64 (result i32) + i64.const 12 + i64.const 50 + i64.ge_s + ) + (func $ge_ui64 (result i32) + i64.const 12 + i64.const 50 + i64.ge_u + ) + (func $lt_f32 (result i32) + f32.const 5 + f32.const 14 + f32.lt + ) + (func $le_f32 (result i32) + f32.const 5 + f32.const 14 + f32.le + ) + (func $gt_f32 (result i32) + f32.const 5 + f32.const 14 + f32.gt + ) + (func $ge_f32 (result i32) + f32.const 5 + f32.const 14 + f32.ge + ) + (func $lt_f64 (result i32) + f64.const 5 + f64.const 14 + f64.lt + ) + (func $le_f64 (result i32) + f64.const 5 + f64.const 14 + f64.le + ) + (func $gt_f64 (result i32) + f64.const 5 + f64.const 14 + f64.gt + ) + (func $ge_f64 (result i32) + f64.const 5 + f64.const 14 + f64.ge + ) +) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.lt_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_1() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.le_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_2() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.lt_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_3() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.le_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_4() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.gt_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_5() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.gt_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_6() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.ge_si %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_7() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.ge_ui %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_8() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.lt_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_9() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.le_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_10() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.lt_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_11() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.le_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_12() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.gt_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_13() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.gt_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_14() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.ge_si %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_15() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.ge_ui %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_16() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = wasmssa.lt %[[VAL_0]] %[[VAL_1]] : f32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_17() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = wasmssa.le %[[VAL_0]] %[[VAL_1]] : f32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_18() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = wasmssa.gt %[[VAL_0]] %[[VAL_1]] : f32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_19() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = wasmssa.ge %[[VAL_0]] %[[VAL_1]] : f32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_20() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64 +// CHECK: %[[VAL_2:.*]] = wasmssa.lt %[[VAL_0]] %[[VAL_1]] : f64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_21() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64 +// CHECK: %[[VAL_2:.*]] = wasmssa.le %[[VAL_0]] %[[VAL_1]] : f64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_22() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64 +// CHECK: %[[VAL_2:.*]] = wasmssa.gt %[[VAL_0]] %[[VAL_1]] : f64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_23() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f64 +// CHECK: %[[VAL_2:.*]] = wasmssa.ge %[[VAL_0]] %[[VAL_1]] : f64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 diff --git a/mlir/test/Target/Wasm/const.mlir b/mlir/test/Target/Wasm/const.mlir index aa9e76f..adb792a 100644 --- a/mlir/test/Target/Wasm/const.mlir +++ b/mlir/test/Target/Wasm/const.mlir @@ -16,22 +16,22 @@ ) */ -// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 { +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1 : i32 // CHECK: wasmssa.return %[[VAL_0]] : i32 // CHECK: } -// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 { +// CHECK-LABEL: wasmssa.func @func_1() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i64 // CHECK: wasmssa.return %[[VAL_0]] : i64 // CHECK: } -// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 { +// CHECK-LABEL: wasmssa.func @func_2() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 4.000000e+00 : f32 // CHECK: wasmssa.return %[[VAL_0]] : f32 // CHECK: } -// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 { +// CHECK-LABEL: wasmssa.func @func_3() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 9.000000e+00 : f64 // CHECK: wasmssa.return %[[VAL_0]] : f64 // CHECK: } diff --git a/mlir/test/Target/Wasm/convert.mlir b/mlir/test/Target/Wasm/convert.mlir new file mode 100644 index 0000000..ddc29a7 --- /dev/null +++ b/mlir/test/Target/Wasm/convert.mlir @@ -0,0 +1,85 @@ +// RUN: yaml2obj %S/inputs/convert.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to generate this test: +(module + (func (export "convert_i32_u_to_f32") (result f32) + i32.const 10 + f32.convert_i32_u + ) + + (func (export "convert_i32_s_to_f32") (result f32) + i32.const 42 + f32.convert_i32_s + ) + + (func (export "convert_i64_u_to_f32") (result f32) + i64.const 17 + f32.convert_i64_u + ) + + (func (export "convert_i64s_to_f32") (result f32) + i64.const 10 + f32.convert_i64_s + ) + + (func (export "convert_i32_u_to_f64") (result f64) + i32.const 10 + f64.convert_i32_u + ) + + (func (export "convert_i32_s_to_f64") (result f64) + i32.const 42 + f64.convert_i32_s + ) + + (func (export "convert_i64_u_to_f64") (result f64) + i64.const 17 + f64.convert_i64_u + ) + + (func (export "convert_i64s_to_f64") (result f64) + i64.const 10 + f64.convert_i64_s + ) +) +*/ + +// CHECK-LABEL: wasmssa.func exported @convert_i32_u_to_f32() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i32 to f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 + +// CHECK-LABEL: wasmssa.func exported @convert_i32_s_to_f32() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i32 to f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 + +// CHECK-LABEL: wasmssa.func exported @convert_i64_u_to_f32() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i64 to f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 + +// CHECK-LABEL: wasmssa.func exported @convert_i64s_to_f32() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i64 to f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 + +// CHECK-LABEL: wasmssa.func exported @convert_i32_u_to_f64() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i32 to f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 + +// CHECK-LABEL: wasmssa.func exported @convert_i32_s_to_f64() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 42 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i32 to f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 + +// CHECK-LABEL: wasmssa.func exported @convert_i64_u_to_f64() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 17 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_u %[[VAL_0]] : i64 to f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 + +// CHECK-LABEL: wasmssa.func exported @convert_i64s_to_f64() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.convert_s %[[VAL_0]] : i64 to f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/copysign.mlir b/mlir/test/Target/Wasm/copysign.mlir index 33d7a56..90c5b11 100644 --- a/mlir/test/Target/Wasm/copysign.mlir +++ b/mlir/test/Target/Wasm/copysign.mlir @@ -16,14 +16,14 @@ ) */ -// CHECK-LABEL: wasmssa.func @copysign_f32() -> f32 { +// CHECK-LABEL: wasmssa.func exported @copysign_f32() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32 // CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f32 // CHECK: wasmssa.return %[[VAL_2]] : f32 // CHECK: } -// CHECK-LABEL: wasmssa.func @copysign_f64() -> f64 { +// CHECK-LABEL: wasmssa.func exported @copysign_f64() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64 // CHECK: %[[VAL_2:.*]] = wasmssa.copysign %[[VAL_0]] %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/ctz.mlir b/mlir/test/Target/Wasm/ctz.mlir index 6c0806f..9e7cc5e 100644 --- a/mlir/test/Target/Wasm/ctz.mlir +++ b/mlir/test/Target/Wasm/ctz.mlir @@ -14,12 +14,12 @@ ) */ -// CHECK-LABEL: wasmssa.func @ctz_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @ctz_i32() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i32 // CHECK: wasmssa.return %[[VAL_1]] : i32 -// CHECK-LABEL: wasmssa.func @ctz_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @ctz_i64() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i64 // CHECK: wasmssa.return %[[VAL_1]] : i64 diff --git a/mlir/test/Target/Wasm/demote.mlir b/mlir/test/Target/Wasm/demote.mlir new file mode 100644 index 0000000..3d2bc05 --- /dev/null +++ b/mlir/test/Target/Wasm/demote.mlir @@ -0,0 +1,15 @@ +// RUN: yaml2obj %S/inputs/demote.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module + (func $main (result f32) + f64.const 2.24 + f32.demote_f64 + ) +) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 2.240000e+00 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.demote %[[VAL_0]] : f64 to f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 diff --git a/mlir/test/Target/Wasm/div.mlir b/mlir/test/Target/Wasm/div.mlir index c91f780..4967d96 100644 --- a/mlir/test/Target/Wasm/div.mlir +++ b/mlir/test/Target/Wasm/div.mlir @@ -66,61 +66,61 @@ ) */ -// CHECK-LABEL: wasmssa.func @div_u_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @div_u_i32() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32 // CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32 // CHECK: wasmssa.return %[[VAL_2]] : i32 -// CHECK-LABEL: wasmssa.func @div_u_i32_zero() -> i32 { +// CHECK-LABEL: wasmssa.func exported @div_u_i32_zero() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32 // CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i32 // CHECK: wasmssa.return %[[VAL_2]] : i32 -// CHECK-LABEL: wasmssa.func @div_s_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @div_s_i32() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i32 // CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32 // CHECK: wasmssa.return %[[VAL_2]] : i32 -// CHECK-LABEL: wasmssa.func @div_s_i32_zero() -> i32 { +// CHECK-LABEL: wasmssa.func exported @div_s_i32_zero() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i32 // CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i32 // CHECK: wasmssa.return %[[VAL_2]] : i32 -// CHECK-LABEL: wasmssa.func @div_u_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @div_u_i64() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64 // CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64 // CHECK: wasmssa.return %[[VAL_2]] : i64 -// CHECK-LABEL: wasmssa.func @div_u_i64_zero() -> i64 { +// CHECK-LABEL: wasmssa.func exported @div_u_i64_zero() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64 // CHECK: %[[VAL_2:.*]] = wasmssa.div_ui %[[VAL_0]] %[[VAL_1]] : i64 // CHECK: wasmssa.return %[[VAL_2]] : i64 -// CHECK-LABEL: wasmssa.func @div_s_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @div_s_i64() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 2 : i64 // CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64 // CHECK: wasmssa.return %[[VAL_2]] : i64 -// CHECK-LABEL: wasmssa.func @div_s_i64_zero() -> i64 { +// CHECK-LABEL: wasmssa.func exported @div_s_i64_zero() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 0 : i64 // CHECK: %[[VAL_2:.*]] = wasmssa.div_si %[[VAL_0]] %[[VAL_1]] : i64 // CHECK: wasmssa.return %[[VAL_2]] : i64 -// CHECK-LABEL: wasmssa.func @div_f32() -> f32 { +// CHECK-LABEL: wasmssa.func exported @div_f32() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f32 // CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f32 // CHECK: wasmssa.return %[[VAL_2]] : f32 -// CHECK-LABEL: wasmssa.func @div_f64() -> f64 { +// CHECK-LABEL: wasmssa.func exported @div_f64() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 2.000000e+00 : f64 // CHECK: %[[VAL_2:.*]] = wasmssa.div %[[VAL_0]] %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/double_nested_loop.mlir b/mlir/test/Target/Wasm/double_nested_loop.mlir new file mode 100644 index 0000000..8b3e499 --- /dev/null +++ b/mlir/test/Target/Wasm/double_nested_loop.mlir @@ -0,0 +1,63 @@ +// RUN: yaml2obj %S/inputs/double_nested_loop.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* +(module + (func + ;; create a local variable and initialize it to 0 + (local $i i32) + (local $j i32) + + (loop $my_loop + + ;; add one to $i + local.get $i + i32.const 1 + i32.add + local.set $i + (loop $my_second_loop (result i32) + i32.const 1 + local.get $j + i32.const 12 + i32.add + local.tee $j + local.get $i + i32.gt_s + br_if $my_second_loop + ) + i32.const 10 + i32.lt_s + br_if $my_loop + ) + ) +) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() { +// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32 +// CHECK: wasmssa.loop : { +// CHECK: %[[VAL_2:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32 +// CHECK: %[[VAL_3:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_4:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_3]] : i32 +// CHECK: wasmssa.local_set %[[VAL_0]] : ref to i32 to %[[VAL_4]] : i32 +// CHECK: wasmssa.loop : { +// CHECK: %[[VAL_5:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_6:.*]] = wasmssa.local_get %[[VAL_1]] : ref to i32 +// CHECK: %[[VAL_7:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_8:.*]] = wasmssa.add %[[VAL_6]] %[[VAL_7]] : i32 +// CHECK: %[[VAL_9:.*]] = wasmssa.local_tee %[[VAL_1]] : ref to i32 to %[[VAL_8]] : i32 +// CHECK: %[[VAL_10:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32 +// CHECK: %[[VAL_11:.*]] = wasmssa.gt_si %[[VAL_9]] %[[VAL_10]] : i32 -> i32 +// CHECK: wasmssa.branch_if %[[VAL_11]] to level 0 else ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.block_return %[[VAL_5]] : i32 +// CHECK: }> ^bb1 +// CHECK: ^bb1(%[[VAL_12:.*]]: i32): +// CHECK: %[[VAL_13:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_14:.*]] = wasmssa.lt_si %[[VAL_12]] %[[VAL_13]] : i32 -> i32 +// CHECK: wasmssa.branch_if %[[VAL_14]] to level 0 else ^bb2 +// CHECK: ^bb2: +// CHECK: wasmssa.block_return +// CHECK: }> ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.return diff --git a/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir b/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir new file mode 100644 index 0000000..5c98f1a --- /dev/null +++ b/mlir/test/Target/Wasm/empty_blocks_list_and_stack.mlir @@ -0,0 +1,53 @@ +// RUN: yaml2obj %S/inputs/empty_blocks_list_and_stack.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module + (func (param $num i32) + (block $b1 + (block $b2 + (block $b3 + ) + ) + ) + ) + + (func (param $num i32) + (block $b1) + (block $b2) + (block $b3) + ) +) + +*/ + +// CHECK-LABEL: wasmssa.func @func_0( +// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) { +// CHECK: wasmssa.block : { +// CHECK: wasmssa.block : { +// CHECK: wasmssa.block : { +// CHECK: wasmssa.block_return +// CHECK: }> ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.block_return +// CHECK: }> ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.block_return +// CHECK: }> ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.return + +// CHECK-LABEL: wasmssa.func @func_1( +// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) { +// CHECK: wasmssa.block : { +// CHECK: wasmssa.block_return +// CHECK: }> ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.block : { +// CHECK: wasmssa.block_return +// CHECK: }> ^bb2 +// CHECK: ^bb2: +// CHECK: wasmssa.block : { +// CHECK: wasmssa.block_return +// CHECK: }> ^bb3 +// CHECK: ^bb3: +// CHECK: wasmssa.return diff --git a/mlir/test/Target/Wasm/eq.mlir b/mlir/test/Target/Wasm/eq.mlir new file mode 100644 index 0000000..ba3ae2f --- /dev/null +++ b/mlir/test/Target/Wasm/eq.mlir @@ -0,0 +1,56 @@ +// RUN: yaml2obj %S/inputs/eq.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s +/* Source code used to create this test: +(module + (func $eq_i32 (result i32) + i32.const 12 + i32.const 50 + i32.eq + ) + + (func $eq_i64 (result i32) + i64.const 20 + i64.const 5 + i64.eq + ) + + (func $eq_f32 (result i32) + f32.const 5 + f32.const 14 + f32.eq + ) + + (func $eq_f64 (result i32) + f64.const 17 + f64.const 0 + f64.eq + ) +) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 +// CHECK: } + +// CHECK-LABEL: wasmssa.func @func_1() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 +// CHECK: } + +// CHECK-LABEL: wasmssa.func @func_2() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : f32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 +// CHECK: } + +// CHECK-LABEL: wasmssa.func @func_3() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64 +// CHECK: %[[VAL_2:.*]] = wasmssa.eq %[[VAL_0]] %[[VAL_1]] : f64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 +// CHECK: } diff --git a/mlir/test/Target/Wasm/eqz.mlir b/mlir/test/Target/Wasm/eqz.mlir new file mode 100644 index 0000000..55cf94a --- /dev/null +++ b/mlir/test/Target/Wasm/eqz.mlir @@ -0,0 +1,21 @@ +// RUN: yaml2obj %S/inputs/eqz.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s +/* Source code used to create this test: +(module + (func (export "eqz_i32") (result i32) + i32.const 13 + i32.eqz) + + (func (export "eqz_i64") (result i32) + i64.const 13 + i64.eqz) +) +*/ +// CHECK-LABEL: wasmssa.func exported @eqz_i32() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 13 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.eqz %[[VAL_0]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_1]] : i32 + +// CHECK-LABEL: wasmssa.func exported @eqz_i64() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 13 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.eqz %[[VAL_0]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_1]] : i32 diff --git a/mlir/test/Target/Wasm/extend.mlir b/mlir/test/Target/Wasm/extend.mlir new file mode 100644 index 0000000..5d4446a --- /dev/null +++ b/mlir/test/Target/Wasm/extend.mlir @@ -0,0 +1,69 @@ +// RUN: yaml2obj %S/inputs/extend.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module + (func $i32_s (result i64) + i32.const 10 + i64.extend_i32_s + ) + (func $i32_u (result i64) + i32.const 10 + i64.extend_i32_u + ) + (func $extend8_32 (result i32) + i32.const 10 + i32.extend8_s + ) + (func $extend16_32 (result i32) + i32.const 10 + i32.extend16_s + ) + (func $extend8_64 (result i64) + i64.const 10 + i64.extend8_s + ) + (func $extend16_64 (result i64) + i64.const 10 + i64.extend16_s + ) + (func $extend32_64 (result i64) + i64.const 10 + i64.extend32_s + ) +) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.extend_i32_s %[[VAL_0]] to i64 +// CHECK: wasmssa.return %[[VAL_1]] : i64 + +// CHECK-LABEL: wasmssa.func @func_1() -> i64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.extend_i32_u %[[VAL_0]] to i64 +// CHECK: wasmssa.return %[[VAL_1]] : i64 + +// CHECK-LABEL: wasmssa.func @func_2() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.extend 8 : ui32 low bits from %[[VAL_0]] : i32 +// CHECK: wasmssa.return %[[VAL_1]] : i32 + +// CHECK-LABEL: wasmssa.func @func_3() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.extend 16 : ui32 low bits from %[[VAL_0]] : i32 +// CHECK: wasmssa.return %[[VAL_1]] : i32 + +// CHECK-LABEL: wasmssa.func @func_4() -> i64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.extend 8 : ui32 low bits from %[[VAL_0]] : i64 +// CHECK: wasmssa.return %[[VAL_1]] : i64 + +// CHECK-LABEL: wasmssa.func @func_5() -> i64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.extend 16 : ui32 low bits from %[[VAL_0]] : i64 +// CHECK: wasmssa.return %[[VAL_1]] : i64 + +// CHECK-LABEL: wasmssa.func @func_6() -> i64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.extend 32 : ui32 low bits from %[[VAL_0]] : i64 +// CHECK: wasmssa.return %[[VAL_1]] : i64 diff --git a/mlir/test/Target/Wasm/global.mlir b/mlir/test/Target/Wasm/global.mlir index e72fe69..1e4fe44 100644 --- a/mlir/test/Target/Wasm/global.mlir +++ b/mlir/test/Target/Wasm/global.mlir @@ -29,9 +29,9 @@ i32.add ) */ -// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 nested : i32 +// CHECK-LABEL: wasmssa.import_global "from_js" from "env" as @global_0 : i32 -// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 { +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.global_get @global_0 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.global_get @global_1 : i32 // CHECK: %[[VAL_2:.*]] = wasmssa.add %[[VAL_0]] %[[VAL_1]] : i32 @@ -41,26 +41,26 @@ i32.add // CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_2]] %[[VAL_5]] : i32 // CHECK: wasmssa.return %[[VAL_6]] : i32 -// CHECK-LABEL: wasmssa.global @global_1 i32 nested : { +// CHECK-LABEL: wasmssa.global @global_1 i32 : { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: wasmssa.return %[[VAL_0]] : i32 -// CHECK-LABEL: wasmssa.global @global_2 i32 mutable nested : { +// CHECK-LABEL: wasmssa.global @global_2 i32 mutable : { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: wasmssa.return %[[VAL_0]] : i32 -// CHECK-LABEL: wasmssa.global @global_3 i32 mutable nested : { +// CHECK-LABEL: wasmssa.global @global_3 i32 mutable : { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: wasmssa.return %[[VAL_0]] : i32 -// CHECK-LABEL: wasmssa.global @global_4 i64 nested : { +// CHECK-LABEL: wasmssa.global @global_4 i64 : { // CHECK: %[[VAL_0:.*]] = wasmssa.const 11 : i64 // CHECK: wasmssa.return %[[VAL_0]] : i64 -// CHECK-LABEL: wasmssa.global @global_5 f32 nested : { +// CHECK-LABEL: wasmssa.global @global_5 f32 : { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.200000e+01 : f32 // CHECK: wasmssa.return %[[VAL_0]] : f32 -// CHECK-LABEL: wasmssa.global @global_6 f64 nested : { +// CHECK-LABEL: wasmssa.global @global_6 f64 : { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.300000e+01 : f64 // CHECK: wasmssa.return %[[VAL_0]] : f64 diff --git a/mlir/test/Target/Wasm/if.mlir b/mlir/test/Target/Wasm/if.mlir new file mode 100644 index 0000000..2d7bfbe --- /dev/null +++ b/mlir/test/Target/Wasm/if.mlir @@ -0,0 +1,112 @@ +// RUN: yaml2obj %S/inputs/if.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to create this test: +(module +(type $intMapper (func (param $input i32) (result i32))) +(func $if_else (type $intMapper) + local.get 0 + i32.const 1 + i32.and + if $isOdd (result i32) + local.get 0 + i32.const 3 + i32.mul + i32.const 1 + i32.add + else + local.get 0 + i32.const 1 + i32.shr_u + end +) + +(func $if_only (type $intMapper) + local.get 0 + local.get 0 + i32.const 1 + i32.and + if $isOdd (type $intMapper) + i32.const 1 + i32.add + end +) + +(func $if_if (type $intMapper) + local.get 0 + i32.ctz + if $isEven (result i32) + i32.const 2 + local.get 0 + i32.const 1 + i32.shr_u + i32.ctz + if $isMultipleOfFour (type $intMapper) + i32.const 2 + i32.add + end + else + i32.const 1 + end +) +) +*/ +// CHECK-LABEL: wasmssa.func @func_0( +// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.and %[[VAL_0]] %[[VAL_1]] : i32 +// CHECK: wasmssa.if %[[VAL_2]] : { +// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_4:.*]] = wasmssa.const 3 : i32 +// CHECK: %[[VAL_5:.*]] = wasmssa.mul %[[VAL_3]] %[[VAL_4]] : i32 +// CHECK: %[[VAL_6:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_7:.*]] = wasmssa.add %[[VAL_5]] %[[VAL_6]] : i32 +// CHECK: wasmssa.block_return %[[VAL_7]] : i32 +// CHECK: } "else "{ +// CHECK: %[[VAL_8:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_9:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_10:.*]] = wasmssa.shr_u %[[VAL_8]] by %[[VAL_9]] bits : i32 +// CHECK: wasmssa.block_return %[[VAL_10]] : i32 +// CHECK: }> ^bb1 +// CHECK: ^bb1(%[[VAL_11:.*]]: i32): +// CHECK: wasmssa.return %[[VAL_11]] : i32 + +// CHECK-LABEL: wasmssa.func @func_1( +// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_3:.*]] = wasmssa.and %[[VAL_1]] %[[VAL_2]] : i32 +// CHECK: wasmssa.if %[[VAL_3]](%[[VAL_0]]) : i32 : { +// CHECK: ^bb0(%[[VAL_4:.*]]: i32): +// CHECK: %[[VAL_5:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_4]] %[[VAL_5]] : i32 +// CHECK: wasmssa.block_return %[[VAL_6]] : i32 +// CHECK: } > ^bb1 +// CHECK: ^bb1(%[[VAL_7:.*]]: i32): +// CHECK: wasmssa.return %[[VAL_7]] : i32 + +// CHECK-LABEL: wasmssa.func @func_2( +// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.ctz %[[VAL_0]] : i32 +// CHECK: wasmssa.if %[[VAL_1]] : { +// CHECK: %[[VAL_2:.*]] = wasmssa.const 2 : i32 +// CHECK: %[[VAL_3:.*]] = wasmssa.local_get %[[ARG0]] : ref to i32 +// CHECK: %[[VAL_4:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_5:.*]] = wasmssa.shr_u %[[VAL_3]] by %[[VAL_4]] bits : i32 +// CHECK: %[[VAL_6:.*]] = wasmssa.ctz %[[VAL_5]] : i32 +// CHECK: wasmssa.if %[[VAL_6]](%[[VAL_2]]) : i32 : { +// CHECK: ^bb0(%[[VAL_7:.*]]: i32): +// CHECK: %[[VAL_8:.*]] = wasmssa.const 2 : i32 +// CHECK: %[[VAL_9:.*]] = wasmssa.add %[[VAL_7]] %[[VAL_8]] : i32 +// CHECK: wasmssa.block_return %[[VAL_9]] : i32 +// CHECK: } > ^bb1 +// CHECK: ^bb1(%[[VAL_10:.*]]: i32): +// CHECK: wasmssa.block_return %[[VAL_10]] : i32 +// CHECK: } "else "{ +// CHECK: %[[VAL_11:.*]] = wasmssa.const 1 : i32 +// CHECK: wasmssa.block_return %[[VAL_11]] : i32 +// CHECK: }> ^bb1 +// CHECK: ^bb1(%[[VAL_12:.*]]: i32): +// CHECK: wasmssa.return %[[VAL_12]] : i32 diff --git a/mlir/test/Target/Wasm/import.mlir b/mlir/test/Target/Wasm/import.mlir index 541dcf3..dcdfa52 100644 --- a/mlir/test/Target/Wasm/import.mlir +++ b/mlir/test/Target/Wasm/import.mlir @@ -11,9 +11,9 @@ ) */ -// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {sym_visibility = "nested", type = (i32) -> ()} -// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {sym_visibility = "nested", type = (i32) -> ()} -// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {sym_visibility = "nested", type = !wasmssa<tabletype !wasmssa.funcref [2:]>} -// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>, sym_visibility = "nested"} -// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 nested : i32 -// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable nested : i32 +// CHECK-LABEL: wasmssa.import_func "foo" from "my_module" as @func_0 {type = (i32) -> ()} +// CHECK: wasmssa.import_func "bar" from "my_module" as @func_1 {type = (i32) -> ()} +// CHECK: wasmssa.import_table "table" from "my_module" as @table_0 {type = !wasmssa<tabletype !wasmssa.funcref [2:]>} +// CHECK: wasmssa.import_mem "mem" from "my_module" as @mem_0 {limits = !wasmssa<limit[2:]>} +// CHECK: wasmssa.import_global "glob" from "my_module" as @global_0 : i32 +// CHECK: wasmssa.import_global "glob_mut" from "my_other_module" as @global_1 mutable : i32 diff --git a/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm b/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm new file mode 100644 index 0000000..865c315 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/add_div.yaml.wasm @@ -0,0 +1,50 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: + - I32 + - Index: 1 + ParamTypes: + - I32 + - I32 + ReturnTypes: + - I32 + - Type: IMPORT + Imports: + - Module: env + Field: twoTimes + Kind: FUNCTION + SigIndex: 0 + - Type: FUNCTION + FunctionTypes: [ 1 ] + - Type: MEMORY + Memories: + - Minimum: 0x2 + - Type: GLOBAL + Globals: + - Index: 0 + Type: I32 + Mutable: true + InitExpr: + Opcode: I32_CONST + Value: 66560 + - Type: EXPORT + Exports: + - Name: memory + Kind: MEMORY + Index: 0 + - Name: add + Kind: FUNCTION + Index: 1 + - Type: CODE + Functions: + - Index: 1 + Locals: [] + Body: 20001000200110006A41026D0B +... diff --git a/mlir/test/Target/Wasm/inputs/block.yaml.wasm b/mlir/test/Target/Wasm/inputs/block.yaml.wasm new file mode 100644 index 0000000..dd5118a --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/block.yaml.wasm @@ -0,0 +1,22 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: [] + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: EXPORT + Exports: + - Name: i_am_a_block + Kind: FUNCTION + Index: 0 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 02400B0B +... diff --git a/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm b/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm new file mode 100644 index 0000000..7a125bf --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/block_complete_type.yaml.wasm @@ -0,0 +1,23 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: + - I32 + - Index: 1 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 1 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410E020041016A0B0B +... diff --git a/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm b/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm new file mode 100644 index 0000000..4ba291d --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/block_value_type.yaml.wasm @@ -0,0 +1,18 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 027F41110B0B +... diff --git a/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm b/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm new file mode 100644 index 0000000..40536ed --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/branch_if.yaml.wasm @@ -0,0 +1,18 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 027F410141020D0041016A0B0B +... diff --git a/mlir/test/Target/Wasm/inputs/call.yaml.wasm b/mlir/test/Target/Wasm/inputs/call.yaml.wasm new file mode 100644 index 0000000..535a623 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/call.yaml.wasm @@ -0,0 +1,26 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0, 0 ] + - Type: EXPORT + Exports: + - Name: forty_two + Kind: FUNCTION + Index: 1 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 412A0B + - Index: 1 + Locals: [] + Body: 10000B +... diff --git a/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm b/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm new file mode 100644 index 0000000..cde9ee1 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/comparison_ops.yaml.wasm @@ -0,0 +1,88 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410C4132480B + - Index: 1 + Locals: [] + Body: 410C41324C0B + - Index: 2 + Locals: [] + Body: 410C4132490B + - Index: 3 + Locals: [] + Body: 410C41324D0B + - Index: 4 + Locals: [] + Body: 410C41324A0B + - Index: 5 + Locals: [] + Body: 410C41324B0B + - Index: 6 + Locals: [] + Body: 410C41324E0B + - Index: 7 + Locals: [] + Body: 410C41324F0B + - Index: 8 + Locals: [] + Body: 420C4232530B + - Index: 9 + Locals: [] + Body: 420C4232570B + - Index: 10 + Locals: [] + Body: 420C4232540B + - Index: 11 + Locals: [] + Body: 420C4232580B + - Index: 12 + Locals: [] + Body: 420C4232550B + - Index: 13 + Locals: [] + Body: 420C4232560B + - Index: 14 + Locals: [] + Body: 420C4232590B + - Index: 15 + Locals: [] + Body: 420C42325A0B + - Index: 16 + Locals: [] + Body: 430000A04043000060415D0B + - Index: 17 + Locals: [] + Body: 430000A04043000060415F0B + - Index: 18 + Locals: [] + Body: 430000A04043000060415E0B + - Index: 19 + Locals: [] + Body: 430000A0404300006041600B + - Index: 20 + Locals: [] + Body: 440000000000001440440000000000002C40630B + - Index: 21 + Locals: [] + Body: 440000000000001440440000000000002C40650B + - Index: 22 + Locals: [] + Body: 440000000000001440440000000000002C40640B + - Index: 23 + Locals: [] + Body: 440000000000001440440000000000002C40660B +... diff --git a/mlir/test/Target/Wasm/inputs/convert.yaml.wasm b/mlir/test/Target/Wasm/inputs/convert.yaml.wasm new file mode 100644 index 0000000..c346a75 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/convert.yaml.wasm @@ -0,0 +1,69 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - F32 + - Index: 1 + ParamTypes: [] + ReturnTypes: + - F64 + - Type: FUNCTION + FunctionTypes: [ 0, 0, 0, 0, 1, 1, 1, 1 ] + - Type: EXPORT + Exports: + - Name: convert_i32_u_to_f32 + Kind: FUNCTION + Index: 0 + - Name: convert_i32_s_to_f32 + Kind: FUNCTION + Index: 1 + - Name: convert_i64_u_to_f32 + Kind: FUNCTION + Index: 2 + - Name: convert_i64s_to_f32 + Kind: FUNCTION + Index: 3 + - Name: convert_i32_u_to_f64 + Kind: FUNCTION + Index: 4 + - Name: convert_i32_s_to_f64 + Kind: FUNCTION + Index: 5 + - Name: convert_i64_u_to_f64 + Kind: FUNCTION + Index: 6 + - Name: convert_i64s_to_f64 + Kind: FUNCTION + Index: 7 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410AB30B + - Index: 1 + Locals: [] + Body: 412AB20B + - Index: 2 + Locals: [] + Body: 4211B50B + - Index: 3 + Locals: [] + Body: 420AB40B + - Index: 4 + Locals: [] + Body: 410AB80B + - Index: 5 + Locals: [] + Body: 412AB70B + - Index: 6 + Locals: [] + Body: 4211BA0B + - Index: 7 + Locals: [] + Body: 420AB90B +... diff --git a/mlir/test/Target/Wasm/inputs/demote.yaml.wasm b/mlir/test/Target/Wasm/inputs/demote.yaml.wasm new file mode 100644 index 0000000..3997045 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/demote.yaml.wasm @@ -0,0 +1,18 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - F32 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 44EC51B81E85EB0140B60B +... diff --git a/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm b/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm new file mode 100644 index 0000000..41a2944 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/double_nested_loop.yaml.wasm @@ -0,0 +1,19 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: [] + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: + - Type: I32 + Count: 2 + Body: 0340200041016A2100037F41012001410C6A220120004A0D000B410A480D000B0B +... diff --git a/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm b/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm new file mode 100644 index 0000000..3171409 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/empty_blocks_list_and_stack.yaml.wasm @@ -0,0 +1,21 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: [] + - Type: FUNCTION + FunctionTypes: [ 0, 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 0240024002400B0B0B0B + - Index: 1 + Locals: [] + Body: 02400B02400B02400B0B +... diff --git a/mlir/test/Target/Wasm/inputs/eq.yaml.wasm b/mlir/test/Target/Wasm/inputs/eq.yaml.wasm new file mode 100644 index 0000000..1998369 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/eq.yaml.wasm @@ -0,0 +1,27 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0, 0, 0, 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410C4132460B + - Index: 1 + Locals: [] + Body: 42144205510B + - Index: 2 + Locals: [] + Body: 430000A04043000060415B0B + - Index: 3 + Locals: [] + Body: 440000000000003140440000000000000000610B +... diff --git a/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm b/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm new file mode 100644 index 0000000..894ac50 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/eqz.yaml.wasm @@ -0,0 +1,29 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0, 0 ] + - Type: EXPORT + Exports: + - Name: eqz_i32 + Kind: FUNCTION + Index: 0 + - Name: eqz_i64 + Kind: FUNCTION + Index: 1 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410D450B + - Index: 1 + Locals: [] + Body: 420D500B +... diff --git a/mlir/test/Target/Wasm/inputs/extend.yaml.wasm b/mlir/test/Target/Wasm/inputs/extend.yaml.wasm new file mode 100644 index 0000000..7e872ba --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/extend.yaml.wasm @@ -0,0 +1,40 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I64 + - Index: 1 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0, 0, 1, 1, 0, 0, 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410AAC0B + - Index: 1 + Locals: [] + Body: 410AAD0B + - Index: 2 + Locals: [] + Body: 410AC00B + - Index: 3 + Locals: [] + Body: 410AC10B + - Index: 4 + Locals: [] + Body: 420AC20B + - Index: 5 + Locals: [] + Body: 420AC30B + - Index: 6 + Locals: [] + Body: 420AC40B +... diff --git a/mlir/test/Target/Wasm/inputs/if.yaml.wasm b/mlir/test/Target/Wasm/inputs/if.yaml.wasm new file mode 100644 index 0000000..ccc38f6 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/if.yaml.wasm @@ -0,0 +1,25 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0, 0, 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 2000410171047F200041036C41016A0520004101760B0B + - Index: 1 + Locals: [] + Body: 20002000410171040041016A0B0B + - Index: 2 + Locals: [] + Body: 200068047F4102200041017668040041026A0B0541010B0B +... diff --git a/mlir/test/Target/Wasm/inputs/loop.yaml.wasm b/mlir/test/Target/Wasm/inputs/loop.yaml.wasm new file mode 100644 index 0000000..9d33894 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/loop.yaml.wasm @@ -0,0 +1,17 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: [] + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 03400B0B +... diff --git a/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm b/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm new file mode 100644 index 0000000..4b8cc54 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/loop_with_inst.yaml.wasm @@ -0,0 +1,20 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: + - Type: I32 + Count: 1 + Body: 037F200041016A21002000410A480B0B +... diff --git a/mlir/test/Target/Wasm/inputs/ne.yaml.wasm b/mlir/test/Target/Wasm/inputs/ne.yaml.wasm new file mode 100644 index 0000000..0167519 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/ne.yaml.wasm @@ -0,0 +1,27 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0, 0, 0, 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410C4132470B + - Index: 1 + Locals: [] + Body: 42144205520B + - Index: 2 + Locals: [] + Body: 430000A04043000060415C0B + - Index: 3 + Locals: [] + Body: 440000000000003140440000000000000000620B +... diff --git a/mlir/test/Target/Wasm/inputs/promote.yaml.wasm b/mlir/test/Target/Wasm/inputs/promote.yaml.wasm new file mode 100644 index 0000000..d38603e --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/promote.yaml.wasm @@ -0,0 +1,18 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - F64 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 4300002841BB0B +... diff --git a/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm b/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm new file mode 100644 index 0000000..c01c1b1 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/reinterpret.yaml.wasm @@ -0,0 +1,53 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - I32 + - Index: 1 + ParamTypes: [] + ReturnTypes: + - I64 + - Index: 2 + ParamTypes: [] + ReturnTypes: + - F32 + - Index: 3 + ParamTypes: [] + ReturnTypes: + - F64 + - Type: FUNCTION + FunctionTypes: [ 0, 1, 2, 3 ] + - Type: EXPORT + Exports: + - Name: i32.reinterpret_f32 + Kind: FUNCTION + Index: 0 + - Name: i64.reinterpret_f64 + Kind: FUNCTION + Index: 1 + - Name: f32.reinterpret_i32 + Kind: FUNCTION + Index: 2 + - Name: f64.reinterpret_i64 + Kind: FUNCTION + Index: 3 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 43000080BFBC0B + - Index: 1 + Locals: [] + Body: 44000000000000F0BFBD0B + - Index: 2 + Locals: [] + Body: 417FBE0B + - Index: 3 + Locals: [] + Body: 427FBF0B +... diff --git a/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm b/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm new file mode 100644 index 0000000..c6e8bf6 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/rounding.yaml.wasm @@ -0,0 +1,37 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: [] + ReturnTypes: + - F64 + - Index: 1 + ParamTypes: [] + ReturnTypes: + - F32 + - Type: FUNCTION + FunctionTypes: [ 0, 1, 0, 1, 0, 1 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 4433333333333328C09B0B + - Index: 1 + Locals: [] + Body: 43A01ACF3F8D0B + - Index: 2 + Locals: [] + Body: 4433333333333328C09C0B + - Index: 3 + Locals: [] + Body: 43A01ACF3F8E0B + - Index: 4 + Locals: [] + Body: 4433333333333328C09D0B + - Index: 5 + Locals: [] + Body: 43A01ACF3F8F0B +... diff --git a/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm b/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm new file mode 100644 index 0000000..51c0b02 --- /dev/null +++ b/mlir/test/Target/Wasm/inputs/wrap.yaml.wasm @@ -0,0 +1,24 @@ +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I64 + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 0 ] + - Type: EXPORT + Exports: + - Name: i64_wrap + Kind: FUNCTION + Index: 0 + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 2000A70B +... diff --git a/mlir/test/Target/Wasm/invalid_block_type_index.yaml b/mlir/test/Target/Wasm/invalid_block_type_index.yaml new file mode 100644 index 0000000..5b83e2e --- /dev/null +++ b/mlir/test/Target/Wasm/invalid_block_type_index.yaml @@ -0,0 +1,28 @@ + +# RUN: yaml2obj %s | not mlir-translate --import-wasm -o - 2>&1 | FileCheck %s + +# CHECK: type index references nonexistent type (2) + +--- !WASM +FileHeader: + Version: 0x1 +Sections: + - Type: TYPE + Signatures: + - Index: 0 + ParamTypes: + - I32 + ReturnTypes: + - I32 + - Index: 1 + ParamTypes: [] + ReturnTypes: + - I32 + - Type: FUNCTION + FunctionTypes: [ 1 ] + - Type: CODE + Functions: + - Index: 0 + Locals: [] + Body: 410E020241016A0B0B +# -----------------------------^^ Invalid type ID diff --git a/mlir/test/Target/Wasm/local.mlir b/mlir/test/Target/Wasm/local.mlir index 32f5900..9844f9c 100644 --- a/mlir/test/Target/Wasm/local.mlir +++ b/mlir/test/Target/Wasm/local.mlir @@ -29,7 +29,7 @@ ) */ -// CHECK-LABEL: wasmssa.func nested @func_0() -> f32 { +// CHECK-LABEL: wasmssa.func @func_0() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.local of type f32 // CHECK: %[[VAL_1:.*]] = wasmssa.local of type f32 // CHECK: %[[VAL_2:.*]] = wasmssa.const 8.000000e+00 : f32 @@ -40,7 +40,7 @@ // CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : f32 // CHECK: wasmssa.return %[[VAL_6]] : f32 -// CHECK-LABEL: wasmssa.func nested @func_1() -> i32 { +// CHECK-LABEL: wasmssa.func @func_1() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32 // CHECK: %[[VAL_1:.*]] = wasmssa.local of type i32 // CHECK: %[[VAL_2:.*]] = wasmssa.const 8 : i32 @@ -51,7 +51,7 @@ // CHECK: %[[VAL_6:.*]] = wasmssa.add %[[VAL_3]] %[[VAL_5]] : i32 // CHECK: wasmssa.return %[[VAL_6]] : i32 -// CHECK-LABEL: wasmssa.func nested @func_2( +// CHECK-LABEL: wasmssa.func @func_2( // CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i32>) -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 3 : i32 // CHECK: wasmssa.local_set %[[ARG0]] : ref to i32 to %[[VAL_0]] : i32 diff --git a/mlir/test/Target/Wasm/loop.mlir b/mlir/test/Target/Wasm/loop.mlir new file mode 100644 index 0000000..29ad502 --- /dev/null +++ b/mlir/test/Target/Wasm/loop.mlir @@ -0,0 +1,17 @@ +// RUN: yaml2obj %S/inputs/loop.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* IR generated from: +(module + (func + (loop $my_loop + ) + ) +)*/ + +// CHECK-LABEL: wasmssa.func @func_0() { +// CHECK: wasmssa.loop : { +// CHECK: wasmssa.block_return +// CHECK: }> ^bb1 +// CHECK: ^bb1: +// CHECK: wasmssa.return +// CHECK: } diff --git a/mlir/test/Target/Wasm/loop_with_inst.mlir b/mlir/test/Target/Wasm/loop_with_inst.mlir new file mode 100644 index 0000000..311d007 --- /dev/null +++ b/mlir/test/Target/Wasm/loop_with_inst.mlir @@ -0,0 +1,33 @@ +// RUN: yaml2obj %S/inputs/loop_with_inst.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Code used to create this test: + +(module + (func (result i32) + (local $i i32) + (loop $my_loop (result i32) + local.get $i + i32.const 1 + i32.add + local.set $i + local.get $i + i32.const 10 + i32.lt_s + ) + ) +)*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.local of type i32 +// CHECK: wasmssa.loop : { +// CHECK: %[[VAL_1:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.const 1 : i32 +// CHECK: %[[VAL_3:.*]] = wasmssa.add %[[VAL_1]] %[[VAL_2]] : i32 +// CHECK: wasmssa.local_set %[[VAL_0]] : ref to i32 to %[[VAL_3]] : i32 +// CHECK: %[[VAL_4:.*]] = wasmssa.local_get %[[VAL_0]] : ref to i32 +// CHECK: %[[VAL_5:.*]] = wasmssa.const 10 : i32 +// CHECK: %[[VAL_6:.*]] = wasmssa.lt_si %[[VAL_4]] %[[VAL_5]] : i32 -> i32 +// CHECK: wasmssa.block_return %[[VAL_6]] : i32 +// CHECK: }> ^bb1 +// CHECK: ^bb1(%[[VAL_7:.*]]: i32): +// CHECK: wasmssa.return %[[VAL_7]] : i32 diff --git a/mlir/test/Target/Wasm/max.mlir b/mlir/test/Target/Wasm/max.mlir index 4ef2042..9160bde 100644 --- a/mlir/test/Target/Wasm/max.mlir +++ b/mlir/test/Target/Wasm/max.mlir @@ -16,14 +16,14 @@ ) */ -// CHECK-LABEL: wasmssa.func @min_f32() -> f32 { +// CHECK-LABEL: wasmssa.func exported @min_f32() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32 // CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f32 // CHECK: wasmssa.return %[[VAL_2]] : f32 -// CHECK-LABEL: wasmssa.func @min_f64() -> f64 { +// CHECK-LABEL: wasmssa.func exported @min_f64() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64 // CHECK: %[[VAL_2:.*]] = wasmssa.max %[[VAL_0]] %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/memory_min_eq_max.mlir b/mlir/test/Target/Wasm/memory_min_eq_max.mlir index 2ba5ab5..ea8f719 100644 --- a/mlir/test/Target/Wasm/memory_min_eq_max.mlir +++ b/mlir/test/Target/Wasm/memory_min_eq_max.mlir @@ -4,4 +4,4 @@ (module (memory 0 0)) */ -// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 0]> +// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[0: 0]> diff --git a/mlir/test/Target/Wasm/memory_min_max.mlir b/mlir/test/Target/Wasm/memory_min_max.mlir index ebf6418..88782ec 100644 --- a/mlir/test/Target/Wasm/memory_min_max.mlir +++ b/mlir/test/Target/Wasm/memory_min_max.mlir @@ -4,4 +4,4 @@ (module (memory 0 65536)) */ -// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[0: 65536]> +// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[0: 65536]> diff --git a/mlir/test/Target/Wasm/memory_min_no_max.mlir b/mlir/test/Target/Wasm/memory_min_no_max.mlir index 8d88786..c10c5cc 100644 --- a/mlir/test/Target/Wasm/memory_min_no_max.mlir +++ b/mlir/test/Target/Wasm/memory_min_no_max.mlir @@ -4,4 +4,4 @@ (module (memory 1)) */ -// CHECK-LABEL: wasmssa.memory @mem_0 nested !wasmssa<limit[1:]> +// CHECK-LABEL: wasmssa.memory @mem_0 !wasmssa<limit[1:]> diff --git a/mlir/test/Target/Wasm/min.mlir b/mlir/test/Target/Wasm/min.mlir index 1058c7d..2372bcc 100644 --- a/mlir/test/Target/Wasm/min.mlir +++ b/mlir/test/Target/Wasm/min.mlir @@ -16,13 +16,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @min_f32() -> f32 { +// CHECK-LABEL: wasmssa.func exported @min_f32() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f32 // CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f32 // CHECK: wasmssa.return %[[VAL_2]] : f32 -// CHECK-LABEL: wasmssa.func @min_f64() -> f64 { +// CHECK-LABEL: wasmssa.func exported @min_f64() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 1.000000e+00 : f64 // CHECK: %[[VAL_2:.*]] = wasmssa.min %[[VAL_0]] %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/ne.mlir b/mlir/test/Target/Wasm/ne.mlir new file mode 100644 index 0000000..331df75 --- /dev/null +++ b/mlir/test/Target/Wasm/ne.mlir @@ -0,0 +1,52 @@ +// RUN: yaml2obj %S/inputs/ne.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s +/* Source code used to create this test: +(module + (func $ne_i32 (result i32) + i32.const 12 + i32.const 50 + i32.ne + ) + + (func $ne_i64 (result i32) + i64.const 20 + i64.const 5 + i64.ne + ) + + (func $ne_f32 (result i32) + f32.const 5 + f32.const 14 + f32.ne + ) + + (func $ne_f64 (result i32) + f64.const 17 + f64.const 0 + f64.ne + ) +) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 +// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : i32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_1() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64 +// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : i64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_2() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : f32 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 + +// CHECK-LABEL: wasmssa.func @func_3() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64 +// CHECK: %[[VAL_2:.*]] = wasmssa.ne %[[VAL_0]] %[[VAL_1]] : f64 -> i32 +// CHECK: wasmssa.return %[[VAL_2]] : i32 diff --git a/mlir/test/Target/Wasm/neg.mlir b/mlir/test/Target/Wasm/neg.mlir index 5811ab50..dae8ee5 100644 --- a/mlir/test/Target/Wasm/neg.mlir +++ b/mlir/test/Target/Wasm/neg.mlir @@ -12,12 +12,12 @@ ) */ -// CHECK-LABEL: wasmssa.func @neg_f32() -> f32 { +// CHECK-LABEL: wasmssa.func exported @neg_f32() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f32 // CHECK: wasmssa.return %[[VAL_1]] : f32 -// CHECK-LABEL: wasmssa.func @neg_f64() -> f64 { +// CHECK-LABEL: wasmssa.func exported @neg_f64() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.neg %[[VAL_0]] : f64 // CHECK: wasmssa.return %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/or.mlir b/mlir/test/Target/Wasm/or.mlir index 521f2ba..be0b3d7 100644 --- a/mlir/test/Target/Wasm/or.mlir +++ b/mlir/test/Target/Wasm/or.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @or_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @or_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.or %0 %1 : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @or_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @or_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.or %0 %1 : i64 diff --git a/mlir/test/Target/Wasm/popcnt.mlir b/mlir/test/Target/Wasm/popcnt.mlir index 235333a..bfaa8eb 100644 --- a/mlir/test/Target/Wasm/popcnt.mlir +++ b/mlir/test/Target/Wasm/popcnt.mlir @@ -14,12 +14,12 @@ ) */ -// CHECK-LABEL: wasmssa.func @popcnt_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @popcnt_i32() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i32 // CHECK: wasmssa.return %[[VAL_1]] : i32 -// CHECK-LABEL: wasmssa.func @popcnt_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @popcnt_i64() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 10 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.popcnt %[[VAL_0]] : i64 // CHECK: wasmssa.return %[[VAL_1]] : i64 diff --git a/mlir/test/Target/Wasm/promote.mlir b/mlir/test/Target/Wasm/promote.mlir new file mode 100644 index 0000000..44c31b6 --- /dev/null +++ b/mlir/test/Target/Wasm/promote.mlir @@ -0,0 +1,14 @@ +// RUN: yaml2obj %S/inputs/promote.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* Source code used to generate this test: +(module + (func $main (result f64) + f32.const 10.5 + f64.promote_f32 + ) +)*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.050000e+01 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.promote %[[VAL_0]] : f32 to f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/reinterpret.mlir b/mlir/test/Target/Wasm/reinterpret.mlir new file mode 100644 index 0000000..574d13f --- /dev/null +++ b/mlir/test/Target/Wasm/reinterpret.mlir @@ -0,0 +1,46 @@ +// RUN: yaml2obj %S/inputs/reinterpret.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s + +/* +Test generated from: +(module + (func (export "i32.reinterpret_f32") (result i32) + f32.const -1 + i32.reinterpret_f32 + ) + + (func (export "i64.reinterpret_f64") (result i64) + f64.const -1 + i64.reinterpret_f64 + ) + + (func (export "f32.reinterpret_i32") (result f32) + i32.const -1 + f32.reinterpret_i32 + ) + + (func (export "f64.reinterpret_i64") (result f64) + i64.const -1 + f64.reinterpret_i64 + ) +) +*/ + +// CHECK-LABEL: wasmssa.func exported @i32.reinterpret_f32() -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : f32 as i32 +// CHECK: wasmssa.return %[[VAL_1]] : i32 + +// CHECK-LABEL: wasmssa.func exported @i64.reinterpret_f64() -> i64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.000000e+00 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : f64 as i64 +// CHECK: wasmssa.return %[[VAL_1]] : i64 + +// CHECK-LABEL: wasmssa.func exported @f32.reinterpret_i32() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const -1 : i32 +// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : i32 as f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 + +// CHECK-LABEL: wasmssa.func exported @f64.reinterpret_i64() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const -1 : i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.reinterpret %[[VAL_0]] : i64 as f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/rem.mlir b/mlir/test/Target/Wasm/rem.mlir index b19b8d9..16c9c78 100644 --- a/mlir/test/Target/Wasm/rem.mlir +++ b/mlir/test/Target/Wasm/rem.mlir @@ -24,28 +24,28 @@ ) */ -// CHECK-LABEL: wasmssa.func @rem_u_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @rem_u_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.rem_ui %0 %1 : i32 // CHECK: wasmssa.return %2 : i32 // CHECK: } -// CHECK-LABEL: wasmssa.func @rem_u_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @rem_u_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.rem_ui %0 %1 : i64 // CHECK: wasmssa.return %2 : i64 // CHECK: } -// CHECK-LABEL: wasmssa.func @rem_s_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @rem_s_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.rem_si %0 %1 : i32 // CHECK: wasmssa.return %2 : i32 // CHECK: } -// CHECK-LABEL: wasmssa.func @rem_s_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @rem_s_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.rem_si %0 %1 : i64 diff --git a/mlir/test/Target/Wasm/rotl.mlir b/mlir/test/Target/Wasm/rotl.mlir index ec573554..4c2e5af 100644 --- a/mlir/test/Target/Wasm/rotl.mlir +++ b/mlir/test/Target/Wasm/rotl.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @rotl_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @rotl_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.rotl %0 by %1 bits : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @rotl_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @rotl_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.rotl %0 by %1 bits : i64 diff --git a/mlir/test/Target/Wasm/rotr.mlir b/mlir/test/Target/Wasm/rotr.mlir index 5618b43..ec403d0 100644 --- a/mlir/test/Target/Wasm/rotr.mlir +++ b/mlir/test/Target/Wasm/rotr.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @rotr_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @rotr_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.rotr %0 by %1 bits : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @rotr_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @rotr_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.rotr %0 by %1 bits : i64 diff --git a/mlir/test/Target/Wasm/rounding.mlir b/mlir/test/Target/Wasm/rounding.mlir new file mode 100644 index 0000000..947637e --- /dev/null +++ b/mlir/test/Target/Wasm/rounding.mlir @@ -0,0 +1,50 @@ +// RUN: yaml2obj %S/inputs/rounding.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s +/* Source code used to create this test: +(module + (func $ceil_f64 (result f64) + f64.const -12.1 + f64.ceil + ) + (func $ceil_f32 (result f32) + f32.const 1.618 + f32.ceil + ) + (func $floor_f64 (result f64) + f64.const -12.1 + f64.floor + ) + (func $floor_f32 (result f32) + f32.const 1.618 + f32.floor + ) +*/ + +// CHECK-LABEL: wasmssa.func @func_0() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.ceil %[[VAL_0]] : f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 + +// CHECK-LABEL: wasmssa.func @func_1() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.ceil %[[VAL_0]] : f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 + +// CHECK-LABEL: wasmssa.func @func_2() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.floor %[[VAL_0]] : f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 + +// CHECK-LABEL: wasmssa.func @func_3() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.floor %[[VAL_0]] : f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 + +// CHECK-LABEL: wasmssa.func @func_4() -> f64 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const -1.210000e+01 : f64 +// CHECK: %[[VAL_1:.*]] = wasmssa.trunc %[[VAL_0]] : f64 +// CHECK: wasmssa.return %[[VAL_1]] : f64 + +// CHECK-LABEL: wasmssa.func @func_5() -> f32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.const 1.618000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = wasmssa.trunc %[[VAL_0]] : f32 +// CHECK: wasmssa.return %[[VAL_1]] : f32 diff --git a/mlir/test/Target/Wasm/shl.mlir b/mlir/test/Target/Wasm/shl.mlir index f2bdd57..1363112 100644 --- a/mlir/test/Target/Wasm/shl.mlir +++ b/mlir/test/Target/Wasm/shl.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @shl_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @shl_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.shl %0 by %1 bits : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @shl_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @shl_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.shl %0 by %1 bits : i64 diff --git a/mlir/test/Target/Wasm/shr_s.mlir b/mlir/test/Target/Wasm/shr_s.mlir index 247d9be..da1a38f 100644 --- a/mlir/test/Target/Wasm/shr_s.mlir +++ b/mlir/test/Target/Wasm/shr_s.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @shr_s_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @shr_s_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @shr_s_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @shr_s_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.shr_s %0 by %1 bits : i64 diff --git a/mlir/test/Target/Wasm/shr_u.mlir b/mlir/test/Target/Wasm/shr_u.mlir index 9a79eed..2991c2a 100644 --- a/mlir/test/Target/Wasm/shr_u.mlir +++ b/mlir/test/Target/Wasm/shr_u.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @shr_u_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @shr_u_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @shr_u_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @shr_u_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.shr_u %0 by %1 bits : i64 diff --git a/mlir/test/Target/Wasm/sqrt.mlir b/mlir/test/Target/Wasm/sqrt.mlir index 77444ad..6b968d6 100644 --- a/mlir/test/Target/Wasm/sqrt.mlir +++ b/mlir/test/Target/Wasm/sqrt.mlir @@ -12,12 +12,12 @@ ) */ -// CHECK-LABEL: wasmssa.func @sqrt_f32() -> f32 { +// CHECK-LABEL: wasmssa.func exported @sqrt_f32() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f32 // CHECK: wasmssa.return %[[VAL_1]] : f32 -// CHECK-LABEL: wasmssa.func @sqrt_f64() -> f64 { +// CHECK-LABEL: wasmssa.func exported @sqrt_f64() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.000000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.sqrt %[[VAL_0]] : f64 // CHECK: wasmssa.return %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/sub.mlir b/mlir/test/Target/Wasm/sub.mlir index b9c6caf..5b242f4 100644 --- a/mlir/test/Target/Wasm/sub.mlir +++ b/mlir/test/Target/Wasm/sub.mlir @@ -27,25 +27,25 @@ ) */ -// CHECK-LABEL: wasmssa.func nested @func_0() -> i32 { +// CHECK-LABEL: wasmssa.func @func_0() -> i32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 12 : i32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 50 : i32 // CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i32 // CHECK: wasmssa.return %[[VAL_2]] : i32 -// CHECK-LABEL: wasmssa.func nested @func_1() -> i64 { +// CHECK-LABEL: wasmssa.func @func_1() -> i64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 20 : i64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 5 : i64 // CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : i64 // CHECK: wasmssa.return %[[VAL_2]] : i64 -// CHECK-LABEL: wasmssa.func nested @func_2() -> f32 { +// CHECK-LABEL: wasmssa.func @func_2() -> f32 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 5.000000e+00 : f32 // CHECK: %[[VAL_1:.*]] = wasmssa.const 1.400000e+01 : f32 // CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f32 // CHECK: wasmssa.return %[[VAL_2]] : f32 -// CHECK-LABEL: wasmssa.func nested @func_3() -> f64 { +// CHECK-LABEL: wasmssa.func @func_3() -> f64 { // CHECK: %[[VAL_0:.*]] = wasmssa.const 1.700000e+01 : f64 // CHECK: %[[VAL_1:.*]] = wasmssa.const 0.000000e+00 : f64 // CHECK: %[[VAL_2:.*]] = wasmssa.sub %[[VAL_0]] %[[VAL_1]] : f64 diff --git a/mlir/test/Target/Wasm/wrap.mlir b/mlir/test/Target/Wasm/wrap.mlir new file mode 100644 index 0000000..1266758 --- /dev/null +++ b/mlir/test/Target/Wasm/wrap.mlir @@ -0,0 +1,15 @@ +// RUN: yaml2obj %S/inputs/wrap.yaml.wasm -o - | mlir-translate --import-wasm | FileCheck %s +/* Source code used to create this test: +(module + (func (export "i64_wrap") (param $in i64) (result i32) + local.get $in + i32.wrap_i64 + ) +) +*/ + +// CHECK-LABEL: wasmssa.func exported @i64_wrap( +// CHECK-SAME: %[[ARG0:.*]]: !wasmssa<local ref to i64>) -> i32 { +// CHECK: %[[VAL_0:.*]] = wasmssa.local_get %[[ARG0]] : ref to i64 +// CHECK: %[[VAL_1:.*]] = wasmssa.wrap %[[VAL_0]] : i64 to i32 +// CHECK: wasmssa.return %[[VAL_1]] : i32 diff --git a/mlir/test/Target/Wasm/xor.mlir b/mlir/test/Target/Wasm/xor.mlir index 94691de..56407db 100644 --- a/mlir/test/Target/Wasm/xor.mlir +++ b/mlir/test/Target/Wasm/xor.mlir @@ -14,13 +14,13 @@ ) */ -// CHECK-LABEL: wasmssa.func @xor_i32() -> i32 { +// CHECK-LABEL: wasmssa.func exported @xor_i32() -> i32 { // CHECK: %0 = wasmssa.const 10 : i32 // CHECK: %1 = wasmssa.const 3 : i32 // CHECK: %2 = wasmssa.xor %0 %1 : i32 // CHECK: wasmssa.return %2 : i32 -// CHECK-LABEL: wasmssa.func @xor_i64() -> i64 { +// CHECK-LABEL: wasmssa.func exported @xor_i64() -> i64 { // CHECK: %0 = wasmssa.const 10 : i64 // CHECK: %1 = wasmssa.const 3 : i64 // CHECK: %2 = wasmssa.xor %0 %1 : i64 diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 727c84c..8c5c8e8 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -276,10 +276,8 @@ void TestLinalgTransforms::runOnOperation() { Operation *consumer = opOperand->getOwner(); // If we have a pack/unpack consumer and a producer that has multiple // uses, do not apply the folding patterns. - if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) && - isa<TilingInterface>(producer) && !producer->hasOneUse()) - return false; - return true; + return !(isa<linalg::PackOp, linalg::UnPackOp>(consumer) && + isa<TilingInterface>(producer) && !producer->hasOneUse()); }; applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn); } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 97fc699..496f18b 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -938,10 +938,10 @@ public: // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "TestTransformDialectExtensionTypes.cpp.inc" diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll index 4e869e5..4be30d8 100644 --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -28,7 +28,7 @@ // CHECK: operation "test.op3" // CHECK: )mlir", context), std::forward<ConfigsT>(configs)...) -// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) { +// CHECK{LITERAL}: [[maybe_unused]] static void populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) { // CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext(), configs...); // CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext(), configs...); // CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext(), configs...); diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py index 26ee9f3..66c4018 100644 --- a/mlir/test/python/dialects/gpu/dialect.py +++ b/mlir/test/python/dialects/gpu/dialect.py @@ -1,6 +1,7 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * +import mlir.ir as ir import mlir.dialects.gpu as gpu import mlir.dialects.gpu.passes from mlir.passmanager import * @@ -64,3 +65,95 @@ def testObjectAttr(): # CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf"> print(o) assert o.kernels == kernelTable + + +# CHECK-LABEL: testGPUFuncOp +@run +def testGPUFuncOp(): + assert gpu.GPUFuncOp.__doc__ is not None + module = Module.create() + with InsertionPoint(module.body): + gpu_module_name = StringAttr.get("gpu_module") + gpumodule = gpu.GPUModuleOp(gpu_module_name) + block = gpumodule.bodyRegion.blocks.append() + + def builder(func: gpu.GPUFuncOp) -> None: + gpu.GlobalIdOp(gpu.Dimension.x) + gpu.ReturnOp([]) + + with InsertionPoint(block): + name = StringAttr.get("kernel0") + func_type = ir.FunctionType.get(inputs=[], results=[]) + type_attr = TypeAttr.get(func_type) + func = gpu.GPUFuncOp(type_attr, name) + func.attributes["sym_name"] = name + func.attributes["gpu.kernel"] = UnitAttr.get() + + try: + func.entry_block + assert False, "Expected RuntimeError" + except RuntimeError as e: + assert ( + str(e) + == "Entry block does not exist for kernel0. Do you need to call the add_entry_block() method on this GPUFuncOp?" + ) + + block = func.add_entry_block() + with InsertionPoint(block): + builder(func) + + try: + func.add_entry_block() + assert False, "Expected RuntimeError" + except RuntimeError as e: + assert str(e) == "Entry block already exists for kernel0" + + func = gpu.GPUFuncOp( + func_type, + sym_name="kernel1", + kernel=True, + body_builder=builder, + known_block_size=[1, 2, 3], + known_grid_size=DenseI32ArrayAttr.get([4, 5, 6]), + ) + + assert func.name.value == "kernel1" + assert func.function_type.value == func_type + assert func.arg_attrs == None + assert func.res_attrs == None + assert func.arguments == [] + assert func.entry_block == func.body.blocks[0] + assert func.is_kernel + assert func.known_block_size == DenseI32ArrayAttr.get( + [1, 2, 3] + ), func.known_block_size + assert func.known_grid_size == DenseI32ArrayAttr.get( + [4, 5, 6] + ), func.known_grid_size + + func = gpu.GPUFuncOp( + func_type, + sym_name="non_kernel_func", + body_builder=builder, + ) + assert not func.is_kernel + assert func.known_block_size is None + assert func.known_grid_size is None + + print(module) + + # CHECK: gpu.module @gpu_module + # CHECK: gpu.func @kernel0() kernel { + # CHECK: %[[VAL_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } + # CHECK: gpu.func @kernel1() kernel attributes + # CHECK-SAME: known_block_size = array<i32: 1, 2, 3> + # CHECK-SAME: known_grid_size = array<i32: 4, 5, 6> + # CHECK: %[[VAL_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } + # CHECK: gpu.func @non_kernel_func() { + # CHECK: %[[VAL_0:.*]] = gpu.global_id x + # CHECK: gpu.return + # CHECK: } diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py new file mode 100644 index 0000000..8f2142a --- /dev/null +++ b/mlir/test/python/dialects/openacc.py @@ -0,0 +1,171 @@ +# RUN: %PYTHON %s | FileCheck %s +from unittest import result +from mlir.ir import ( + Context, + FunctionType, + Location, + Module, + InsertionPoint, + IntegerType, + IndexType, + MemRefType, + F32Type, + Block, + ArrayAttr, + Attribute, + UnitAttr, + StringAttr, + DenseI32ArrayAttr, + ShapedType, +) +from mlir.dialects import openacc, func, arith, memref +from mlir.extras import types + + +def run(f): + print("\n// TEST:", f.__name__) + with Context(), Location.unknown(): + f() + return f + + +@run +def testParallelMemcpy(): + module = Module.create() + + dynamic = ShapedType.get_dynamic_size() + memref_f32_1d_any = MemRefType.get([dynamic], types.f32()) + + with InsertionPoint(module.body): + function_type = FunctionType.get( + [memref_f32_1d_any, memref_f32_1d_any, types.i64()], [] + ) + f = func.FuncOp( + type=function_type, + name="memcpy_idiom", + ) + f.attributes["sym_visibility"] = StringAttr.get("public") + + with InsertionPoint(f.add_entry_block()): + c1024 = arith.ConstantOp(types.i32(), 1024) + c128 = arith.ConstantOp(types.i32(), 128) + + arg0, arg1, arg2 = f.arguments + + copied = openacc.copyin( + acc_var=arg0.type, + var=arg0, + var_type=types.f32(), + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + created = openacc.create_( + acc_var=arg1.type, + var=arg1, + var_type=types.f32(), + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + + parallel_op = openacc.ParallelOp( + asyncOperands=[], + waitOperands=[], + numGangs=[c1024], + numWorkers=[], + vectorLength=[c128], + reductionOperands=[], + privateOperands=[], + firstprivateOperands=[], + dataClauseOperands=[], + ) + + # Set required device_type and segment attributes to satisfy verifier + acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")]) + parallel_op.numGangsDeviceType = acc_device_none + parallel_op.numGangsSegments = DenseI32ArrayAttr.get([1]) + parallel_op.vectorLengthDeviceType = acc_device_none + + parallel_block = Block.create_at_start(parent=parallel_op.region, arg_types=[]) + + with InsertionPoint(parallel_block): + c0 = arith.ConstantOp(types.i64(), 0) + c1 = arith.ConstantOp(types.i64(), 1) + + loop_op = openacc.LoopOp( + results_=[], + lowerbound=[c0], + upperbound=[f.arguments[2]], + step=[c1], + gangOperands=[], + workerNumOperands=[], + vectorOperands=[], + tileOperands=[], + cacheOperands=[], + privateOperands=[], + reductionOperands=[], + firstprivateOperands=[], + ) + + # Set loop attributes: gang and independent on device_type<none> + acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")]) + loop_op.gang = acc_device_none + loop_op.independent = acc_device_none + + loop_block = Block.create_at_start( + parent=loop_op.region, arg_types=[types.i64()] + ) + + with InsertionPoint(loop_block): + idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0]) + val = memref.load(memref=copied, indices=[idx]) + memref.store(value=val, memref=created, indices=[idx]) + openacc.YieldOp([]) + + openacc.YieldOp([]) + + deleted = openacc.delete( + acc_var=copied, + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + copied = openacc.copyout( + acc_var=created, + var=arg1, + var_type=types.f32(), + bounds=[], + async_operands=[], + implicit=False, + structured=True, + ) + func.ReturnOp([]) + + print(module) + + # CHECK: TEST: testParallelMemcpy + # CHECK-LABEL: func.func public @memcpy_idiom( + # CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: i64) { + # CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : i32 + # CHECK: %[[CONSTANT_1:.*]] = arith.constant 128 : i32 + # CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ARG0]] : memref<?xf32>) -> memref<?xf32> + # CHECK: %[[CREATE_0:.*]] = acc.create varPtr(%[[ARG1]] : memref<?xf32>) -> memref<?xf32> + # CHECK: acc.parallel num_gangs({%[[CONSTANT_0]] : i32}) vector_length(%[[CONSTANT_1]] : i32) { + # CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i64 + # CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i64 + # CHECK: acc.loop gang control(%[[VAL_0:.*]] : i64) = (%[[CONSTANT_2]] : i64) to (%[[ARG2]] : i64) step (%[[CONSTANT_3]] : i64) { + # CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i64 to index + # CHECK: %[[LOAD_0:.*]] = memref.load %[[COPYIN_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32> + # CHECK: memref.store %[[LOAD_0]], %[[CREATE_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32> + # CHECK: acc.yield + # CHECK: } attributes {independent = [#acc.device_type<none>]} + # CHECK: acc.yield + # CHECK: } + # CHECK: acc.delete accPtr(%[[COPYIN_0]] : memref<?xf32>) + # CHECK: acc.copyout accPtr(%[[CREATE_0]] : memref<?xf32>) to varPtr(%[[ARG1]] : memref<?xf32>) + # CHECK: return + # CHECK: } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index cb4cfc8c..1d4ede1 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -569,12 +569,30 @@ def testOperationAttributes(): # CHECK: Attribute value b'text' print(f"Attribute value {sattr.value_bytes}") + # Python dict-style iteration # We don't know in which order the attributes are stored. - # CHECK-DAG: NamedAttribute(dependent="text") - # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) - # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) - for attr in op.attributes: - print(str(attr)) + # CHECK-DAG: dependent + # CHECK-DAG: other.attribute + # CHECK-DAG: some.attribute + for name in op.attributes: + print(name) + + # Basic dict-like introspection + # CHECK: True + print("some.attribute" in op.attributes) + # CHECK: False + print("missing" in op.attributes) + # CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute'] + print("Keys:", sorted(op.attributes.keys())) + # CHECK: Values count 3 + print("Values count", len(op.attributes.values())) + # CHECK: Items count 3 + print("Items count", len(op.attributes.items())) + + # Dict() conversion test + d = {k: v.value for k, v in dict(op.attributes).items()} + # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1} + print("Dict mapping", d) # Check that exceptions are raised as expected. try: diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp index 96af14d..11a2db4 100644 --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -416,7 +416,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) { // Emit the function converting the enum attribute to its LLVM counterpart. os << formatv( - "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n", + "[[maybe_unused]] static {0} convert{1}ToLLVM({2}::{1} value) {{\n", llvmClass, cppClassName, cppNamespace); os << " switch (value) {\n"; @@ -444,7 +444,7 @@ static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) { StringRef cppNamespace = enumAttr.getCppNamespace(); // Emit the function converting the enum attribute to its LLVM counterpart. - os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t " + os << formatv("[[maybe_unused]] static int64_t " "convert{0}ToLLVM({1}::{0} value) {{\n", cppClassName, cppNamespace); os << " switch (value) {\n"; @@ -474,7 +474,7 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) { StringRef cppNamespace = enumInfo.getCppNamespace(); // Emit the function converting the enum attribute from its LLVM counterpart. - os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} " + os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM({2} " "value) {{\n", cppNamespace, cppClassName, llvmClass); os << " switch (value) {\n"; @@ -509,10 +509,9 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) { StringRef cppNamespace = enumInfo.getCppNamespace(); // Emit the function converting the enum attribute from its LLVM counterpart. - os << formatv( - "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t " - "value) {{\n", - cppNamespace, cppClassName); + os << formatv("[[maybe_unused]] inline {0}::{1} convert{1}FromLLVM(int64_t " + "value) {{\n", + cppNamespace, cppClassName); os << " switch (value) {\n"; for (const auto &enumerant : enumInfo.getAllCases()) { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index daae3c7..3718648 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -4896,7 +4896,7 @@ static void emitOpClassDefs(const RecordKeeper &records, constraintPrefix); os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); staticVerifierEmitter.collectOpConstraints(defs); - staticVerifierEmitter.emitOpConstraints(defs); + staticVerifierEmitter.emitOpConstraints(); // Emit the classes. emitOpClasses(records, defs, os, staticVerifierEmitter, diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 40bc1a9..c3034bb8 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -2120,7 +2120,7 @@ static void emitRewriters(const RecordKeeper &records, raw_ostream &os) { } // Emit function to add the generated matchers to the pattern list. - os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" + os << "[[maybe_unused]] void populateWithGenerated(" "::mlir::RewritePatternSet &patterns) {\n"; for (const auto &name : rewriterNames) { os << " patterns.add<" << name << ">(patterns.getContext());\n"; diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index fe26fc1..2a58305 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -113,8 +113,7 @@ static Match tensorMatch(TensorId tid) { return Match(tid); } static Match synZeroMatch() { return Match(); } #define IMPL_BINOP_PATTERN(OP, KIND) \ - LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \ - const Match &e1) { \ + [[maybe_unused]] static Match OP##Match(const Match &e0, const Match &e1) { \ return Match(KIND, e0, e1); \ } FOREVERY_BINOP(IMPL_BINOP_PATTERN) |