From e9e1c411b6db8fb739c2c7af0d41bdd48eeed3e5 Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Thu, 7 Dec 2023 10:35:00 +0000 Subject: [mlir][LLVM] Add nsw and nuw flags (#74508) The implementation of these are modeled after the existing fastmath flags for floating point arithmetic. --- mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 23 +++++++++ mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td | 57 ++++++++++++++++++++++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 30 ++++++++++-- mlir/include/mlir/Target/LLVMIR/ModuleImport.h | 6 +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 8 ++- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 13 +++++ mlir/test/Dialect/LLVMIR/roundtrip.mlir | 10 ++++ mlir/test/Target/LLVMIR/Import/nsw_nuw.ll | 14 ++++++ mlir/test/Target/LLVMIR/nsw_nuw.mlir | 14 ++++++ 9 files changed, 170 insertions(+), 5 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/Import/nsw_nuw.ll create mode 100644 mlir/test/Target/LLVMIR/nsw_nuw.mlir (limited to 'mlir') diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td index 3b59844..a7b269e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td @@ -429,6 +429,29 @@ def DISubprogramFlags : I32BitEnumAttr< } //===----------------------------------------------------------------------===// +// IntegerOverflowFlags +//===----------------------------------------------------------------------===// + +def IOFnone : I32BitEnumAttrCaseNone<"none">; +def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>; +def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>; + +def IntegerOverflowFlags : I32BitEnumAttr< + "IntegerOverflowFlags", + "LLVM integer overflow flags", + [IOFnone, IOFnsw, IOFnuw]> { + let separator = ", "; + let cppNamespace = "::mlir::LLVM"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; +} + +def LLVM_IntegerOverflowFlagsAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +//===----------------------------------------------------------------------===// // FastmathFlags //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index c5d65f7..81589eaf 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { ]; } +def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> { + let description = [{ + Access to op integer overflow flags. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation", + /*returnType=*/ "IntegerOverflowFlagsAttr", + /*methodName=*/ "getOverflowAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + return op.getOverflowFlagsAttr(); + }] + >, + InterfaceMethod< + /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword", + /*returnType=*/ "bool", + /*methodName=*/ "hasNoUnsignedWrap", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); + return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw); + }] + >, + InterfaceMethod< + /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword", + /*returnType=*/ "bool", + /*methodName=*/ "hasNoSignedWrap", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); + return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getIntegerOverflowAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "overflowFlags"; + }] + > + ]; +} + def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { let description = [{ An interface for operations that can carry branch weights metadata. It diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 92460fa..88d9cd2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -55,6 +55,26 @@ class LLVM_IntArithmeticOp($_location, $lhs, $rhs); }]; } +class LLVM_IntArithmeticOpWithOverflowFlag traits = []> : + LLVM_ArithmeticOpBase], traits)> { + dag iofArg = ( + ins DefaultValuedAttr:$overflowFlags); + let arguments = !con(commonArgs, iofArg); + string mlirBuilder = [{ + auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + moduleImport.setIntegerOverflowFlagsAttr(inst, op); + $res = op; + }]; + let assemblyFormat = [{ + $lhs `,` $rhs (`overflow` `` $overflowFlags^)? + custom(attr-dict) `:` type($res) + }]; + string llvmBuilder = + "$res = builder.Create" # instName # + "($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"; +} class LLVM_FloatArithmeticOp traits = []> : LLVM_ArithmeticOpBase; -def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">; -def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>; +def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add", + [Commutative]>; +def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>; +def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul", + [Commutative]>; def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">; def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">; def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; @@ -102,7 +124,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> { let hasFolder = 1; } def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; -def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> { +def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> { let hasFolder = 1; } def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index b8e449d..b49d2f5 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -172,6 +172,12 @@ public: /// attributes of LLVMFuncOp `funcOp`. void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp); + /// Sets the integer overflow flags (nsw/nuw) attribute for the imported + /// operation `op` given the original instruction `inst`. Asserts if the + /// operation does not implement the integer overflow flag interface. + void setIntegerOverflowFlagsAttr(llvm::Instruction *inst, + Operation *op) const; + /// Sets the fastmath flags attribute for the imported operation `op` given /// the original instruction `inst`. Asserts if the operation does not /// implement the fastmath interface. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c67bba0..53e1088 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -69,7 +69,13 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs) { - printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); + auto filteredAttrs = processFMFAttr(attrs.getValue()); + if (auto iface = dyn_cast(op)) + printer.printOptionalAttrDict( + filteredAttrs, + /*elidedAttrs=*/{iface.getIntegerOverflowAttrName()}); + else + printer.printOptionalAttrDict(filteredAttrs); } /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 4bdffa5..7c51ee7 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst, } } +void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst, + Operation *op) const { + auto iface = cast(op); + + IntegerOverflowFlags value = {}; + value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap()); + value = + bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap()); + + auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value); + iface->setAttr(iface.getIntegerOverflowAttrName(), attr); +} + void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const { auto iface = cast(op); diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 4258997..594c3de 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32, %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr> %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1> +// Integer overflow flags +// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] overflow : i32 +// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] overflow : i32 +// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] overflow : i32 +// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] overflow : i32 + %add_flag = llvm.add %arg0, %arg0 overflow : i32 + %sub_flag = llvm.sub %arg0, %arg0 overflow : i32 + %mul_flag = llvm.mul %arg0, %arg0 overflow : i32 + %shl_flag = llvm.shl %arg0, %arg0 overflow : i32 + // Floating point binary operations. // // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32 diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll new file mode 100644 index 0000000..d08098a --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll @@ -0,0 +1,14 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @intflag_inst +define void @intflag_inst(i64 %arg1, i64 %arg2) { + ; CHECK: llvm.add %{{.*}}, %{{.*}} overflow : i64 + %1 = add nsw i64 %arg1, %arg2 + ; CHECK: llvm.sub %{{.*}}, %{{.*}} overflow : i64 + %2 = sub nuw i64 %arg1, %arg2 + ; CHECK: llvm.mul %{{.*}}, %{{.*}} overflow : i64 + %3 = mul nsw nuw i64 %arg1, %arg2 + ; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow : i64 + %4 = shl nuw nsw i64 %arg1, %arg2 + ret void +} diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir new file mode 100644 index 0000000..6843c2e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define void @intflags_func +llvm.func @intflags_func(%arg0: i64, %arg1: i64) { + // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}} + %0 = llvm.add %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}} + %1 = llvm.sub %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}} + %2 = llvm.mul %arg0, %arg1 overflow : i64 + // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}} + %3 = llvm.shl %arg0, %arg1 overflow : i64 + llvm.return +} -- cgit v1.1