aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKrzysztof Drewniak <krzysdrewniak@gmail.com>2024-12-23 22:01:09 -0600
committerKrzysztof Drewniak <krzysdrewniak@gmail.com>2024-12-23 22:01:34 -0600
commit0fb33d53c2d8f32776664c6bae31a2f49399031f (patch)
treecb20f857157728c90200ab17eb0d34ab077be9fe
parent5d1c98f9a5060f51be27d8a56ddf7ec62b19ab0d (diff)
downloadllvm-users/krzysz00/enum-prop-properly.zip
llvm-users/krzysz00/enum-prop-properly.tar.gz
llvm-users/krzysz00/enum-prop-properly.tar.bz2
Compiles, LLVM attr syntax is wrong, will fix laterusers/krzysz00/enum-prop-properly
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td3
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td8
-rw-r--r--mlir/include/mlir/IR/EnumAttr.td60
-rw-r--r--mlir/include/mlir/IR/ODSSupport.h60
-rw-r--r--mlir/include/mlir/IR/Properties.td19
-rw-r--r--mlir/lib/IR/ODSSupport.cpp90
6 files changed, 217 insertions, 23 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index c08b75d..570d1e8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -504,6 +504,9 @@ def LLVM_IntegerOverflowFlagsAttr :
let assemblyFormat = "`<` $value `>`";
}
+def LLVM_IntegerOverflowFlagsProp :
+ EnumProp<IntegerOverflowFlags, "none">;
+
//===----------------------------------------------------------------------===//
// FastmathFlags
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index a3aa53b..f0fc966 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -60,7 +60,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
- dag iofArg = (ins EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
+ dag iofArg = (ins LLVM_IntegerOverflowFlagsProp:$overflowFlags);
let arguments = !con(commonArgs, iofArg);
string mlirBuilder = [{
@@ -69,7 +69,7 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
$res = op;
}];
let assemblyFormat = [{
- $lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($res)
+ $lhs `,` $rhs (`overflow` `` `<` $overflowFlags^ `>`)? attr-dict `:` type($res)
}];
string llvmBuilder =
"$res = builder.Create" # instName #
@@ -558,10 +558,10 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
Type resultType, list<Trait> traits = []> :
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)>,
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> {
- let arguments = (ins type:$arg, EnumProp<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
+ let arguments = (ins type:$arg, LLVM_IntegerOverflowFlagsProp:$overflowFlags);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
- let assemblyFormat = "$arg `` custom<OverflowFlags>($overflowFlags) attr-dict `:` type($arg) `to` type($res)";
+ let assemblyFormat = "$arg (`overflow` `` `<` $overflowFlags^ `>`)? attr-dict `:` type($arg) `to` type($res)";
string llvmInstName = instName;
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>(
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 9fec28f..227ac2c 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -10,6 +10,7 @@
#define ENUMATTR_TD
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/Properties.td"
//===----------------------------------------------------------------------===//
// Enum attribute kinds
@@ -417,6 +418,65 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
let assemblyFormat = "$value";
}
+class _enumPropPred<string underlyingType, list<EnumAttrCaseInfo> cases> {
+ Pred ret = Or<!foreach(c, cases,
+ CPred<c.value # " == static_cast<" # underlyingType # ">($_self))">)>;
+}
+class _bitEnumPropPred<string underlyingType, list<EnumAttrCaseInfo> cases> {
+ Pred ret = CPred<"0 == (static_cast<"# underlyingType # ">($_self) & ~("
+ # !foldl(0, cases, acc, c, !or(acc, c.value)) # "))">;
+}
+
+/// A property backed by a C++ enum, analogous to `EnumAttr`.
+///
+/// When converted to an attribute or stored in bitcode, the enum is
+/// converted to its underlying integer type.
+class EnumProp<EnumAttrInfo enumInfo, string defaultCase = ""> :
+ Property<enumInfo.cppNamespace # "::" # enumInfo.className, enumInfo.summary> {
+ let description = enumInfo.description;
+
+ EnumAttrInfo enumAttrInfo = enumInfo;
+
+ let convertFromAttribute =
+ enumInfo.underlyingType # " intStorage;" # [{
+ if (failed(convertFromAttribute(intStorage, $_attr, $_diag)))
+ return ::mlir::failure();
+ $_storage = static_cast<}] # storageType # [{>(intStorage);
+ return ::mlir::success();
+ }];
+ let convertToAttribute = [{
+ return convertToAttribute($_ctxt, static_cast<}] # enumInfo.underlyingType # [{>($_storage));
+ }];
+
+ // The bytecode reader always wants things as uint64_t, so use that
+ // no matter what the underlying type is.
+ let writeToMlirBytecode = [{
+ $_writer.writeVarInt(static_cast<uint64_t>($_storage));
+ }];
+ let readFromMlirBytecode = [{
+ uint64_t val;
+ if (failed($_reader.readVarInt(val)))
+ return ::mlir::failure();
+ $_storage = static_cast<}] # storageType # [{>(val);
+ }];
+
+ let predicate =
+ !if(!isa<BitEnumAttr>(enumInfo),
+ _bitEnumPropPred<enumInfo.underlyingType, enumInfo.enumerants>.ret,
+ _enumPropPred<enumInfo.underlyingType, enumInfo.enumerants>.ret);
+
+ let optionalParser = [{
+ auto value = ::mlir::FieldParser<std::optional<}] # storageType # [{>>::parse($_parser);
+ if (::mlir::failed(value))
+ return ::mlir::failure();
+ if (!value->has_value())
+ return std::nullopt;
+ $_storage = **value;
+ }];
+
+ let defaultValue = !if(!ne(defaultCase, ""), storageType # "::" # defaultCase, "");
+}
+
class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
defvar cases =
!filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h
index 25d6f3d..61bd0f8 100644
--- a/mlir/include/mlir/IR/ODSSupport.h
+++ b/mlir/include/mlir/IR/ODSSupport.h
@@ -43,6 +43,66 @@ convertFromAttribute(int32_t &storage, Attribute attr,
/// Convert the provided int32_t to an IntegerAttr attribute.
Attribute convertToAttribute(MLIRContext *ctx, int32_t storage);
+/// Convert an IntegerAttr attribute to an int16_t, or return an error if the
+/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
+/// error message is also emitted.
+LogicalResult
+convertFromAttribute(int16_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError);
+
+/// Convert the provided int16_t to an IntegerAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, int16_t storage);
+
+/// Convert an IntegerAttr attribute to an int8_t, or return an error if the
+/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
+/// error message is also emitted.
+LogicalResult
+convertFromAttribute(int8_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError);
+
+/// Convert the provided int8_t to an IntegerAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, int8_t storage);
+
+/// Convert an IntegerAttr attribute to a uint64_t, or return an error if the
+/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
+/// error message is also emitted.
+LogicalResult
+convertFromAttribute(uint64_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError);
+
+/// Convert the provided unt64_t to an IntegerAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, uint64_t storage);
+
+/// Convert an IntegerAttr attribute to a uint32_t, or return an error if the
+/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
+/// error message is also emitted.
+LogicalResult
+convertFromAttribute(uint32_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError);
+
+/// Convert the provided uint32_t to an IntegerAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, uint32_t storage);
+
+/// Convert an IntegerAttr attribute to a uint16_t, or return an error if the
+/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
+/// error message is also emitted.
+LogicalResult
+convertFromAttribute(uint16_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError);
+
+/// Convert the provided uint16_t to an IntegerAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, uint16_t storage);
+
+/// Convert an IntegerAttr attribute to a uint8_t, or return an error if the
+/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
+/// error message is also emitted.
+LogicalResult
+convertFromAttribute(uint8_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError);
+
+/// Convert the provided uint8_t to an IntegerAttr attribute.
+Attribute convertToAttribute(MLIRContext *ctx, uint8_t storage);
+
/// Extract the string from `attr` into `storage`. If `attr` is not a
/// `StringAttr`, return failure and emit an error into the diagnostic from
/// `emitError`.
diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td
index 212b858..ef48d6e 100644
--- a/mlir/include/mlir/IR/Properties.td
+++ b/mlir/include/mlir/IR/Properties.td
@@ -238,25 +238,6 @@ def I64Prop : IntProp<"int64_t">;
def I32Property : IntProp<"int32_t">, Deprecated<"moved to shorter name I32Prop">;
def I64Property : IntProp<"int64_t">, Deprecated<"moved to shorter name I64Prop">;
-class EnumProp<string storageTypeParam, string desc = "", string default = ""> :
- Property<storageTypeParam, desc> {
- // TODO: implement predicate for enum validity.
- let writeToMlirBytecode = [{
- $_writer.writeVarInt(static_cast<uint64_t>($_storage));
- }];
- let readFromMlirBytecode = [{
- uint64_t val;
- if (failed($_reader.readVarInt(val)))
- return ::mlir::failure();
- $_storage = static_cast<}] # storageTypeParam # [{>(val);
- }];
- let defaultValue = default;
-}
-
-class EnumProperty<string storageTypeParam, string desc = "", string default = ""> :
- EnumProp<storageTypeParam, desc, default>,
- Deprecated<"moved to shorter name EnumProp">;
-
// Note: only a class so we can deprecate the old name
class _cls_StringProp : Property<"std::string", "string"> {
let interfaceType = "::llvm::StringRef";
diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp
index d56c75e..e569a7a 100644
--- a/mlir/lib/IR/ODSSupport.cpp
+++ b/mlir/lib/IR/ODSSupport.cpp
@@ -49,6 +49,96 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, int32_t storage) {
}
LogicalResult
+mlir::convertFromAttribute(int16_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto valueAttr = dyn_cast<IntegerAttr>(attr);
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+ storage = valueAttr.getValue().getSExtValue();
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx, int16_t storage) {
+ return IntegerAttr::get(IntegerType::get(ctx, 16), storage);
+}
+
+LogicalResult
+mlir::convertFromAttribute(int8_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto valueAttr = dyn_cast<IntegerAttr>(attr);
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+ storage = valueAttr.getValue().getSExtValue();
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx, int8_t storage) {
+ return IntegerAttr::get(IntegerType::get(ctx, 8), storage);
+}
+
+LogicalResult
+mlir::convertFromAttribute(uint64_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto valueAttr = dyn_cast<IntegerAttr>(attr);
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+ storage = valueAttr.getValue().getZExtValue();
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx, uint64_t storage) {
+ return IntegerAttr::get(IntegerType::get(ctx, 64), storage);
+}
+
+LogicalResult
+mlir::convertFromAttribute(uint32_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto valueAttr = dyn_cast<IntegerAttr>(attr);
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+ storage = valueAttr.getValue().getZExtValue();
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx, uint32_t storage) {
+ return IntegerAttr::get(IntegerType::get(ctx, 32), storage);
+}
+
+LogicalResult
+mlir::convertFromAttribute(uint16_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto valueAttr = dyn_cast<IntegerAttr>(attr);
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+ storage = valueAttr.getValue().getZExtValue();
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx, uint16_t storage) {
+ return IntegerAttr::get(IntegerType::get(ctx, 16), storage);
+}
+
+LogicalResult
+mlir::convertFromAttribute(uint8_t &storage, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto valueAttr = dyn_cast<IntegerAttr>(attr);
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+ storage = valueAttr.getValue().getZExtValue();
+ return success();
+}
+Attribute mlir::convertToAttribute(MLIRContext *ctx, uint8_t storage) {
+ return IntegerAttr::get(IntegerType::get(ctx, 8), storage);
+}
+
+LogicalResult
mlir::convertFromAttribute(std::string &storage, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
auto valueAttr = dyn_cast<StringAttr>(attr);