diff options
author | Jeff Niu <jeff@modular.com> | 2024-04-19 09:23:00 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-19 09:23:00 -0700 |
commit | e553ac4d8148291914526f4f66f09e362ce0a63f (patch) | |
tree | 87ec0f996aa034e34484152055fef4ea64c434f4 | |
parent | d86079f93c9f59d31f2cebb55dce24783070bd77 (diff) | |
download | llvm-e553ac4d8148291914526f4f66f09e362ce0a63f.zip llvm-e553ac4d8148291914526f4f66f09e362ce0a63f.tar.gz llvm-e553ac4d8148291914526f4f66f09e362ce0a63f.tar.bz2 |
[mlir][llvm] Port `overflowFlags` to a native operation property (RELAND) (#89410)
This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.
Reland to fix flang
-rw-r--r-- | flang/lib/Optimizer/CodeGen/CodeGen.cpp | 10 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h | 22 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/LLVMCommon/Pattern.h | 14 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h | 16 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td | 76 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 23 | ||||
-rw-r--r-- | mlir/include/mlir/Target/LLVMIR/ModuleImport.h | 3 | ||||
-rw-r--r-- | mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 19 | ||||
-rw-r--r-- | mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp | 26 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 76 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 7 |
12 files changed, 183 insertions, 116 deletions
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index d909bda..921eac2 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -2110,9 +2110,8 @@ struct XArrayCoorOpConversion const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>(); TypePair baseBoxTyPair = baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{}; - mlir::LLVM::IntegerOverflowFlagsAttr nsw = - mlir::LLVM::IntegerOverflowFlagsAttr::get( - rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw); + mlir::LLVM::IntegerOverflowFlags nsw = + mlir::LLVM::IntegerOverflowFlags::nsw; // For each dimension of the array, generate the offset calculation. for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset, @@ -2396,9 +2395,8 @@ private: auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy); mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext()); mlir::Type byteTy = ::getI8Type(coor.getContext()); - mlir::LLVM::IntegerOverflowFlagsAttr nsw = - mlir::LLVM::IntegerOverflowFlagsAttr::get( - rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw); + mlir::LLVM::IntegerOverflowFlags nsw = + mlir::LLVM::IntegerOverflowFlags::nsw; for (unsigned i = 1, last = operands.size(); i < last; ++i) { if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) { diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h index 0891e2b..7ffc861 100644 --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr); LLVM::IntegerOverflowFlags convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); -/// Creates an LLVM overflow attribute from a given arithmetic overflow -/// attribute. -LLVM::IntegerOverflowFlagsAttr -convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr); - /// Creates an LLVM rounding mode enum value from a given arithmetic rounding /// mode enum value. LLVM::RoundingMode @@ -72,6 +67,9 @@ public: } ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } + LLVM::IntegerOverflowFlags getOverflowFlags() const { + return LLVM::IntegerOverflowFlags::none; + } private: NamedAttrList convertedAttr; @@ -89,19 +87,18 @@ public: // Get the name of the arith overflow attribute. StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); // Remove the source overflow attribute. - auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( - convertedAttr.erase(arithAttrName)); - if (arithAttr) { - StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName(); - convertedAttr.set(targetAttrName, - convertArithOverflowAttrToLLVM(arithAttr)); + if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( + convertedAttr.erase(arithAttrName))) { + overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); } } ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } + LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; } private: NamedAttrList convertedAttr; + LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none; }; template <typename SourceOp, typename TargetOp> @@ -132,6 +129,9 @@ public: } ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } + LLVM::IntegerOverflowFlags getOverflowFlags() const { + return LLVM::IntegerOverflowFlags::none; + } private: NamedAttrList convertedAttr; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index f362167..f3bf5b6 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -11,6 +11,7 @@ #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -18,13 +19,16 @@ class CallOpInterface; namespace LLVM { namespace detail { +/// Handle generically setting flags as native properties on LLVM operations. +void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags); + /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. -LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); +LogicalResult oneToOneRewrite( + Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef<NamedAttribute> targetAttrs, + const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, + IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); } // namespace detail } // namespace LLVM diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 279175b..9642815 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors( std::function<Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter); -LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); +LogicalResult vectorOneToOneRewrite( + Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef<NamedAttribute> targetAttrs, + const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, + IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none); } // namespace detail } // namespace LLVM @@ -70,6 +70,9 @@ public: AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; } + LLVM::IntegerOverflowFlags getOverflowFlags() const { + return LLVM::IntegerOverflowFlags::none; + } private: ArrayRef<NamedAttribute> srcAttrs; @@ -100,7 +103,8 @@ public: return LLVM::detail::vectorOneToOneRewrite( op, TargetOp::getOperationName(), adaptor.getOperands(), - attrConvert.getAttrs(), *this->getTypeConverter(), rewriter); + attrConvert.getAttrs(), *this->getTypeConverter(), rewriter, + attrConvert.getOverflowFlags()); } }; } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index cee752a..7085f81 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> { let description = [{ - Access to op integer overflow flags. + This interface defines an LLVM operation with integer overflow flags and + provides a uniform API for accessing them. }]; let cppNamespace = "::mlir::LLVM"; let methods = [ - InterfaceMethod< - /*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation", - /*returnType=*/ "IntegerOverflowFlagsAttr", - /*methodName=*/ "getOverflowAttr", - /*args=*/ (ins), - /*methodBody=*/ [{}], - /*defaultImpl=*/ [{ - auto op = cast<ConcreteOp>(this->getOperation()); - return op.getOverflowFlagsAttr(); - }] - >, - InterfaceMethod< - /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword", - /*returnType=*/ "bool", - /*methodName=*/ "hasNoUnsignedWrap", - /*args=*/ (ins), - /*methodBody=*/ [{}], - /*defaultImpl=*/ [{ - auto op = cast<ConcreteOp>(this->getOperation()); - IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); - return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw); - }] - >, - InterfaceMethod< - /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword", - /*returnType=*/ "bool", - /*methodName=*/ "hasNoSignedWrap", - /*args=*/ (ins), - /*methodBody=*/ [{}], - /*defaultImpl=*/ [{ - auto op = cast<ConcreteOp>(this->getOperation()); - IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); - return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw); - }] - >, - StaticInterfaceMethod< - /*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute - for the operation}], - /*returnType=*/ "StringRef", - /*methodName=*/ "getIntegerOverflowAttrName", - /*args=*/ (ins), - /*methodBody=*/ [{}], - /*defaultImpl=*/ [{ - return "overflowFlags"; - }] - > + InterfaceMethod<[{ + Get the integer overflow flags for the operation. + }], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{ + return $_op.getProperties().overflowFlags; + }]>, + InterfaceMethod<[{ + Set the integer overflow flags for the operation. + }], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{ + $_op.getProperties().overflowFlags = flags; + }]>, + InterfaceMethod<[{ + Returns whether the operation has the No Unsigned Wrap keyword. + }], "bool", "hasNoUnsignedWrap", (ins), [{}], [{ + return bitEnumContainsAll($_op.getOverflowFlags(), + IntegerOverflowFlags::nuw); + }]>, + InterfaceMethod<[{ + Returns whether the operation has the No Signed Wrap keyword. + }], "bool", "hasNoSignedWrap", (ins), [{}], [{ + return bitEnumContainsAll($_op.getOverflowFlags(), + IntegerOverflowFlags::nsw); + }]>, + StaticInterfaceMethod<[{ + Get the attribute name of the overflow flags property. + }], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{ + return "overflowFlags"; + }]>, ]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index f8f9264..eedae4b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -59,17 +59,30 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName, list<Trait> traits = []> : LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName, !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> { - dag iofArg = ( - ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags); + dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags); let arguments = !con(commonArgs, iofArg); + + let builders = [ + OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs, + "IntegerOverflowFlags":$overflowFlags), [{ + build($_builder, $_state, type, lhs, rhs); + $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags; + }]>, + OpBuilder<(ins "Value":$lhs, "Value":$rhs, + "IntegerOverflowFlags":$overflowFlags), [{ + build($_builder, $_state, lhs, rhs); + $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags; + }]> + ]; + string mlirBuilder = [{ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); - moduleImport.setIntegerOverflowFlagsAttr(inst, op); + moduleImport.setIntegerOverflowFlags(inst, op); $res = op; }]; let assemblyFormat = [{ - $lhs `,` $rhs (`overflow` `` $overflowFlags^)? - custom<LLVMOpAttrs>(attr-dict) `:` type($res) + $lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags) + `` custom<LLVMOpAttrs>(attr-dict) `:` type($res) }]; string llvmBuilder = "$res = builder.Create" # instName # diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 1a188b1..04d098d 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -183,8 +183,7 @@ public: /// Sets the integer overflow flags (nsw/nuw) attribute for the imported /// operation `op` given the original instruction `inst`. Asserts if the /// operation does not implement the integer overflow flag interface. - void setIntegerOverflowFlagsAttr(llvm::Instruction *inst, - Operation *op) const; + void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const; /// Sets the fastmath flags attribute for the imported operation `op` given /// the original instruction `inst`. Asserts if the operation does not diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp index f12eba9..cf60a04 100644 --- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp +++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp @@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM( return llvmFlags; } -LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM( - arith::IntegerOverflowFlagsAttr flagsAttr) { - arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue(); - return LLVM::IntegerOverflowFlagsAttr::get( - flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags)); -} - LLVM::RoundingMode mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) { switch (roundingMode) { diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 83c31a2..1886dfa 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Detail methods //===----------------------------------------------------------------------===// +void LLVM::detail::setNativeProperties(Operation *op, + IntegerOverflowFlags overflowFlags) { + if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) + iface.setOverflowFlags(overflowFlags); +} + /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. -LogicalResult -LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { +LogicalResult LLVM::detail::oneToOneRewrite( + Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef<NamedAttribute> targetAttrs, + const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, + IntegerOverflowFlags overflowFlags) { unsigned numResults = op->getNumResults(); SmallVector<Type> resultTypes; @@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp, rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, resultTypes, targetAttrs); + setNativeProperties(newOp, overflowFlags); + // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), success(); diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index 544bcc7..626135c 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( return success(); } -LogicalResult -LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { +LogicalResult LLVM::detail::vectorOneToOneRewrite( + Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef<NamedAttribute> targetAttrs, + const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, + IntegerOverflowFlags overflowFlags) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp, auto llvmNDVectorTy = operands[0].getType(); if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, - rewriter); + rewriter, overflowFlags); - auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy, - ValueRange operands) { - return rewriter - .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - llvm1DVectorTy, targetAttrs) - ->getResult(0); + auto callback = [op, targetOp, targetAttrs, overflowFlags, + &rewriter](Type llvm1DVectorTy, ValueRange operands) { + Operation *newOp = + rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), + operands, llvm1DVectorTy, targetAttrs); + LLVM::detail::setNativeProperties(newOp, overflowFlags); + return newOp->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 1db506e..78ff24d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -47,6 +47,74 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage; #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" +//===----------------------------------------------------------------------===// +// Property Helpers +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// IntegerOverflowFlags + +namespace mlir { +static Attribute convertToAttribute(MLIRContext *ctx, + IntegerOverflowFlags flags) { + return IntegerOverflowFlagsAttr::get(ctx, flags); +} + +static LogicalResult +convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr, + function_ref<InFlightDiagnostic()> emitError) { + auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr); + if (!flagsAttr) { + return emitError() << "expected 'overflowFlags' attribute to be an " + "IntegerOverflowFlagsAttr, but got " + << attr; + } + flags = flagsAttr.getValue(); + return success(); +} +} // namespace mlir + +static ParseResult parseOverflowFlags(AsmParser &p, + IntegerOverflowFlags &flags) { + if (failed(p.parseOptionalKeyword("overflow"))) { + flags = IntegerOverflowFlags::none; + return success(); + } + if (p.parseLess()) + return failure(); + do { + StringRef kw; + SMLoc loc = p.getCurrentLocation(); + if (p.parseKeyword(&kw)) + return failure(); + std::optional<IntegerOverflowFlags> flag = + symbolizeIntegerOverflowFlags(kw); + if (!flag) + return p.emitError(loc, + "invalid overflow flag: expected nsw, nuw, or none"); + flags = flags | *flag; + } while (succeeded(p.parseOptionalComma())); + return p.parseGreater(); +} + +static void printOverflowFlags(AsmPrinter &p, Operation *op, + IntegerOverflowFlags flags) { + if (flags == IntegerOverflowFlags::none) + return; + p << " overflow<"; + SmallVector<StringRef, 2> strs; + if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) + strs.push_back("nsw"); + if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) + strs.push_back("nuw"); + llvm::interleaveComma(strs, p); + p << ">"; +} + +//===----------------------------------------------------------------------===// +// Attribute Helpers +//===----------------------------------------------------------------------===// + static constexpr const char kElemTypeAttrName[] = "elem_type"; static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { @@ -70,12 +138,12 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs) { auto filteredAttrs = processFMFAttr(attrs.getValue()); - if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) + if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) { printer.printOptionalAttrDict( - filteredAttrs, - /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()}); - else + filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()}); + } else { printer.printOptionalAttrDict(filteredAttrs); + } } /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 9d41fa0..191b84a 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -627,8 +627,8 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst, } } -void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst, - Operation *op) const { +void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst, + Operation *op) const { auto iface = cast<IntegerOverflowFlagsInterface>(op); IntegerOverflowFlags value = {}; @@ -636,8 +636,7 @@ void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst, value = bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap()); - auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value); - iface->setAttr(iface.getIntegerOverflowAttrName(), attr); + iface.setOverflowFlags(value); } void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, |