diff options
author | Krzysztof Drewniak <krzysdrewniak@gmail.com> | 2024-12-23 22:01:09 -0600 |
---|---|---|
committer | Krzysztof Drewniak <krzysdrewniak@gmail.com> | 2024-12-23 22:01:34 -0600 |
commit | 0fb33d53c2d8f32776664c6bae31a2f49399031f (patch) | |
tree | cb20f857157728c90200ab17eb0d34ab077be9fe | |
parent | 5d1c98f9a5060f51be27d8a56ddf7ec62b19ab0d (diff) | |
download | llvm-users/krzysz00/enum-prop-properly.zip llvm-users/krzysz00/enum-prop-properly.tar.gz llvm-users/krzysz00/enum-prop-properly.tar.bz2 |
Compiles, LLVM attr syntax is wrong, will fix laterusers/krzysz00/enum-prop-properly
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 3 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 8 | ||||
-rw-r--r-- | mlir/include/mlir/IR/EnumAttr.td | 60 | ||||
-rw-r--r-- | mlir/include/mlir/IR/ODSSupport.h | 60 | ||||
-rw-r--r-- | mlir/include/mlir/IR/Properties.td | 19 | ||||
-rw-r--r-- | mlir/lib/IR/ODSSupport.cpp | 90 |
6 files changed, 217 insertions, 23 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index c08b75d..570d1e8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -504,6 +504,9 @@ def LLVM_IntegerOverflowFlagsAttr : let assemblyFormat = "`<` $value `>`"; } +def LLVM_IntegerOverflowFlagsProp : + EnumProp<IntegerOverflowFlags, "none">; + //===----------------------------------------------------------------------===// // FastmathFlags //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index a3aa53b..f0fc966 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -60,7 +60,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName, list<Trait> traits = []> : LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName, !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> { - dag iofArg = (ins EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags); + dag iofArg = (ins LLVM_IntegerOverflowFlagsProp:$overflowFlags); let arguments = !con(commonArgs, iofArg); string mlirBuilder = [{ @@ -69,7 +69,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName, $res = op; }]; let assemblyFormat = [{ - $lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($res) + $lhs `,` $rhs (`overflow` `` `<` $overflowFlags^ `>`)? attr-dict `:` type($res) }]; string llvmBuilder = "$res = builder.Create" # instName # @@ -558,10 +558,10 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type, Type resultType, list<Trait> traits = []> : LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)>, LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> { - let arguments = (ins type:$arg, EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags); + let arguments = (ins type:$arg, LLVM_IntegerOverflowFlagsProp:$overflowFlags); let results = (outs resultType:$res); let builders = [LLVM_OneResultOpBuilder]; - let assemblyFormat = "$arg `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($arg) `to` type($res)"; + let assemblyFormat = "$arg (`overflow` `` `<` $overflowFlags^ `>`)? attr-dict `:` type($arg) `to` type($res)"; string llvmInstName = instName; string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>( diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td index 9fec28f..227ac2c 100644 --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -10,6 +10,7 @@ #define ENUMATTR_TD include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/Properties.td" //===----------------------------------------------------------------------===// // Enum attribute kinds @@ -417,6 +418,65 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "", let assemblyFormat = "$value"; } +class _enumPropPred<string underlyingType, list<EnumAttrCaseInfo> cases> { + Pred ret = Or<!foreach(c, cases, + CPred<c.value # " == static_cast<" # underlyingType # ">($_self))">)>; +} +class _bitEnumPropPred<string underlyingType, list<EnumAttrCaseInfo> cases> { + Pred ret = CPred<"0 == (static_cast<"# underlyingType # ">($_self) & ~(" + # !foldl(0, cases, acc, c, !or(acc, c.value)) # "))">; +} + +/// A property backed by a C++ enum, analogous to `EnumAttr`. +/// +/// When converted to an attribute or stored in bitcode, the enum is +/// converted to its underlying integer type. +class EnumProp<EnumAttrInfo enumInfo, string defaultCase = ""> : + Property<enumInfo.cppNamespace # "::" # enumInfo.className, enumInfo.summary> { + let description = enumInfo.description; + + EnumAttrInfo enumAttrInfo = enumInfo; + + let convertFromAttribute = + enumInfo.underlyingType # " intStorage;" # [{ + if (failed(convertFromAttribute(intStorage, $_attr, $_diag))) + return ::mlir::failure(); + $_storage = static_cast<}] # storageType # [{>(intStorage); + return ::mlir::success(); + }]; + let convertToAttribute = [{ + return convertToAttribute($_ctxt, static_cast<}] # enumInfo.underlyingType # [{>($_storage)); + }]; + + // The bytecode reader always wants things as uint64_t, so use that + // no matter what the underlying type is. + let writeToMlirBytecode = [{ + $_writer.writeVarInt(static_cast<uint64_t>($_storage)); + }]; + let readFromMlirBytecode = [{ + uint64_t val; + if (failed($_reader.readVarInt(val))) + return ::mlir::failure(); + $_storage = static_cast<}] # storageType # [{>(val); + }]; + + let predicate = + !if(!isa<BitEnumAttr>(enumInfo), + _bitEnumPropPred<enumInfo.underlyingType, enumInfo.enumerants>.ret, + _enumPropPred<enumInfo.underlyingType, enumInfo.enumerants>.ret); + + let optionalParser = [{ + auto value = ::mlir::FieldParser<std::optional<}] # storageType # [{>>::parse($_parser); + if (::mlir::failed(value)) + return ::mlir::failure(); + if (!value->has_value()) + return std::nullopt; + $_storage = **value; + }]; + + let defaultValue = !if(!ne(defaultCase, ""), storageType # "::" # defaultCase, ""); +} + class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> { defvar cases = !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case)); diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h index 25d6f3d..61bd0f8 100644 --- a/mlir/include/mlir/IR/ODSSupport.h +++ b/mlir/include/mlir/IR/ODSSupport.h @@ -43,6 +43,66 @@ convertFromAttribute(int32_t &storage, Attribute attr, /// Convert the provided int32_t to an IntegerAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, int32_t storage); +/// Convert an IntegerAttr attribute to an int16_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(int16_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError); + +/// Convert the provided int16_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, int16_t storage); + +/// Convert an IntegerAttr attribute to an int8_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(int8_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError); + +/// Convert the provided int8_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, int8_t storage); + +/// Convert an IntegerAttr attribute to a uint64_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(uint64_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError); + +/// Convert the provided unt64_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, uint64_t storage); + +/// Convert an IntegerAttr attribute to a uint32_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(uint32_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError); + +/// Convert the provided uint32_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, uint32_t storage); + +/// Convert an IntegerAttr attribute to a uint16_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(uint16_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError); + +/// Convert the provided uint16_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, uint16_t storage); + +/// Convert an IntegerAttr attribute to a uint8_t, or return an error if the +/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an +/// error message is also emitted. +LogicalResult +convertFromAttribute(uint8_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError); + +/// Convert the provided uint8_t to an IntegerAttr attribute. +Attribute convertToAttribute(MLIRContext *ctx, uint8_t storage); + /// Extract the string from `attr` into `storage`. If `attr` is not a /// `StringAttr`, return failure and emit an error into the diagnostic from /// `emitError`. diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td index 212b858..ef48d6e 100644 --- a/mlir/include/mlir/IR/Properties.td +++ b/mlir/include/mlir/IR/Properties.td @@ -238,25 +238,6 @@ def I64Prop : IntProp<"int64_t">; def I32Property : IntProp<"int32_t">, Deprecated<"moved to shorter name I32Prop">; def I64Property : IntProp<"int64_t">, Deprecated<"moved to shorter name I64Prop">; -class EnumProp<string storageTypeParam, string desc = "", string default = ""> : - Property<storageTypeParam, desc> { - // TODO: implement predicate for enum validity. - let writeToMlirBytecode = [{ - $_writer.writeVarInt(static_cast<uint64_t>($_storage)); - }]; - let readFromMlirBytecode = [{ - uint64_t val; - if (failed($_reader.readVarInt(val))) - return ::mlir::failure(); - $_storage = static_cast<}] # storageTypeParam # [{>(val); - }]; - let defaultValue = default; -} - -class EnumProperty<string storageTypeParam, string desc = "", string default = ""> : - EnumProp<storageTypeParam, desc, default>, - Deprecated<"moved to shorter name EnumProp">; - // Note: only a class so we can deprecate the old name class _cls_StringProp : Property<"std::string", "string"> { let interfaceType = "::llvm::StringRef"; diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp index d56c75e..e569a7a 100644 --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -49,6 +49,96 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, int32_t storage) { } LogicalResult +mlir::convertFromAttribute(int16_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError) { + auto valueAttr = dyn_cast<IntegerAttr>(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getSExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, int16_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 16), storage); +} + +LogicalResult +mlir::convertFromAttribute(int8_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError) { + auto valueAttr = dyn_cast<IntegerAttr>(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getSExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, int8_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 8), storage); +} + +LogicalResult +mlir::convertFromAttribute(uint64_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError) { + auto valueAttr = dyn_cast<IntegerAttr>(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getZExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, uint64_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 64), storage); +} + +LogicalResult +mlir::convertFromAttribute(uint32_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError) { + auto valueAttr = dyn_cast<IntegerAttr>(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getZExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, uint32_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 32), storage); +} + +LogicalResult +mlir::convertFromAttribute(uint16_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError) { + auto valueAttr = dyn_cast<IntegerAttr>(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getZExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, uint16_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 16), storage); +} + +LogicalResult +mlir::convertFromAttribute(uint8_t &storage, Attribute attr, + function_ref<InFlightDiagnostic()> emitError) { + auto valueAttr = dyn_cast<IntegerAttr>(attr); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + storage = valueAttr.getValue().getZExtValue(); + return success(); +} +Attribute mlir::convertToAttribute(MLIRContext *ctx, uint8_t storage) { + return IntegerAttr::get(IntegerType::get(ctx, 8), storage); +} + +LogicalResult mlir::convertFromAttribute(std::string &storage, Attribute attr, function_ref<InFlightDiagnostic()> emitError) { auto valueAttr = dyn_cast<StringAttr>(attr); |