diff options
author | Vladislav Vinogradov <vlad.vinogradov@intel.com> | 2021-02-27 15:21:00 +0300 |
---|---|---|
committer | Vladislav Vinogradov <vlad.vinogradov@intel.com> | 2021-03-17 16:44:24 +0300 |
commit | fee90542326bc1d81ba684bfc0a2cd21cb04e650 (patch) | |
tree | a264cd57123f92e3a4e30171d57d09bfbeaae9c6 | |
parent | 2571a0936719b50facfee492ccbaf4916272be36 (diff) | |
download | llvm-fee90542326bc1d81ba684bfc0a2cd21cb04e650.zip llvm-fee90542326bc1d81ba684bfc0a2cd21cb04e650.tar.gz llvm-fee90542326bc1d81ba684bfc0a2cd21cb04e650.tar.bz2 |
[mlir][ODS] Support specialized Attribute class for Enums
Add a feature to `EnumAttr` definition to generate
specialized Attribute class for the particular enumeration.
This class will inherit `StringAttr` or `IntegerAttr` and
will override `classof` and `getValue` methods.
With this class the enumeration predicate can be checked with simple
RTTI calls (`isa`, `dyn_cast`) and it will return the typed enumeration
directly instead of raw string/integer.
Based on the following discussion:
https://llvm.discourse.group/t/rfc-add-enum-attribute-decorator-class/2252
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D97836
20 files changed, 266 insertions, 96 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index ddc0ed3..3623565 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -200,7 +200,7 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> { OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1), - $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs); + predicate, lhs, rhs); }]>]; let parser = [{ return parseCmpOp<ICmpPredicate>(parser, result); }]; let printer = [{ printICmpOp(p, *this); }]; @@ -246,14 +246,6 @@ def LLVM_FCmpOp : LLVM_Op<"fcmp", [ let llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; - let builders = [ - OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs, - CArg<"FastmathFlags", "{}">:$fmf), - [{ - build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1), - $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs, - ::mlir::LLVM::FMFAttr::get($_builder.getContext(), fmf)); - }]>]; let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }]; let printer = [{ printFCmpOp(p, *this); }]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h index e5838ef..ac128ac 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVEnums.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_ #define MLIR_DIALECT_SPIRV_IR_SPIRVENUMS_H_ +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/StringRef.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td index 0cffdb5..77fa63f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td @@ -184,7 +184,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> { let builders = [ OpBuilder<(ins "Value":$basePtr, - CArg<"IntegerAttr", "{}">:$memory_access, + CArg<"MemoryAccessAttr", "{}">:$memory_access, CArg<"IntegerAttr", "{}">:$alignment)> ]; } diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 69fb073..f26a771 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -53,6 +53,7 @@ def CombiningKind : BitEnumAttr< COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR, COMBINING_KIND_XOR]> { let cppNamespace = "::mlir::vector"; + let genSpecializedAttr = 0; } def Vector_CombiningKindAttr : DialectAttr< diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 268056d..bdae05f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1142,7 +1142,9 @@ class BitEnumAttrCase<string sym, int val, string str = sym> : } // Additional information for an enum attribute. -class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> { +class EnumAttrInfo< + string name, list<EnumAttrCaseInfo> cases, Attr baseClass> : + Attr<baseClass.predicate, baseClass.summary> { // The C++ enum class name string className = name; @@ -1188,6 +1190,28 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> { // static constexpr unsigned <fn-name>(); // ``` string maxEnumValFnName = "getMaxEnumValFor" # name; + + // Generate specialized Attribute class + bit genSpecializedAttr = 1; + // The underlying Attribute class, which holds the enum value + Attr baseAttrClass = baseClass; + // The name of specialized Enum Attribute class + string specializedAttrClassName = name # Attr; + + // Override Attr class fields for specialized class + let predicate = !if(genSpecializedAttr, + CPred<"$_self.isa<" # cppNamespace # "::" # specializedAttrClassName # ">()">, + baseAttrClass.predicate); + let storageType = !if(genSpecializedAttr, + cppNamespace # "::" # specializedAttrClassName, + baseAttrClass.storageType); + let returnType = !if(genSpecializedAttr, + cppNamespace # "::" # className, + baseAttrClass.returnType); + let constBuilderCall = !if(genSpecializedAttr, + cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)", + baseAttrClass.constBuilderCall); + let valueType = baseAttrClass.valueType; } // An enum attribute backed by StringAttr. @@ -1195,47 +1219,44 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> { // Op attributes of this kind are stored as StringAttr. Extra verification will // be generated on the string though: only the symbols of the allowed cases are // permitted as the string value. -class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> - : EnumAttrInfo<name, cases>, +class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> : + EnumAttrInfo<name, cases, StringBasedAttr< And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>, !if(!empty(summary), "allowed string cases: " # !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "), - summary)>; + summary)>> { + // Disable specialized Attribute class for `StringAttr` backend by default. + let genSpecializedAttr = 0; +} // An enum attribute backed by IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will // be generated on the integer though: only the values of the allowed cases are // permitted as the integer value. -class IntEnumAttr<I intType, string name, string summary, - list<IntEnumAttrCaseBase> cases> : - EnumAttrInfo<name, cases>, - SignlessIntegerAttrBase<intType, - !if(!empty(summary), "allowed " # intType.summary # " cases: " # - !interleave(!foreach(case, cases, case.value), ", "), summary)> { +class IntEnumAttrBase<I intType, list<IntEnumAttrCaseBase> cases, string summary> : + SignlessIntegerAttrBase<intType, summary> { let predicate = And<[ - SignlessIntegerAttrBase<intType, "">.predicate, + SignlessIntegerAttrBase<intType, summary>.predicate, Or<!foreach(case, cases, case.predicate)>]>; } -class I32EnumAttr<string name, string summary, - list<I32EnumAttrCase> cases> : +class IntEnumAttr<I intType, string name, string summary, + list<IntEnumAttrCaseBase> cases> : + EnumAttrInfo<name, cases, + IntEnumAttrBase<intType, cases, + !if(!empty(summary), "allowed " # intType.summary # " cases: " # + !interleave(!foreach(case, cases, case.value), ", "), + summary)>>; + +class I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> : IntEnumAttr<I32, name, summary, cases> { - let returnType = cppNamespace # "::" # name; let underlyingType = "uint32_t"; - let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; - let constBuilderCall = - "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; } -class I64EnumAttr<string name, string summary, - list<I64EnumAttrCase> cases> : +class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> : IntEnumAttr<I64, name, summary, cases> { - let returnType = cppNamespace # "::" # name; let underlyingType = "uint64_t"; - let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; - let constBuilderCall = - "$_builder.getI64IntegerAttr(static_cast<int64_t>($0))"; } // A bit enum stored with 32-bit IntegerAttr. @@ -1244,9 +1265,8 @@ class I64EnumAttr<string name, string summary, // be generated on the integer to make sure only allowed bit are set. Besides, // helper methods are generated to parse a string separated with a specified // delimiter to a symbol and vice versa. -class BitEnumAttr<string name, string summary, - list<BitEnumAttrCase> cases> : - EnumAttrInfo<name, cases>, SignlessIntegerAttrBase<I32, summary> { +class BitEnumAttrBase<list<BitEnumAttrCase> cases, string summary> : + SignlessIntegerAttrBase<I32, summary> { let predicate = And<[ I32Attr.predicate, // Make sure we don't have unknown bit set. @@ -1254,12 +1274,11 @@ class BitEnumAttr<string name, string summary, # !interleave(!foreach(case, cases, case.value # "u"), "|") # ")))"> ]>; +} - let returnType = cppNamespace # "::" # name; +class BitEnumAttr<string name, string summary, list<BitEnumAttrCase> cases> : + EnumAttrInfo<name, cases, BitEnumAttrBase<cases, summary>> { let underlyingType = "uint32_t"; - let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; - let constBuilderCall = - "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; // We need to return a string because we may concatenate symbols for multiple // bits together. diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index dc6c969..a8292a9 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -202,6 +202,10 @@ public: // Returns all allowed cases for this enum attribute. std::vector<EnumAttrCase> getAllCases() const; + + bool genSpecializedAttr() const; + llvm::Record *getBaseAttrClass() const; + StringRef getSpecializedAttrClassName() const; }; class StructFieldAttr { diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 8e7540f..19837fe6 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -155,9 +155,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands, // header to merge. scf::ForOpAdaptor forOperands(operands); auto loc = forOp.getLoc(); - auto loopControl = rewriter.getI32IntegerAttr( - static_cast<uint32_t>(spirv::LoopControl::None)); - auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl); + auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(); OpBuilder::InsertionGuard guard(rewriter); @@ -238,11 +236,9 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands, scf::IfOpAdaptor ifOperands(operands); auto loc = ifOp.getLoc(); - // Create `spv.mlir.selection` operation, selection header block and merge - // block. - auto selectionControl = rewriter.getI32IntegerAttr( - static_cast<uint32_t>(spirv::SelectionControl::None)); - auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl); + // Create `spv.selection` operation, selection header block and merge block. + auto selectionOp = + rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); auto *mergeBlock = rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); rewriter.create<spirv::MergeOp>(loc); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 871f54b..3a139b4 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -826,10 +826,8 @@ public: return failure(); rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( - operation, dstType, - rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), - operation.operand1(), operation.operand2(), - LLVM::FMFAttr::get(operation.getContext(), {})); + operation, dstType, predicate, operation.operand1(), + operation.operand2()); return success(); } }; @@ -849,9 +847,8 @@ public: return failure(); rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( - operation, dstType, - rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), - operation.operand1(), operation.operand2()); + operation, dstType, predicate, operation.operand1(), + operation.operand2()); return success(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 91e520e..2490f35 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3069,8 +3069,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> { rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()), - rewriter.getI64IntegerAttr(static_cast<int64_t>( - convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))), + convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()), transformed.lhs(), transformed.rhs()); return success(); @@ -3085,12 +3084,10 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> { ConversionPatternRewriter &rewriter) const override { CmpFOpAdaptor transformed(operands); - auto fmf = LLVM::FMFAttr::get(cmpfOp.getContext(), {}); rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), - rewriter.getI64IntegerAttr(static_cast<int64_t>( - convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))), - transformed.lhs(), transformed.rhs(), fmf); + convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()), + transformed.lhs(), transformed.rhs()); return success(); } diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index ed1b72c..025029a 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -1017,7 +1017,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create<spirv::LoadOp>( loc, dstType, adjustedPtr, - loadOp->getAttrOfType<IntegerAttr>( + loadOp->getAttrOfType<spirv::MemoryAccessAttr>( spirv::attributeName<spirv::MemoryAccess>()), loadOp->getAttrOfType<IntegerAttr>("alignment")); diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp index 6ccb59a..b032169 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp @@ -36,7 +36,7 @@ ParallelLoopDimMapping getParallelLoopDimMappingAttr(Processor processor, MLIRContext *context = map.getContext(); OpBuilder builder(context); return ParallelLoopDimMapping::get( - builder.getI64IntegerAttr(static_cast<int32_t>(processor)), + ProcessorAttr::get(builder.getContext(), processor), AffineMapAttr::get(map), AffineMapAttr::get(bound), context); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp index a289d9d..a851906 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVEnums.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/IR/BuiltinTypes.h" + #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index dd2dc3d..21bcfe4 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1659,7 +1659,7 @@ void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, spirv::FuncOp function, ArrayRef<Attribute> interfaceVars) { build(builder, state, - builder.getI32IntegerAttr(static_cast<int32_t>(executionModel)), + spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), builder.getSymbolRefAttr(function), builder.getArrayAttr(interfaceVars)); } @@ -1721,7 +1721,7 @@ void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, spirv::ExecutionMode executionMode, ArrayRef<int32_t> params) { build(builder, state, builder.getSymbolRefAttr(function), - builder.getI32IntegerAttr(static_cast<int32_t>(executionMode)), + spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), builder.getI32ArrayAttr(params)); } @@ -2243,10 +2243,10 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) { //===----------------------------------------------------------------------===// void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, - Value basePtr, IntegerAttr memory_access, + Value basePtr, MemoryAccessAttr memoryAccess, IntegerAttr alignment) { auto ptrType = basePtr.getType().cast<spirv::PointerType>(); - build(builder, state, ptrType.getPointeeType(), basePtr, memory_access, + build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, alignment); } @@ -2784,9 +2784,8 @@ void spirv::SelectionOp::addMergeBlock() { spirv::SelectionOp spirv::SelectionOp::createIfThen( Location loc, Value condition, function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) { - auto selectionControl = builder.getI32IntegerAttr( - static_cast<uint32_t>(spirv::SelectionControl::None)); - auto selectionOp = builder.create<spirv::SelectionOp>(loc, selectionControl); + auto selectionOp = + builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(); Block *mergeBlock = selectionOp.getMergeBlock(); diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 99d9d8a..3b949b0 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -231,6 +231,18 @@ std::vector<EnumAttrCase> EnumAttr::getAllCases() const { return cases; } +bool EnumAttr::genSpecializedAttr() const { + return def->getValueAsBit("genSpecializedAttr"); +} + +llvm::Record *EnumAttr::getBaseAttrClass() const { + return def->getValueAsDef("baseAttrClass"); +} + +StringRef EnumAttr::getSpecializedAttrClassName() const { + return def->getValueAsString("specializedAttrClassName"); +} + StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { assert(def->isSubClassOf("StructFieldAttr") && "must be subclass of TableGen 'StructFieldAttr' class"); diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index 06e7f81..6137fee3 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -331,7 +331,8 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { return emitError(unknownLoc, "missing Execution Model specification in OpEntryPoint"); } - auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]); + auto execModel = spirv::ExecutionModelAttr::get( + context, static_cast<spirv::ExecutionModel>(words[wordIndex++])); if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing <id> in OpEntryPoint"); } @@ -383,7 +384,8 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) { if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); } - auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); + auto execMode = spirv::ExecutionModeAttr::get( + context, static_cast<spirv::ExecutionMode>(words[wordIndex++])); // Get the values SmallVector<Attribute, 4> attrListElems; @@ -417,8 +419,11 @@ Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) { argAttrs.push_back(argAttr); } - opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0], - argAttrs[1], argAttrs[2]); + opBuilder.create<spirv::ControlBarrierOp>( + unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), + argAttrs[1].cast<spirv::ScopeAttr>(), + argAttrs[2].cast<spirv::MemorySemanticsAttr>()); + return success(); } @@ -483,8 +488,9 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) { argAttrs.push_back(argAttr); } - opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0], - argAttrs[1]); + opBuilder.create<spirv::MemoryBarrierOp>( + unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), + argAttrs[1].cast<spirv::MemorySemanticsAttr>()); return success(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 171d9b7..c54c168 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1640,7 +1640,7 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { // merge block so that the newly created SelectionOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - auto control = builder.getI32IntegerAttr(selectionControl); + auto control = static_cast<spirv::SelectionControl>(selectionControl); auto selectionOp = builder.create<spirv::SelectionOp>(location, control); selectionOp.addMergeBlock(); @@ -1652,7 +1652,7 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { // merge block so that the newly created LoopOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - auto control = builder.getI32IntegerAttr(loopControl); + auto control = static_cast<spirv::LoopControl>(loopControl); auto loopOp = builder.create<spirv::LoopOp>(location, control); loopOp.addEntryAndMergeBlock(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 1968ebd..b8956e4 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1052,39 +1052,39 @@ def OneResultOp3 : TEST_Op<"one_result3"> { } // Test using multi-result op as a whole -def : Pat<(ThreeResultOp MultiResultOpKind1), - (AnotherThreeResultOp MultiResultOpKind1)>; +def : Pat<(ThreeResultOp MultiResultOpKind1:$kind), + (AnotherThreeResultOp $kind)>; // Test using multi-result op as a whole for partial replacement -def : Pattern<(ThreeResultOp MultiResultOpKind2), - [(TwoResultOp MultiResultOpKind2), - (OneResultOp1 MultiResultOpKind2)]>; -def : Pattern<(ThreeResultOp MultiResultOpKind3), - [(OneResultOp2 MultiResultOpKind3), - (AnotherTwoResultOp MultiResultOpKind3)]>; +def : Pattern<(ThreeResultOp MultiResultOpKind2:$kind), + [(TwoResultOp $kind), + (OneResultOp1 $kind)]>; +def : Pattern<(ThreeResultOp MultiResultOpKind3:$kind), + [(OneResultOp2 $kind), + (AnotherTwoResultOp $kind)]>; // Test using results separately in a multi-result op -def : Pattern<(ThreeResultOp MultiResultOpKind4), - [(TwoResultOp:$res1__0 MultiResultOpKind4), - (OneResultOp1 MultiResultOpKind4), - (TwoResultOp:$res2__1 MultiResultOpKind4)]>; +def : Pattern<(ThreeResultOp MultiResultOpKind4:$kind), + [(TwoResultOp:$res1__0 $kind), + (OneResultOp1 $kind), + (TwoResultOp:$res2__1 $kind)]>; // Test referencing a single value in the value pack // This rule only matches TwoResultOp if its second result has no use. -def : Pattern<(TwoResultOp:$res MultiResultOpKind5), - [(OneResultOp2 MultiResultOpKind5), - (OneResultOp1 MultiResultOpKind5)], +def : Pattern<(TwoResultOp:$res MultiResultOpKind5:$kind), + [(OneResultOp2 $kind), + (OneResultOp1 $kind)], [(HasNoUseOf:$res__1)]>; // Test using auxiliary ops for replacing multi-result op def : Pattern< - (ThreeResultOp MultiResultOpKind6), [ + (ThreeResultOp MultiResultOpKind6:$kind), [ // Auxiliary op generated to help building the final result but not // directly used to replace the source op's results. - (TwoResultOp:$interm MultiResultOpKind6), + (TwoResultOp:$interm $kind), (OneResultOp3 $interm__1), - (AnotherTwoResultOp MultiResultOpKind6) + (AnotherTwoResultOp $kind) ]>; //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index e207e31..aa8841a 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" @@ -22,12 +23,16 @@ using llvm::formatv; using llvm::isDigit; +using llvm::PrintFatalError; using llvm::raw_ostream; using llvm::Record; using llvm::RecordKeeper; using llvm::StringRef; +using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; using mlir::tblgen::EnumAttrCase; +using mlir::tblgen::FmtContext; +using mlir::tblgen::tgfmt; static std::string makeIdentifier(StringRef str) { if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) { @@ -303,6 +308,78 @@ static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef, << "}\n\n"; } +static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); + llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass(); + Attribute baseAttr(baseAttrDef); + + // Emit classof method + + os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n", + attrClassName); + + mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate(); + if (baseAttrPred.isNull()) + PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n"); + + std::string condition = baseAttrPred.getCondition(); + FmtContext verifyCtx; + verifyCtx.withSelf("attr"); + os << tgfmt(" return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx)); + + os << "}\n"; + + // Emit get method + + os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n", + attrClassName, enumName); + + if (enumAttr.isSubClassOf("StrEnumAttr")) { + os << formatv(" ::mlir::StringAttr baseAttr = " + "::mlir::StringAttr::get(context, {0}(val));\n", + symToStrFnName); + } else { + StringRef underlyingType = enumAttr.getUnderlyingType(); + + // Assuming that it is IntegerAttr constraint + int64_t bitwidth = 64; + if (baseAttrDef->getValue("valueType")) { + auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType"); + if (valueTypeDef->getValue("bitwidth")) + bitwidth = valueTypeDef->getValueAsInt("bitwidth"); + } + + os << formatv(" ::mlir::IntegerType intType = " + "::mlir::IntegerType::get(context, {0});\n", + bitwidth); + os << formatv(" ::mlir::IntegerAttr baseAttr = " + "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n", + underlyingType); + } + os << formatv(" return baseAttr.cast<{0}>();\n", attrClassName); + + os << "}\n"; + + // Emit getValue method + + os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName); + + if (enumAttr.isSubClassOf("StrEnumAttr")) { + os << formatv(" const auto res = {0}(::mlir::StringAttr::getValue());\n", + strToSymFnName); + os << " return res.getValue();\n"; + } else { + os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n", + enumName); + } + + os << "}\n"; +} + static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); @@ -391,6 +468,23 @@ inline ::llvm::Optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) { )"; os << formatv(symbolizeEnumStr, enumName, strToSymFnName); + const char *const attrClassDecl = R"( +class {1} : public ::mlir::{2} { +public: + using ValueType = {0}; + using ::mlir::{2}::{2}; + static bool classof(::mlir::Attribute attr); + static {1} get(::mlir::MLIRContext *context, {0} val); + {0} getValue() const; +}; +)"; + if (enumAttr.genSpecializedAttr()) { + StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); + StringRef baseAttrClassName = + enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr"; + os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); + } + for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; @@ -428,6 +522,9 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { emitUnderlyingToSymFnForIntEnum(enumDef, os); } + if (enumAttr.genSpecializedAttr()) + emitSpecializedAttrDef(enumDef, os); + for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; os << "\n"; diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp index a558019..a873658 100644 --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -6,21 +6,29 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" + #include "gmock/gmock.h" + #include <type_traits> /// Pull in generated enum utility declarations and definitions. #include "EnumsGenTest.h.inc" + #include "EnumsGenTest.cpp.inc" /// Test namespaces and enum class/utility names. using Outer::Inner::ConvertToEnum; using Outer::Inner::ConvertToString; using Outer::Inner::StrEnum; +using Outer::Inner::StrEnumAttr; TEST(EnumsGenTest, GeneratedStrEnumDefinition) { EXPECT_EQ(0u, static_cast<uint64_t>(StrEnum::CaseA)); @@ -110,3 +118,41 @@ TEST(EnumsGenTest, GeneratedCustomStringToSymbolFn) { auto none = symbolizePrettyIntEnum("Case1"); EXPECT_FALSE(none); } + +TEST(EnumsGenTest, GeneratedIntAttributeClass) { + mlir::MLIRContext ctx; + I32Enum rawVal = I32Enum::Case5; + + I32EnumAttr enumAttr = I32EnumAttr::get(&ctx, rawVal); + EXPECT_NE(enumAttr, nullptr); + EXPECT_EQ(enumAttr.getValue(), rawVal); + + mlir::Type intType = mlir::IntegerType::get(&ctx, 32); + mlir::Attribute intAttr = mlir::IntegerAttr::get(intType, 5); + EXPECT_TRUE(intAttr.isa<I32EnumAttr>()); + EXPECT_EQ(intAttr, enumAttr); +} + +TEST(EnumsGenTest, GeneratedStringAttributeClass) { + mlir::MLIRContext ctx; + StrEnum rawVal = StrEnum::CaseA; + + StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal); + EXPECT_NE(enumAttr, nullptr); + EXPECT_EQ(enumAttr.getValue(), rawVal); + + mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA"); + EXPECT_TRUE(strAttr.isa<StrEnumAttr>()); + EXPECT_EQ(strAttr, enumAttr); +} + +TEST(EnumsGenTest, GeneratedBitAttributeClass) { + mlir::MLIRContext ctx; + + mlir::Type intType = mlir::IntegerType::get(&ctx, 32); + mlir::Attribute intAttr = mlir::IntegerAttr::get( + intType, + static_cast<uint32_t>(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3)); + EXPECT_TRUE(intAttr.isa<BitEnumWithNoneAttr>()); + EXPECT_TRUE(intAttr.isa<BitEnumWithoutNoneAttr>()); +} diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td index b2c8f6f..cdcc182 100644 --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -15,6 +15,7 @@ def StrEnum: StrEnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> { let cppNamespace = "Outer::Inner"; let stringToSymbolFnName = "ConvertToEnum"; let symbolToStringFnName = "ConvertToString"; + let genSpecializedAttr = 1; } def Case5: I32EnumAttrCase<"Case5", 5>; |