diff options
24 files changed, 796 insertions, 295 deletions
diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp index f4df63c..9cbd1bd 100644 --- a/clang/lib/Sema/SemaConcept.cpp +++ b/clang/lib/Sema/SemaConcept.cpp @@ -604,6 +604,10 @@ ConstraintSatisfactionChecker::SubstitutionInTemplateArguments( return std::nullopt; const NormalizedConstraint::OccurenceList &Used = Constraint.mappingOccurenceList(); + // The empty MLTAL situation should only occur when evaluating non-dependent + // constraints. + if (!MLTAL.getNumSubstitutedLevels()) + MLTAL.addOuterTemplateArguments(TD, {}, /*Final=*/false); SubstitutedOuterMost = llvm::to_vector_of<TemplateArgument>(MLTAL.getOutermost()); unsigned Offset = 0; @@ -623,9 +627,7 @@ ConstraintSatisfactionChecker::SubstitutionInTemplateArguments( if (Offset < SubstitutedOuterMost.size()) SubstitutedOuterMost.erase(SubstitutedOuterMost.begin() + Offset); - MLTAL.replaceOutermostTemplateArguments( - const_cast<NamedDecl *>(Constraint.getConstraintDecl()), - SubstitutedOuterMost); + MLTAL.replaceOutermostTemplateArguments(TD, SubstitutedOuterMost); return std::move(MLTAL); } @@ -956,11 +958,20 @@ ExprResult ConstraintSatisfactionChecker::Evaluate( ? Constraint.getPackSubstitutionIndex() : PackSubstitutionIndex; - Sema::InstantiatingTemplate _(S, ConceptId->getBeginLoc(), - Sema::InstantiatingTemplate::ConstraintsCheck{}, - ConceptId->getNamedConcept(), - MLTAL.getInnermost(), - Constraint.getSourceRange()); + Sema::InstantiatingTemplate InstTemplate( + S, ConceptId->getBeginLoc(), + Sema::InstantiatingTemplate::ConstraintsCheck{}, + ConceptId->getNamedConcept(), + // We may have empty template arguments when checking non-dependent + // nested constraint expressions. + // In such cases, non-SFINAE errors would have already been diagnosed + // during parameter mapping substitution, so the instantiating template + // arguments are less useful here. + MLTAL.getNumSubstitutedLevels() ? MLTAL.getInnermost() + : ArrayRef<TemplateArgument>{}, + Constraint.getSourceRange()); + if (InstTemplate.isInvalid()) + return ExprError(); unsigned Size = Satisfaction.Details.size(); diff --git a/clang/test/SemaTemplate/concepts.cpp b/clang/test/SemaTemplate/concepts.cpp index 1dbb989..3fbe7c0 100644 --- a/clang/test/SemaTemplate/concepts.cpp +++ b/clang/test/SemaTemplate/concepts.cpp @@ -1404,6 +1404,18 @@ static_assert(!std::is_constructible_v<span<4>, array<int, 3>>); } +namespace case7 { + +template <class _Tp, class _Up> +concept __same_as_impl = __is_same(_Tp, _Up); +template <class _Tp, class _Up> +concept same_as = __same_as_impl<_Tp, _Up>; +template <typename> +concept IsEntitySpec = + requires { requires same_as<void, void>; }; + +} + } namespace GH162125 { diff --git a/llvm/docs/QualGroup.rst b/llvm/docs/QualGroup.rst index b45f569..5c05e4e 100644 --- a/llvm/docs/QualGroup.rst +++ b/llvm/docs/QualGroup.rst @@ -75,6 +75,16 @@ They meet the criteria for inclusion below. Knowing their handles help us keep t - capitan-davide - capitan_davide - capitan-davide + * - Jorge Pinto Sousa + - Critical Techworks + - sousajo-cc + - sousajo-cc + - sousajo-cc + * - José Rui Simões + - Critical Software + - jr-simoes + - jr_simoes + - iznogoud-zz * - Oscar Slotosch - Validas - slotosch @@ -100,6 +110,11 @@ They meet the criteria for inclusion below. Knowing their handles help us keep t - YoungJunLee - YoungJunLee - IamYJLee + * - Zaky Hermawan + - No Affiliation + - ZakyHermawan + - quarkz99 + - zakyHermawan Organizations are limited to three representatives within the group to maintain diversity. diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h index face5da..d47f6c0 100644 --- a/llvm/include/llvm/IR/ConstantFPRange.h +++ b/llvm/include/llvm/IR/ConstantFPRange.h @@ -216,6 +216,12 @@ public: /// Get the range without infinities. It is useful when we apply ninf flag to /// range of operands/results. LLVM_ABI ConstantFPRange getWithoutInf() const; + + /// Return a new range in the specified format with the specified rounding + /// mode. + LLVM_ABI ConstantFPRange + cast(const fltSemantics &DstSem, + APFloat::roundingMode RM = APFloat::rmNearestTiesToEven) const; }; inline raw_ostream &operator<<(raw_ostream &OS, const ConstantFPRange &CR) { diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp index 2477e22..070e833 100644 --- a/llvm/lib/IR/ConstantFPRange.cpp +++ b/llvm/lib/IR/ConstantFPRange.cpp @@ -326,6 +326,8 @@ std::optional<bool> ConstantFPRange::getSignBit() const { } bool ConstantFPRange::operator==(const ConstantFPRange &CR) const { + assert(&getSemantics() == &CR.getSemantics() && + "Should only use the same semantics"); if (MayBeSNaN != CR.MayBeSNaN || MayBeQNaN != CR.MayBeQNaN) return false; return Lower.bitwiseIsEqual(CR.Lower) && Upper.bitwiseIsEqual(CR.Upper); @@ -425,3 +427,20 @@ ConstantFPRange ConstantFPRange::getWithoutInf() const { return ConstantFPRange(std::move(NewLower), std::move(NewUpper), MayBeQNaN, MayBeSNaN); } + +ConstantFPRange ConstantFPRange::cast(const fltSemantics &DstSem, + APFloat::roundingMode RM) const { + bool LosesInfo; + APFloat NewLower = Lower; + APFloat NewUpper = Upper; + // For conservative, return full range if conversion is invalid. + if (NewLower.convert(DstSem, RM, &LosesInfo) == APFloat::opInvalidOp || + NewLower.isNaN()) + return getFull(DstSem); + if (NewUpper.convert(DstSem, RM, &LosesInfo) == APFloat::opInvalidOp || + NewUpper.isNaN()) + return getFull(DstSem); + return ConstantFPRange(std::move(NewLower), std::move(NewUpper), + /*MayBeQNaNVal=*/MayBeQNaN || MayBeSNaN, + /*MayBeSNaNVal=*/false); +} diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp index f692180..944a1e2 100644 --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -585,6 +585,10 @@ PPCTargetLowering::PPCTargetLowering(const PPCTargetMachine &TM, // We cannot sextinreg(i1). Expand to shifts. setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); + // Custom handling for PowerPC ucmp instruction + setOperationAction(ISD::UCMP, MVT::i32, Custom); + setOperationAction(ISD::UCMP, MVT::i64, isPPC64 ? Custom : Expand); + // NOTE: EH_SJLJ_SETJMP/_LONGJMP supported here is NOT intended to support // SjLj exception handling but a light-weight setjmp/longjmp replacement to // support continuation, user-level threading, and etc.. As a result, no @@ -12618,6 +12622,33 @@ SDValue PPCTargetLowering::LowerSSUBO(SDValue Op, SelectionDAG &DAG) const { return DAG.getMergeValues({Sub, OverflowTrunc}, dl); } +// Lower unsigned 3-way compare producing -1/0/1. +SDValue PPCTargetLowering::LowerUCMP(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue A = DAG.getFreeze(Op.getOperand(0)); + SDValue B = DAG.getFreeze(Op.getOperand(1)); + EVT OpVT = A.getValueType(); // operand type + EVT ResVT = Op.getValueType(); // result type + + // First compute diff = A - B (will become subf). + SDValue Diff = DAG.getNode(ISD::SUB, DL, OpVT, A, B); + + // Generate B - A using SUBC to capture carry. + SDVTList VTs = DAG.getVTList(OpVT, MVT::i32); + SDValue SubC = DAG.getNode(PPCISD::SUBC, DL, VTs, B, A); + SDValue CA0 = SubC.getValue(1); + + // t2 = A - B + CA0 using SUBE. + SDValue SubE1 = DAG.getNode(PPCISD::SUBE, DL, VTs, A, B, CA0); + SDValue CA1 = SubE1.getValue(1); + + // res = diff - t2 + CA1 using SUBE (produces desired -1/0/1). + SDValue ResPair = DAG.getNode(PPCISD::SUBE, DL, VTs, Diff, SubE1, CA1); + + // Extract the first result and truncate to result type if needed + return DAG.getSExtOrTrunc(ResPair.getValue(0), DL, ResVT); +} + /// LowerOperation - Provide custom lowering hooks for some operations. /// SDValue PPCTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { @@ -12722,6 +12753,8 @@ SDValue PPCTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::UADDO_CARRY: case ISD::USUBO_CARRY: return LowerADDSUBO_CARRY(Op, DAG); + case ISD::UCMP: + return LowerUCMP(Op, DAG); } } diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.h b/llvm/lib/Target/PowerPC/PPCISelLowering.h index 6694305..59f3387 100644 --- a/llvm/lib/Target/PowerPC/PPCISelLowering.h +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.h @@ -1318,6 +1318,7 @@ namespace llvm { SDValue LowerIS_FPCLASS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG) const; SDValue LowerADDSUBO(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerUCMP(SDValue Op, SelectionDAG &DAG) const; SDValue lowerToLibCall(const char *LibCallName, SDValue Op, SelectionDAG &DAG) const; SDValue lowerLibCallBasedOnType(const char *LibCallFloatName, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index 7a14929..b9e01c3 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -1653,17 +1653,18 @@ def riscv_selectcc_frag : PatFrag<(ops node:$lhs, node:$rhs, node:$cc, node:$falsev), [{}], IntCCtoRISCVCC>; -multiclass SelectCC_GPR_rrirr<DAGOperand valty, ValueType vt> { +multiclass SelectCC_GPR_rrirr<DAGOperand valty, ValueType vt, + ValueType cmpvt = XLenVT> { let usesCustomInserter = 1 in def _Using_CC_GPR : Pseudo<(outs valty:$dst), (ins GPR:$lhs, GPR:$rhs, cond_code:$cc, valty:$truev, valty:$falsev), [(set valty:$dst, - (riscv_selectcc_frag:$cc (XLenVT GPR:$lhs), GPR:$rhs, cond, + (riscv_selectcc_frag:$cc (cmpvt GPR:$lhs), GPR:$rhs, cond, (vt valty:$truev), valty:$falsev))]>; // Explicitly select 0 in the condition to X0. The register coalescer doesn't // always do it. - def : Pat<(riscv_selectcc_frag:$cc (XLenVT GPR:$lhs), 0, cond, (vt valty:$truev), + def : Pat<(riscv_selectcc_frag:$cc (cmpvt GPR:$lhs), 0, cond, (vt valty:$truev), valty:$falsev), (!cast<Instruction>(NAME#"_Using_CC_GPR") GPR:$lhs, (XLenVT X0), (IntCCtoRISCVCC $cc), valty:$truev, valty:$falsev)>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td index b9510ef..65e7e3b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -59,9 +59,9 @@ def FPR64IN32X : RegisterOperand<GPRPair> { def DExt : ExtInfo<"", "", [HasStdExtD], f64, FPR64, FPR32, FPR64, ?>; def ZdinxExt : ExtInfo<"_INX", "Zfinx", [HasStdExtZdinx, IsRV64], - f64, FPR64INX, FPR32INX, FPR64INX, ?>; + f64, FPR64INX, FPR32INX, FPR64INX, ?, i64>; def Zdinx32Ext : ExtInfo<"_IN32X", "ZdinxRV32Only", [HasStdExtZdinx, IsRV32], - f64, FPR64IN32X, FPR32INX, FPR64IN32X, ?>; + f64, FPR64IN32X, FPR32INX, FPR64IN32X, ?, i32>; defvar DExts = [DExt, ZdinxExt, Zdinx32Ext]; defvar DExtsRV64 = [DExt, ZdinxExt]; @@ -261,8 +261,10 @@ let Predicates = [HasStdExtZdinx, IsRV32] in { /// Float conversion operations // f64 -> f32, f32 -> f64 -def : Pat<(any_fpround FPR64IN32X:$rs1), (FCVT_S_D_IN32X FPR64IN32X:$rs1, FRM_DYN)>; -def : Pat<(any_fpextend FPR32INX:$rs1), (FCVT_D_S_IN32X FPR32INX:$rs1, FRM_RNE)>; +def : Pat<(any_fpround FPR64IN32X:$rs1), + (FCVT_S_D_IN32X FPR64IN32X:$rs1, (i32 FRM_DYN))>; +def : Pat<(any_fpextend FPR32INX:$rs1), + (FCVT_D_S_IN32X FPR32INX:$rs1, (i32 FRM_RNE))>; } // Predicates = [HasStdExtZdinx, IsRV32] // [u]int<->double conversion patterns must be gated on IsRV32 or IsRV64, so @@ -321,7 +323,7 @@ def : Pat<(any_fsqrt FPR64INX:$rs1), (FSQRT_D_INX FPR64INX:$rs1, FRM_DYN)>; def : Pat<(fneg FPR64INX:$rs1), (FSGNJN_D_INX $rs1, $rs1)>; def : Pat<(fabs FPR64INX:$rs1), (FSGNJX_D_INX $rs1, $rs1)>; -def : Pat<(riscv_fclass FPR64INX:$rs1), (FCLASS_D_INX $rs1)>; +def : Pat<(i64 (riscv_fclass FPR64INX:$rs1)), (FCLASS_D_INX $rs1)>; def : PatFprFpr<fcopysign, FSGNJ_D_INX, FPR64INX, f64>; def : PatFprFpr<riscv_fsgnjx, FSGNJX_D_INX, FPR64INX, f64>; @@ -354,41 +356,46 @@ def : Pat<(fneg (any_fma_nsz FPR64INX:$rs1, FPR64INX:$rs2, FPR64INX:$rs3)), } // Predicates = [HasStdExtZdinx, IsRV64] let Predicates = [HasStdExtZdinx, IsRV32] in { -def : Pat<(any_fsqrt FPR64IN32X:$rs1), (FSQRT_D_IN32X FPR64IN32X:$rs1, FRM_DYN)>; +def : Pat<(any_fsqrt FPR64IN32X:$rs1), + (FSQRT_D_IN32X FPR64IN32X:$rs1, (i32 FRM_DYN))>; def : Pat<(fneg FPR64IN32X:$rs1), (FSGNJN_D_IN32X $rs1, $rs1)>; def : Pat<(fabs FPR64IN32X:$rs1), (FSGNJX_D_IN32X $rs1, $rs1)>; -def : Pat<(riscv_fclass FPR64IN32X:$rs1), (FCLASS_D_IN32X $rs1)>; +def : Pat<(i32 (riscv_fclass FPR64IN32X:$rs1)), (FCLASS_D_IN32X $rs1)>; def : PatFprFpr<fcopysign, FSGNJ_D_IN32X, FPR64IN32X, f64>; def : PatFprFpr<riscv_fsgnjx, FSGNJX_D_IN32X, FPR64IN32X, f64>; def : Pat<(fcopysign FPR64IN32X:$rs1, (fneg FPR64IN32X:$rs2)), (FSGNJN_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2)>; def : Pat<(fcopysign FPR64IN32X:$rs1, FPR32INX:$rs2), - (FSGNJ_D_IN32X $rs1, (FCVT_D_S_IN32X $rs2, FRM_RNE))>; + (FSGNJ_D_IN32X $rs1, (FCVT_D_S_IN32X $rs2, (i32 FRM_RNE)))>; def : Pat<(fcopysign FPR32INX:$rs1, FPR64IN32X:$rs2), - (FSGNJ_S_INX $rs1, (FCVT_S_D_IN32X $rs2, FRM_DYN))>; + (FSGNJ_S_INX $rs1, (FCVT_S_D_IN32X $rs2, (i32 FRM_DYN)))>; // fmadd: rs1 * rs2 + rs3 def : Pat<(any_fma FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3), - (FMADD_D_IN32X $rs1, $rs2, $rs3, FRM_DYN)>; + (FMADD_D_IN32X $rs1, $rs2, $rs3, (i32 FRM_DYN))>; // fmsub: rs1 * rs2 - rs3 def : Pat<(any_fma FPR64IN32X:$rs1, FPR64IN32X:$rs2, (fneg FPR64IN32X:$rs3)), - (FMSUB_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, FRM_DYN)>; + (FMSUB_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, + (i32 FRM_DYN))>; // fnmsub: -rs1 * rs2 + rs3 def : Pat<(any_fma (fneg FPR64IN32X:$rs1), FPR64IN32X:$rs2, FPR64IN32X:$rs3), - (FNMSUB_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, FRM_DYN)>; + (FNMSUB_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, + (i32 FRM_DYN))>; // fnmadd: -rs1 * rs2 - rs3 def : Pat<(any_fma (fneg FPR64IN32X:$rs1), FPR64IN32X:$rs2, (fneg FPR64IN32X:$rs3)), - (FNMADD_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, FRM_DYN)>; + (FNMADD_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, + (i32 FRM_DYN))>; // fnmadd: -(rs1 * rs2 + rs3) (the nsz flag on the FMA) def : Pat<(fneg (any_fma_nsz FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3)), - (FNMADD_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, FRM_DYN)>; + (FNMADD_D_IN32X FPR64IN32X:$rs1, FPR64IN32X:$rs2, FPR64IN32X:$rs3, + (i32 FRM_DYN))>; } // Predicates = [HasStdExtZdinx, IsRV32] // The ratified 20191213 ISA spec defines fmin and fmax in a way that matches @@ -441,42 +448,42 @@ def : PatSetCC<FPR64, any_fsetccs, SETOLE, FLE_D, f64>; let Predicates = [HasStdExtZdinx, IsRV64] in { // Match signaling FEQ_D -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64INX:$rs1), FPR64INX:$rs2, SETEQ)), +def : Pat<(XLenVT (strict_fsetccs FPR64INX:$rs1, FPR64INX:$rs2, SETEQ)), (AND (XLenVT (FLE_D_INX $rs1, $rs2)), (XLenVT (FLE_D_INX $rs2, $rs1)))>; -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64INX:$rs1), FPR64INX:$rs2, SETOEQ)), +def : Pat<(XLenVT (strict_fsetccs FPR64INX:$rs1, FPR64INX:$rs2, SETOEQ)), (AND (XLenVT (FLE_D_INX $rs1, $rs2)), (XLenVT (FLE_D_INX $rs2, $rs1)))>; // If both operands are the same, use a single FLE. -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64INX:$rs1), FPR64INX:$rs1, SETEQ)), +def : Pat<(XLenVT (strict_fsetccs FPR64INX:$rs1, FPR64INX:$rs1, SETEQ)), (FLE_D_INX $rs1, $rs1)>; -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64INX:$rs1), FPR64INX:$rs1, SETOEQ)), +def : Pat<(XLenVT (strict_fsetccs FPR64INX:$rs1, FPR64INX:$rs1, SETOEQ)), (FLE_D_INX $rs1, $rs1)>; -def : PatSetCC<FPR64INX, any_fsetccs, SETLT, FLT_D_INX, f64>; -def : PatSetCC<FPR64INX, any_fsetccs, SETOLT, FLT_D_INX, f64>; -def : PatSetCC<FPR64INX, any_fsetccs, SETLE, FLE_D_INX, f64>; -def : PatSetCC<FPR64INX, any_fsetccs, SETOLE, FLE_D_INX, f64>; +def : PatSetCC<FPR64INX, any_fsetccs, SETLT, FLT_D_INX, f64, i64>; +def : PatSetCC<FPR64INX, any_fsetccs, SETOLT, FLT_D_INX, f64, i64>; +def : PatSetCC<FPR64INX, any_fsetccs, SETLE, FLE_D_INX, f64, i64>; +def : PatSetCC<FPR64INX, any_fsetccs, SETOLE, FLE_D_INX, f64, i64>; } // Predicates = [HasStdExtZdinx, IsRV64] let Predicates = [HasStdExtZdinx, IsRV32] in { // Match signaling FEQ_D -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETEQ)), +def : Pat<(i32 (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETEQ)), (AND (XLenVT (FLE_D_IN32X $rs1, $rs2)), (XLenVT (FLE_D_IN32X $rs2, $rs1)))>; -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs2, SETOEQ)), +def : Pat<(i32 (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs2, SETOEQ)), (AND (XLenVT (FLE_D_IN32X $rs1, $rs2)), (XLenVT (FLE_D_IN32X $rs2, $rs1)))>; // If both operands are the same, use a single FLE. -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETEQ)), +def : Pat<(i32 (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETEQ)), (FLE_D_IN32X $rs1, $rs1)>; -def : Pat<(XLenVT (strict_fsetccs (f64 FPR64IN32X:$rs1), FPR64IN32X:$rs1, SETOEQ)), +def : Pat<(i32 (strict_fsetccs FPR64IN32X:$rs1, FPR64IN32X:$rs1, SETOEQ)), (FLE_D_IN32X $rs1, $rs1)>; -def : PatSetCC<FPR64IN32X, any_fsetccs, SETLT, FLT_D_IN32X, f64>; -def : PatSetCC<FPR64IN32X, any_fsetccs, SETOLT, FLT_D_IN32X, f64>; -def : PatSetCC<FPR64IN32X, any_fsetccs, SETLE, FLE_D_IN32X, f64>; -def : PatSetCC<FPR64IN32X, any_fsetccs, SETOLE, FLE_D_IN32X, f64>; +def : PatSetCC<FPR64IN32X, any_fsetccs, SETLT, FLT_D_IN32X, f64, i32>; +def : PatSetCC<FPR64IN32X, any_fsetccs, SETOLT, FLT_D_IN32X, f64, i32>; +def : PatSetCC<FPR64IN32X, any_fsetccs, SETLE, FLE_D_IN32X, f64, i32>; +def : PatSetCC<FPR64IN32X, any_fsetccs, SETOLE, FLE_D_IN32X, f64, i32>; } // Predicates = [HasStdExtZdinx, IsRV32] let Predicates = [HasStdExtD] in { @@ -511,7 +518,7 @@ def SplitF64Pseudo } // Predicates = [HasStdExtD, NoStdExtZfa, IsRV32] let Predicates = [HasStdExtZdinx, IsRV64] in { -defm Select_FPR64INX : SelectCC_GPR_rrirr<FPR64INX, f64>; +defm Select_FPR64INX : SelectCC_GPR_rrirr<FPR64INX, f64, i64>; def PseudoFROUND_D_INX : PseudoFROUND<FPR64INX, f64>; @@ -523,9 +530,9 @@ def : StPat<store, SD, GPR, f64>; } // Predicates = [HasStdExtZdinx, IsRV64] let Predicates = [HasStdExtZdinx, IsRV32] in { -defm Select_FPR64IN32X : SelectCC_GPR_rrirr<FPR64IN32X, f64>; +defm Select_FPR64IN32X : SelectCC_GPR_rrirr<FPR64IN32X, f64, i32>; -def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64>; +def PseudoFROUND_D_IN32X : PseudoFROUND<FPR64IN32X, f64, i32>; /// Loads let hasSideEffects = 0, mayLoad = 1, mayStore = 0, Size = 8, isCodeGenOnly = 1 in diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td index fde030e..6571d99 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td @@ -131,7 +131,7 @@ def FPR32INX : RegisterOperand<GPRF32> { // The DAGOperand can be unset if the predicates are not enough to define it. class ExtInfo<string suffix, string space, list<Predicate> predicates, ValueType primaryvt, DAGOperand primaryty, DAGOperand f32ty, - DAGOperand f64ty, DAGOperand f16ty> { + DAGOperand f64ty, DAGOperand f16ty, ValueType intvt = XLenVT> { list<Predicate> Predicates = predicates; string Suffix = suffix; string Space = space; @@ -140,6 +140,7 @@ class ExtInfo<string suffix, string space, list<Predicate> predicates, DAGOperand F32Ty = f32ty; DAGOperand F64Ty = f64ty; ValueType PrimaryVT = primaryvt; + ValueType IntVT = intvt; } def FExt : ExtInfo<"", "", [HasStdExtF], f32, FPR32, FPR32, ?, ?>; @@ -314,9 +315,9 @@ multiclass FPCmp_rr_m<bits<7> funct7, bits<3> funct3, string opcodestr, def Ext.Suffix : FPCmp_rr<funct7, funct3, opcodestr, Ext.PrimaryTy, Commutable>; } -class PseudoFROUND<DAGOperand Ty, ValueType vt> +class PseudoFROUND<DAGOperand Ty, ValueType vt, ValueType intvt = XLenVT> : Pseudo<(outs Ty:$rd), (ins Ty:$rs1, Ty:$rs2, ixlenimm:$rm), - [(set Ty:$rd, (vt (riscv_fround Ty:$rs1, Ty:$rs2, timm:$rm)))]> { + [(set Ty:$rd, (vt (riscv_fround Ty:$rs1, Ty:$rs2, (intvt timm:$rm))))]> { let hasSideEffects = 0; let mayLoad = 0; let mayStore = 0; @@ -529,13 +530,14 @@ def fpimm0 : PatLeaf<(fpimm), [{ return N->isExactlyValue(+0.0); }]>; /// Generic pattern classes class PatSetCC<DAGOperand Ty, SDPatternOperator OpNode, CondCode Cond, - RVInstCommon Inst, ValueType vt> - : Pat<(XLenVT (OpNode (vt Ty:$rs1), Ty:$rs2, Cond)), (Inst $rs1, $rs2)>; + RVInstCommon Inst, ValueType vt, ValueType intvt = XLenVT> + : Pat<(intvt (OpNode (vt Ty:$rs1), Ty:$rs2, Cond)), (Inst $rs1, $rs2)>; multiclass PatSetCC_m<SDPatternOperator OpNode, CondCode Cond, RVInstCommon Inst, ExtInfo Ext> { let Predicates = Ext.Predicates in def Ext.Suffix : PatSetCC<Ext.PrimaryTy, OpNode, Cond, - !cast<RVInstCommon>(Inst#Ext.Suffix), Ext.PrimaryVT>; + !cast<RVInstCommon>(Inst#Ext.Suffix), + Ext.PrimaryVT, Ext.IntVT>; } class PatFprFpr<SDPatternOperator OpNode, RVInstR Inst, @@ -549,14 +551,15 @@ multiclass PatFprFpr_m<SDPatternOperator OpNode, RVInstR Inst, } class PatFprFprDynFrm<SDPatternOperator OpNode, RVInstRFrm Inst, - DAGOperand RegTy, ValueType vt> - : Pat<(OpNode (vt RegTy:$rs1), (vt RegTy:$rs2)), (Inst $rs1, $rs2, FRM_DYN)>; + DAGOperand RegTy, ValueType vt, ValueType intvt> + : Pat<(OpNode (vt RegTy:$rs1), (vt RegTy:$rs2)), + (Inst $rs1, $rs2,(intvt FRM_DYN))>; multiclass PatFprFprDynFrm_m<SDPatternOperator OpNode, RVInstRFrm Inst, ExtInfo Ext> { let Predicates = Ext.Predicates in def Ext.Suffix : PatFprFprDynFrm<OpNode, !cast<RVInstRFrm>(Inst#Ext.Suffix), - Ext.PrimaryTy, Ext.PrimaryVT>; + Ext.PrimaryTy, Ext.PrimaryVT, Ext.IntVT>; } /// Float conversion operations diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td index 014da99..52a2b29 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td @@ -69,16 +69,16 @@ def ZhinxminExt : ExtInfo<"_INX", "Zfinx", f16, FPR16INX, FPR32INX, ?, FPR16INX>; def ZhinxZdinxExt : ExtInfo<"_INX", "Zfinx", [HasStdExtZhinx, HasStdExtZdinx, IsRV64], - ?, ?, FPR32INX, FPR64INX, FPR16INX>; + ?, ?, FPR32INX, FPR64INX, FPR16INX, i64>; def ZhinxminZdinxExt : ExtInfo<"_INX", "Zfinx", [HasStdExtZhinxmin, HasStdExtZdinx, IsRV64], - ?, ?, FPR32INX, FPR64INX, FPR16INX>; + ?, ?, FPR32INX, FPR64INX, FPR16INX, i64>; def ZhinxZdinx32Ext : ExtInfo<"_IN32X", "ZdinxGPRPairRV32", [HasStdExtZhinx, HasStdExtZdinx, IsRV32], - ?, ?, FPR32INX, FPR64IN32X, FPR16INX>; + ?, ?, FPR32INX, FPR64IN32X, FPR16INX, i32>; def ZhinxminZdinx32Ext : ExtInfo<"_IN32X", "ZdinxGPRPairRV32", [HasStdExtZhinxmin, HasStdExtZdinx, IsRV32], - ?, ?, FPR32INX, FPR64IN32X, FPR16INX>; + ?, ?, FPR32INX, FPR64IN32X, FPR16INX, i32>; defvar ZfhExts = [ZfhExt, ZhinxExt]; defvar ZfhminExts = [ZfhminExt, ZhinxminExt]; diff --git a/llvm/test/CodeGen/PowerPC/memcmp.ll b/llvm/test/CodeGen/PowerPC/memcmp.ll index 39f9269..4998d87 100644 --- a/llvm/test/CodeGen/PowerPC/memcmp.ll +++ b/llvm/test/CodeGen/PowerPC/memcmp.ll @@ -6,12 +6,10 @@ define signext i32 @memcmp8(ptr nocapture readonly %buffer1, ptr nocapture reado ; CHECK: # %bb.0: ; CHECK-NEXT: ldbrx 3, 0, 3 ; CHECK-NEXT: ldbrx 4, 0, 4 -; CHECK-NEXT: cmpld 3, 4 -; CHECK-NEXT: subc 3, 4, 3 -; CHECK-NEXT: subfe 3, 4, 4 -; CHECK-NEXT: li 4, -1 -; CHECK-NEXT: neg 3, 3 -; CHECK-NEXT: isellt 3, 4, 3 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: extsw 3, 3 ; CHECK-NEXT: blr %call = tail call signext i32 @memcmp(ptr %buffer1, ptr %buffer2, i64 8) @@ -23,11 +21,11 @@ define signext i32 @memcmp4(ptr nocapture readonly %buffer1, ptr nocapture reado ; CHECK: # %bb.0: ; CHECK-NEXT: lwbrx 3, 0, 3 ; CHECK-NEXT: lwbrx 4, 0, 4 -; CHECK-NEXT: cmplw 3, 4 -; CHECK-NEXT: sub 5, 4, 3 -; CHECK-NEXT: li 3, -1 -; CHECK-NEXT: rldicl 5, 5, 1, 63 -; CHECK-NEXT: isellt 3, 3, 5 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 +; CHECK-NEXT: extsw 3, 3 ; CHECK-NEXT: blr %call = tail call signext i32 @memcmp(ptr %buffer1, ptr %buffer2, i64 4) ret i32 %call diff --git a/llvm/test/CodeGen/PowerPC/ucmp.ll b/llvm/test/CodeGen/PowerPC/ucmp.ll index d2dff6e..4d393dd 100644 --- a/llvm/test/CodeGen/PowerPC/ucmp.ll +++ b/llvm/test/CodeGen/PowerPC/ucmp.ll @@ -4,12 +4,10 @@ define i8 @ucmp_8_8(i8 zeroext %x, i8 zeroext %y) nounwind { ; CHECK-LABEL: ucmp_8_8: ; CHECK: # %bb.0: -; CHECK-NEXT: cmplw 3, 4 -; CHECK-NEXT: sub 5, 4, 3 -; CHECK-NEXT: li 3, -1 -; CHECK-NEXT: rldicl 5, 5, 1, 63 -; CHECK-NEXT: rldic 3, 3, 0, 32 -; CHECK-NEXT: isellt 3, 3, 5 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: blr %1 = call i8 @llvm.ucmp(i8 %x, i8 %y) ret i8 %1 @@ -18,12 +16,10 @@ define i8 @ucmp_8_8(i8 zeroext %x, i8 zeroext %y) nounwind { define i8 @ucmp_8_16(i16 zeroext %x, i16 zeroext %y) nounwind { ; CHECK-LABEL: ucmp_8_16: ; CHECK: # %bb.0: -; CHECK-NEXT: cmplw 3, 4 -; CHECK-NEXT: sub 5, 4, 3 -; CHECK-NEXT: li 3, -1 -; CHECK-NEXT: rldicl 5, 5, 1, 63 -; CHECK-NEXT: rldic 3, 3, 0, 32 -; CHECK-NEXT: isellt 3, 3, 5 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: blr %1 = call i8 @llvm.ucmp(i16 %x, i16 %y) ret i8 %1 @@ -32,14 +28,10 @@ define i8 @ucmp_8_16(i16 zeroext %x, i16 zeroext %y) nounwind { define i8 @ucmp_8_32(i32 %x, i32 %y) nounwind { ; CHECK-LABEL: ucmp_8_32: ; CHECK: # %bb.0: -; CHECK-NEXT: clrldi 5, 4, 32 -; CHECK-NEXT: clrldi 6, 3, 32 -; CHECK-NEXT: sub 5, 5, 6 -; CHECK-NEXT: cmplw 3, 4 -; CHECK-NEXT: li 3, -1 -; CHECK-NEXT: rldic 3, 3, 0, 32 -; CHECK-NEXT: rldicl 5, 5, 1, 63 -; CHECK-NEXT: isellt 3, 3, 5 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: blr %1 = call i8 @llvm.ucmp(i32 %x, i32 %y) ret i8 %1 @@ -48,12 +40,10 @@ define i8 @ucmp_8_32(i32 %x, i32 %y) nounwind { define i8 @ucmp_8_64(i64 %x, i64 %y) nounwind { ; CHECK-LABEL: ucmp_8_64: ; CHECK: # %bb.0: -; CHECK-NEXT: cmpld 3, 4 -; CHECK-NEXT: subc 3, 4, 3 -; CHECK-NEXT: subfe 3, 4, 4 -; CHECK-NEXT: li 4, -1 -; CHECK-NEXT: neg 3, 3 -; CHECK-NEXT: isellt 3, 4, 3 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: blr %1 = call i8 @llvm.ucmp(i64 %x, i64 %y) ret i8 %1 @@ -82,14 +72,10 @@ define i8 @ucmp_8_128(i128 %x, i128 %y) nounwind { define i32 @ucmp_32_32(i32 %x, i32 %y) nounwind { ; CHECK-LABEL: ucmp_32_32: ; CHECK: # %bb.0: -; CHECK-NEXT: clrldi 5, 4, 32 -; CHECK-NEXT: clrldi 6, 3, 32 -; CHECK-NEXT: sub 5, 5, 6 -; CHECK-NEXT: cmplw 3, 4 -; CHECK-NEXT: li 3, -1 -; CHECK-NEXT: rldic 3, 3, 0, 32 -; CHECK-NEXT: rldicl 5, 5, 1, 63 -; CHECK-NEXT: isellt 3, 3, 5 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: blr %1 = call i32 @llvm.ucmp(i32 %x, i32 %y) ret i32 %1 @@ -98,12 +84,10 @@ define i32 @ucmp_32_32(i32 %x, i32 %y) nounwind { define i32 @ucmp_32_64(i64 %x, i64 %y) nounwind { ; CHECK-LABEL: ucmp_32_64: ; CHECK: # %bb.0: -; CHECK-NEXT: cmpld 3, 4 -; CHECK-NEXT: subc 3, 4, 3 -; CHECK-NEXT: subfe 3, 4, 4 -; CHECK-NEXT: li 4, -1 -; CHECK-NEXT: neg 3, 3 -; CHECK-NEXT: isellt 3, 4, 3 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: blr %1 = call i32 @llvm.ucmp(i64 %x, i64 %y) ret i32 %1 @@ -112,12 +96,10 @@ define i32 @ucmp_32_64(i64 %x, i64 %y) nounwind { define i64 @ucmp_64_64(i64 %x, i64 %y) nounwind { ; CHECK-LABEL: ucmp_64_64: ; CHECK: # %bb.0: -; CHECK-NEXT: subc 5, 4, 3 -; CHECK-NEXT: cmpld 3, 4 -; CHECK-NEXT: li 3, -1 -; CHECK-NEXT: subfe 5, 4, 4 -; CHECK-NEXT: neg 5, 5 -; CHECK-NEXT: isellt 3, 3, 5 +; CHECK-NEXT: subc 6, 4, 3 +; CHECK-NEXT: sub 5, 3, 4 +; CHECK-NEXT: subfe 3, 4, 3 +; CHECK-NEXT: subfe 3, 3, 5 ; CHECK-NEXT: blr %1 = call i64 @llvm.ucmp(i64 %x, i64 %y) ret i64 %1 diff --git a/llvm/test/Transforms/AtomicExpand/X86/expand-atomic-non-integer.ll b/llvm/test/Transforms/AtomicExpand/X86/expand-atomic-non-integer.ll index 5929c15..84c7df1 100644 --- a/llvm/test/Transforms/AtomicExpand/X86/expand-atomic-non-integer.ll +++ b/llvm/test/Transforms/AtomicExpand/X86/expand-atomic-non-integer.ll @@ -1,152 +1,190 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 ; RUN: opt -S %s -passes=atomic-expand -mtriple=x86_64-linux-gnu | FileCheck %s ; This file tests the functions `llvm::convertAtomicLoadToIntegerType` and -; `llvm::convertAtomicStoreToIntegerType`. If X86 stops using this +; `llvm::convertAtomicStoreToIntegerType`. If X86 stops using this ; functionality, please move this test to a target which still is. define float @float_load_expand(ptr %ptr) { -; CHECK-LABEL: @float_load_expand -; CHECK: %1 = load atomic i32, ptr %ptr unordered, align 4 -; CHECK: %2 = bitcast i32 %1 to float -; CHECK: ret float %2 +; CHECK-LABEL: define float @float_load_expand( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = load atomic i32, ptr [[PTR]] unordered, align 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32 [[TMP1]] to float +; CHECK-NEXT: ret float [[TMP2]] +; %res = load atomic float, ptr %ptr unordered, align 4 ret float %res } define float @float_load_expand_seq_cst(ptr %ptr) { -; CHECK-LABEL: @float_load_expand_seq_cst -; CHECK: %1 = load atomic i32, ptr %ptr seq_cst, align 4 -; CHECK: %2 = bitcast i32 %1 to float -; CHECK: ret float %2 +; CHECK-LABEL: define float @float_load_expand_seq_cst( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = load atomic i32, ptr [[PTR]] seq_cst, align 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32 [[TMP1]] to float +; CHECK-NEXT: ret float [[TMP2]] +; %res = load atomic float, ptr %ptr seq_cst, align 4 ret float %res } define float @float_load_expand_vol(ptr %ptr) { -; CHECK-LABEL: @float_load_expand_vol -; CHECK: %1 = load atomic volatile i32, ptr %ptr unordered, align 4 -; CHECK: %2 = bitcast i32 %1 to float -; CHECK: ret float %2 +; CHECK-LABEL: define float @float_load_expand_vol( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = load atomic volatile i32, ptr [[PTR]] unordered, align 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32 [[TMP1]] to float +; CHECK-NEXT: ret float [[TMP2]] +; %res = load atomic volatile float, ptr %ptr unordered, align 4 ret float %res } define float @float_load_expand_addr1(ptr addrspace(1) %ptr) { -; CHECK-LABEL: @float_load_expand_addr1 -; CHECK: %1 = load atomic i32, ptr addrspace(1) %ptr unordered, align 4 -; CHECK: %2 = bitcast i32 %1 to float -; CHECK: ret float %2 +; CHECK-LABEL: define float @float_load_expand_addr1( +; CHECK-SAME: ptr addrspace(1) [[PTR:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = load atomic i32, ptr addrspace(1) [[PTR]] unordered, align 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32 [[TMP1]] to float +; CHECK-NEXT: ret float [[TMP2]] +; %res = load atomic float, ptr addrspace(1) %ptr unordered, align 4 ret float %res } define void @float_store_expand(ptr %ptr, float %v) { -; CHECK-LABEL: @float_store_expand -; CHECK: %1 = bitcast float %v to i32 -; CHECK: store atomic i32 %1, ptr %ptr unordered, align 4 +; CHECK-LABEL: define void @float_store_expand( +; CHECK-SAME: ptr [[PTR:%.*]], float [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float [[V]] to i32 +; CHECK-NEXT: store atomic i32 [[TMP1]], ptr [[PTR]] unordered, align 4 +; CHECK-NEXT: ret void +; store atomic float %v, ptr %ptr unordered, align 4 ret void } define void @float_store_expand_seq_cst(ptr %ptr, float %v) { -; CHECK-LABEL: @float_store_expand_seq_cst -; CHECK: %1 = bitcast float %v to i32 -; CHECK: store atomic i32 %1, ptr %ptr seq_cst, align 4 +; CHECK-LABEL: define void @float_store_expand_seq_cst( +; CHECK-SAME: ptr [[PTR:%.*]], float [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float [[V]] to i32 +; CHECK-NEXT: store atomic i32 [[TMP1]], ptr [[PTR]] seq_cst, align 4 +; CHECK-NEXT: ret void +; store atomic float %v, ptr %ptr seq_cst, align 4 ret void } define void @float_store_expand_vol(ptr %ptr, float %v) { -; CHECK-LABEL: @float_store_expand_vol -; CHECK: %1 = bitcast float %v to i32 -; CHECK: store atomic volatile i32 %1, ptr %ptr unordered, align 4 +; CHECK-LABEL: define void @float_store_expand_vol( +; CHECK-SAME: ptr [[PTR:%.*]], float [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float [[V]] to i32 +; CHECK-NEXT: store atomic volatile i32 [[TMP1]], ptr [[PTR]] unordered, align 4 +; CHECK-NEXT: ret void +; store atomic volatile float %v, ptr %ptr unordered, align 4 ret void } define void @float_store_expand_addr1(ptr addrspace(1) %ptr, float %v) { -; CHECK-LABEL: @float_store_expand_addr1 -; CHECK: %1 = bitcast float %v to i32 -; CHECK: store atomic i32 %1, ptr addrspace(1) %ptr unordered, align 4 +; CHECK-LABEL: define void @float_store_expand_addr1( +; CHECK-SAME: ptr addrspace(1) [[PTR:%.*]], float [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float [[V]] to i32 +; CHECK-NEXT: store atomic i32 [[TMP1]], ptr addrspace(1) [[PTR]] unordered, align 4 +; CHECK-NEXT: ret void +; store atomic float %v, ptr addrspace(1) %ptr unordered, align 4 ret void } define void @pointer_cmpxchg_expand(ptr %ptr, ptr %v) { -; CHECK-LABEL: @pointer_cmpxchg_expand -; CHECK: %1 = ptrtoint ptr %v to i64 -; CHECK: %2 = cmpxchg ptr %ptr, i64 0, i64 %1 seq_cst monotonic -; CHECK: %3 = extractvalue { i64, i1 } %2, 0 -; CHECK: %4 = extractvalue { i64, i1 } %2, 1 -; CHECK: %5 = inttoptr i64 %3 to ptr -; CHECK: %6 = insertvalue { ptr, i1 } poison, ptr %5, 0 -; CHECK: %7 = insertvalue { ptr, i1 } %6, i1 %4, 1 +; CHECK-LABEL: define void @pointer_cmpxchg_expand( +; CHECK-SAME: ptr [[PTR:%.*]], ptr [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = cmpxchg ptr [[PTR]], i64 0, i64 [[TMP1]] seq_cst monotonic, align 8 +; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = inttoptr i64 [[TMP3]] to ptr +; CHECK-NEXT: [[TMP6:%.*]] = insertvalue { ptr, i1 } poison, ptr [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertvalue { ptr, i1 } [[TMP6]], i1 [[TMP4]], 1 +; CHECK-NEXT: ret void +; cmpxchg ptr %ptr, ptr null, ptr %v seq_cst monotonic ret void } define void @pointer_cmpxchg_expand2(ptr %ptr, ptr %v) { -; CHECK-LABEL: @pointer_cmpxchg_expand2 -; CHECK: %1 = ptrtoint ptr %v to i64 -; CHECK: %2 = cmpxchg ptr %ptr, i64 0, i64 %1 release monotonic -; CHECK: %3 = extractvalue { i64, i1 } %2, 0 -; CHECK: %4 = extractvalue { i64, i1 } %2, 1 -; CHECK: %5 = inttoptr i64 %3 to ptr -; CHECK: %6 = insertvalue { ptr, i1 } poison, ptr %5, 0 -; CHECK: %7 = insertvalue { ptr, i1 } %6, i1 %4, 1 +; CHECK-LABEL: define void @pointer_cmpxchg_expand2( +; CHECK-SAME: ptr [[PTR:%.*]], ptr [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = cmpxchg ptr [[PTR]], i64 0, i64 [[TMP1]] release monotonic, align 8 +; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = inttoptr i64 [[TMP3]] to ptr +; CHECK-NEXT: [[TMP6:%.*]] = insertvalue { ptr, i1 } poison, ptr [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertvalue { ptr, i1 } [[TMP6]], i1 [[TMP4]], 1 +; CHECK-NEXT: ret void +; cmpxchg ptr %ptr, ptr null, ptr %v release monotonic ret void } define void @pointer_cmpxchg_expand3(ptr %ptr, ptr %v) { -; CHECK-LABEL: @pointer_cmpxchg_expand3 -; CHECK: %1 = ptrtoint ptr %v to i64 -; CHECK: %2 = cmpxchg ptr %ptr, i64 0, i64 %1 seq_cst seq_cst -; CHECK: %3 = extractvalue { i64, i1 } %2, 0 -; CHECK: %4 = extractvalue { i64, i1 } %2, 1 -; CHECK: %5 = inttoptr i64 %3 to ptr -; CHECK: %6 = insertvalue { ptr, i1 } poison, ptr %5, 0 -; CHECK: %7 = insertvalue { ptr, i1 } %6, i1 %4, 1 +; CHECK-LABEL: define void @pointer_cmpxchg_expand3( +; CHECK-SAME: ptr [[PTR:%.*]], ptr [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = cmpxchg ptr [[PTR]], i64 0, i64 [[TMP1]] seq_cst seq_cst, align 8 +; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = inttoptr i64 [[TMP3]] to ptr +; CHECK-NEXT: [[TMP6:%.*]] = insertvalue { ptr, i1 } poison, ptr [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertvalue { ptr, i1 } [[TMP6]], i1 [[TMP4]], 1 +; CHECK-NEXT: ret void +; cmpxchg ptr %ptr, ptr null, ptr %v seq_cst seq_cst ret void } define void @pointer_cmpxchg_expand4(ptr %ptr, ptr %v) { -; CHECK-LABEL: @pointer_cmpxchg_expand4 -; CHECK: %1 = ptrtoint ptr %v to i64 -; CHECK: %2 = cmpxchg weak ptr %ptr, i64 0, i64 %1 seq_cst seq_cst -; CHECK: %3 = extractvalue { i64, i1 } %2, 0 -; CHECK: %4 = extractvalue { i64, i1 } %2, 1 -; CHECK: %5 = inttoptr i64 %3 to ptr -; CHECK: %6 = insertvalue { ptr, i1 } poison, ptr %5, 0 -; CHECK: %7 = insertvalue { ptr, i1 } %6, i1 %4, 1 +; CHECK-LABEL: define void @pointer_cmpxchg_expand4( +; CHECK-SAME: ptr [[PTR:%.*]], ptr [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = cmpxchg weak ptr [[PTR]], i64 0, i64 [[TMP1]] seq_cst seq_cst, align 8 +; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = inttoptr i64 [[TMP3]] to ptr +; CHECK-NEXT: [[TMP6:%.*]] = insertvalue { ptr, i1 } poison, ptr [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertvalue { ptr, i1 } [[TMP6]], i1 [[TMP4]], 1 +; CHECK-NEXT: ret void +; cmpxchg weak ptr %ptr, ptr null, ptr %v seq_cst seq_cst ret void } define void @pointer_cmpxchg_expand5(ptr %ptr, ptr %v) { -; CHECK-LABEL: @pointer_cmpxchg_expand5 -; CHECK: %1 = ptrtoint ptr %v to i64 -; CHECK: %2 = cmpxchg volatile ptr %ptr, i64 0, i64 %1 seq_cst seq_cst -; CHECK: %3 = extractvalue { i64, i1 } %2, 0 -; CHECK: %4 = extractvalue { i64, i1 } %2, 1 -; CHECK: %5 = inttoptr i64 %3 to ptr -; CHECK: %6 = insertvalue { ptr, i1 } poison, ptr %5, 0 -; CHECK: %7 = insertvalue { ptr, i1 } %6, i1 %4, 1 +; CHECK-LABEL: define void @pointer_cmpxchg_expand5( +; CHECK-SAME: ptr [[PTR:%.*]], ptr [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr [[V]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = cmpxchg volatile ptr [[PTR]], i64 0, i64 [[TMP1]] seq_cst seq_cst, align 8 +; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = inttoptr i64 [[TMP3]] to ptr +; CHECK-NEXT: [[TMP6:%.*]] = insertvalue { ptr, i1 } poison, ptr [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertvalue { ptr, i1 } [[TMP6]], i1 [[TMP4]], 1 +; CHECK-NEXT: ret void +; cmpxchg volatile ptr %ptr, ptr null, ptr %v seq_cst seq_cst ret void } -define void @pointer_cmpxchg_expand6(ptr addrspace(1) %ptr, - ptr addrspace(2) %v) { -; CHECK-LABEL: @pointer_cmpxchg_expand6 -; CHECK: %1 = ptrtoint ptr addrspace(2) %v to i64 -; CHECK: %2 = cmpxchg ptr addrspace(1) %ptr, i64 0, i64 %1 seq_cst seq_cst -; CHECK: %3 = extractvalue { i64, i1 } %2, 0 -; CHECK: %4 = extractvalue { i64, i1 } %2, 1 -; CHECK: %5 = inttoptr i64 %3 to ptr addrspace(2) -; CHECK: %6 = insertvalue { ptr addrspace(2), i1 } poison, ptr addrspace(2) %5, 0 -; CHECK: %7 = insertvalue { ptr addrspace(2), i1 } %6, i1 %4, 1 +define void @pointer_cmpxchg_expand6(ptr addrspace(1) %ptr, ptr addrspace(2) %v) { +; CHECK-LABEL: define void @pointer_cmpxchg_expand6( +; CHECK-SAME: ptr addrspace(1) [[PTR:%.*]], ptr addrspace(2) [[V:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = ptrtoint ptr addrspace(2) [[V]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = cmpxchg ptr addrspace(1) [[PTR]], i64 0, i64 [[TMP1]] seq_cst seq_cst, align 8 +; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = inttoptr i64 [[TMP3]] to ptr addrspace(2) +; CHECK-NEXT: [[TMP6:%.*]] = insertvalue { ptr addrspace(2), i1 } poison, ptr addrspace(2) [[TMP5]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertvalue { ptr addrspace(2), i1 } [[TMP6]], i1 [[TMP4]], 1 +; CHECK-NEXT: ret void +; cmpxchg ptr addrspace(1) %ptr, ptr addrspace(2) null, ptr addrspace(2) %v seq_cst seq_cst ret void } diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp index 5bc516d..58a65b9 100644 --- a/llvm/unittests/IR/ConstantFPRangeTest.cpp +++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/ConstantFPRange.h" +#include "llvm/ADT/APFloat.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "gtest/gtest.h" @@ -818,4 +819,110 @@ TEST_F(ConstantFPRangeTest, getWithout) { APFloat::getLargest(Sem, /*Negative=*/true), APFloat(3.0))); } +TEST_F(ConstantFPRangeTest, cast) { + const fltSemantics &F16Sem = APFloat::IEEEhalf(); + const fltSemantics &BF16Sem = APFloat::BFloat(); + const fltSemantics &F32Sem = APFloat::IEEEsingle(); + const fltSemantics &F8NanOnlySem = APFloat::Float8E4M3FN(); + // normal -> normal (exact) + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat(1.0), APFloat(2.0)).cast(F32Sem), + ConstantFPRange::getNonNaN(APFloat(1.0f), APFloat(2.0f))); + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat(-2.0f), APFloat(-1.0f)).cast(Sem), + ConstantFPRange::getNonNaN(APFloat(-2.0), APFloat(-1.0))); + // normal -> normal (inexact) + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat(3.141592653589793), + APFloat(6.283185307179586)) + .cast(F32Sem), + ConstantFPRange::getNonNaN(APFloat(3.14159274f), APFloat(6.28318548f))); + // normal -> subnormal + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat(-5e-8), APFloat(5e-8)) + .cast(F16Sem) + .classify(), + fcSubnormal | fcZero); + // normal -> zero + EXPECT_EQ(ConstantFPRange::getNonNaN( + APFloat::getSmallestNormalized(Sem, /*Negative=*/true), + APFloat::getSmallestNormalized(Sem, /*Negative=*/false)) + .cast(F32Sem) + .classify(), + fcZero); + // normal -> inf + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat(-65536.0), APFloat(65536.0)) + .cast(F16Sem), + ConstantFPRange::getNonNaN(F16Sem)); + // nan -> qnan + EXPECT_EQ( + ConstantFPRange::getNaNOnly(Sem, /*MayBeQNaN=*/true, /*MayBeSNaN=*/false) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + EXPECT_EQ( + ConstantFPRange::getNaNOnly(Sem, /*MayBeQNaN=*/false, /*MayBeSNaN=*/true) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + EXPECT_EQ( + ConstantFPRange::getNaNOnly(Sem, /*MayBeQNaN=*/true, /*MayBeSNaN=*/true) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + // For BF16 -> F32, signaling bit is still lost. + EXPECT_EQ(ConstantFPRange::getNaNOnly(BF16Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/true) + .cast(F32Sem), + ConstantFPRange::getNaNOnly(F32Sem, /*MayBeQNaN=*/true, + /*MayBeSNaN=*/false)); + // inf -> nan only (return full set for now) + EXPECT_EQ(ConstantFPRange::getNonNaN(APFloat::getInf(Sem, /*Negative=*/true), + APFloat::getInf(Sem, /*Negative=*/false)) + .cast(F8NanOnlySem), + ConstantFPRange::getFull(F8NanOnlySem)); + // other rounding modes + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat::getSmallest(Sem, /*Negative=*/true), + APFloat::getSmallest(Sem, /*Negative=*/false)) + .cast(F32Sem, APFloat::rmTowardNegative), + ConstantFPRange::getNonNaN( + APFloat::getSmallest(F32Sem, /*Negative=*/true), + APFloat::getZero(F32Sem, /*Negative=*/false))); + EXPECT_EQ( + ConstantFPRange::getNonNaN(APFloat::getSmallest(Sem, /*Negative=*/true), + APFloat::getSmallest(Sem, /*Negative=*/false)) + .cast(F32Sem, APFloat::rmTowardPositive), + ConstantFPRange::getNonNaN( + APFloat::getZero(F32Sem, /*Negative=*/true), + APFloat::getSmallest(F32Sem, /*Negative=*/false))); + EXPECT_EQ( + ConstantFPRange::getNonNaN( + APFloat::getSmallestNormalized(Sem, /*Negative=*/true), + APFloat::getSmallestNormalized(Sem, /*Negative=*/false)) + .cast(F32Sem, APFloat::rmTowardZero), + ConstantFPRange::getNonNaN(APFloat::getZero(F32Sem, /*Negative=*/true), + APFloat::getZero(F32Sem, /*Negative=*/false))); + + EnumerateValuesInConstantFPRange( + ConstantFPRange::getFull(APFloat::Float8E4M3()), + [&](const APFloat &V) { + bool LosesInfo = false; + + APFloat DoubleV = V; + DoubleV.convert(Sem, APFloat::rmNearestTiesToEven, &LosesInfo); + ConstantFPRange DoubleCR = ConstantFPRange(V).cast(Sem); + EXPECT_TRUE(DoubleCR.contains(DoubleV)) + << "Casting " << V << " to double failed. " << DoubleCR + << " doesn't contain " << DoubleV; + + auto &FP4Sem = APFloat::Float4E2M1FN(); + APFloat FP4V = V; + FP4V.convert(FP4Sem, APFloat::rmNearestTiesToEven, &LosesInfo); + ConstantFPRange FP4CR = ConstantFPRange(V).cast(FP4Sem); + EXPECT_TRUE(FP4CR.contains(FP4V)) + << "Casting " << V << " to FP4E2M1FN failed. " << FP4CR + << " doesn't contain " << FP4V; + }, + /*IgnoreNaNPayload=*/true); +} + } // anonymous namespace diff --git a/llvm/utils/profcheck-xfail.txt b/llvm/utils/profcheck-xfail.txt index 092d63d..39ff476 100644 --- a/llvm/utils/profcheck-xfail.txt +++ b/llvm/utils/profcheck-xfail.txt @@ -73,7 +73,9 @@ CodeGen/Hexagon/loop-idiom/hexagon-memmove2.ll CodeGen/Hexagon/loop-idiom/memmove-rt-check.ll CodeGen/NVPTX/lower-ctor-dtor.ll CodeGen/RISCV/zmmul.ll +CodeGen/SPIRV/hlsl-resources/UniqueImplicitBindingNumber.ll CodeGen/WebAssembly/memory-interleave.ll +CodeGen/X86/global-variable-partition-with-dap.ll CodeGen/X86/masked_gather_scatter.ll CodeGen/X86/nocfivalue.ll DebugInfo/AArch64/ir-outliner.ll @@ -1095,6 +1097,7 @@ Transforms/LoopSimplifyCFG/invalidate-scev-dispositions.ll Transforms/LoopSimplifyCFG/lcssa.ll Transforms/LoopSimplifyCFG/live_block_marking.ll Transforms/LoopSimplifyCFG/mssa_update.ll +Transforms/LoopSimplifyCFG/pr117537.ll Transforms/LoopSimplifyCFG/update_parents.ll Transforms/LoopUnroll/peel-last-iteration-expansion-cost.ll Transforms/LoopUnroll/peel-last-iteration-with-guards.ll diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 5dd285e..2db1d84 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -38,6 +38,7 @@ DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirPatternRewriter, void); +DEFINE_C_API_STRUCT(MlirRewritePattern, const void); //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder @@ -302,11 +303,15 @@ MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter); /// FrozenRewritePatternSet API //===----------------------------------------------------------------------===// +/// Freeze the given MlirRewritePatternSet to a MlirFrozenRewritePatternSet. +/// Note that the ownership of the input set is transferred into the frozen set +/// after this call. MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet -mlirFreezeRewritePattern(MlirRewritePatternSet op); +mlirFreezeRewritePattern(MlirRewritePatternSet set); +/// Destroy the given MlirFrozenRewritePatternSet. MLIR_CAPI_EXPORTED void -mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); +mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set); MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp( MlirOperation op, MlirFrozenRewritePatternSet patterns, @@ -325,6 +330,51 @@ MLIR_CAPI_EXPORTED MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter); //===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// + +/// Callbacks to construct a rewrite pattern. +typedef struct { + /// Optional constructor for the user data. + /// Set to nullptr to disable it. + void (*construct)(void *userData); + /// Optional destructor for the user data. + /// Set to nullptr to disable it. + void (*destruct)(void *userData); + /// The callback function to match against code rooted at the specified + /// operation, and perform the rewrite if the match is successful, + /// corresponding to RewritePattern::matchAndRewrite. + MlirLogicalResult (*matchAndRewrite)(MlirRewritePattern pattern, + MlirOperation op, + MlirPatternRewriter rewriter, + void *userData); +} MlirRewritePatternCallbacks; + +/// Create a rewrite pattern that matches the operation +/// with the given rootName, corresponding to mlir::OpRewritePattern. +MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, unsigned benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames); + +//===----------------------------------------------------------------------===// +/// RewritePatternSet API +//===----------------------------------------------------------------------===// + +/// Create an empty MlirRewritePatternSet. +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetCreate(MlirContext context); + +/// Destruct the given MlirRewritePatternSet. +MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set); + +/// Add the given MlirRewritePattern into a MlirRewritePatternSet. +/// Note that the ownership of the pattern is transferred to the set after this +/// call. +MLIR_CAPI_EXPORTED void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern); + +//===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index 1038c0a..9c96d35 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -18,7 +18,19 @@ #include "mlir-c/Rewrite.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase) +DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern) +DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet) +DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet, + mlir::FrozenRewritePatternSet) +DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter) + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule) +DEFINE_C_API_PTR_METHODS(MlirPDLResultList, mlir::PDLResultList) +DEFINE_C_API_PTR_METHODS(MlirPDLValue, const mlir::PDLValue) +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH #endif // MLIR_CAPIREWRITER_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index e2a0331..89fbeb7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3233,35 +3233,15 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp : attr-dict `:` type($dstMem) `,` type($srcMem) }]; + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; string llvmBuilder = [{ - // Arguments to the intrinsic: - // dst, mbar, src, size - // multicast_mask, cache_hint, - // flag for multicast_mask, - // flag for cache_hint - llvm::SmallVector<llvm::Value *> translatedOperands; - translatedOperands.push_back($dstMem); - translatedOperands.push_back($mbar); - translatedOperands.push_back($srcMem); - translatedOperands.push_back($size); - - // Multicast, if available - llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); - auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0); - bool isMulticast = op.getMulticastMask() ? true : false; - translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused); - - // Cachehint, if available - auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); - bool isCacheHint = op.getL2CacheHint() ? true : false; - translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused); - - // Flag arguments for multicast and cachehint - translatedOperands.push_back(builder.getInt1(isMulticast)); - translatedOperands.push_back(builder.getInt1(isCacheHint)); - - createIntrinsicCall(builder, - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands); + auto [id, args] = NVVM::CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; } diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 9e3d970..47685567 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -45,6 +45,16 @@ public: return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } + void replaceOp(MlirOperation op, MlirOperation newOp) { + mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); + } + + void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) { + mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); + } + + void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); } + private: MlirRewriterBase base; PyMlirContextRef ctx; @@ -165,13 +175,115 @@ private: MlirFrozenRewritePatternSet set; }; +class PyRewritePatternSet { +public: + PyRewritePatternSet(MlirContext ctx) + : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} + ~PyRewritePatternSet() { + if (set.ptr) + mlirRewritePatternSetDestroy(set); + } + + void add(MlirStringRef rootName, unsigned benefit, + const nb::callable &matchAndRewrite) { + MlirRewritePatternCallbacks callbacks; + callbacks.construct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).inc_ref(); + }; + callbacks.destruct = [](void *userData) { + nb::handle(static_cast<PyObject *>(userData)).dec_ref(); + }; + callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, + MlirPatternRewriter rewriter, + void *userData) -> MlirLogicalResult { + nb::handle f(static_cast<PyObject *>(userData)); + nb::object res = f(op, PyPatternRewriter(rewriter)); + return logicalResultFromObject(res); + }; + MlirRewritePattern pattern = mlirOpRewritePattenCreate( + rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), + /* nGeneratedNames */ 0, + /* generatedNames */ nullptr); + mlirRewritePatternSetAdd(set, pattern); + } + + PyFrozenRewritePatternSet freeze() { + MlirRewritePatternSet s = set; + set.ptr = nullptr; + return mlirFreezeRewritePattern(s); + } + +private: + MlirRewritePatternSet set; + MlirContext ctx; +}; + } // namespace /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { - nb::class_<PyPatternRewriter>(m, "PatternRewriter") - .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, - "The current insertion point of the PatternRewriter."); + //---------------------------------------------------------------------------- + // Mapping of the PatternRewriter + //---------------------------------------------------------------------------- + nb:: + class_<PyPatternRewriter>(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter.") + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + MlirOperation newOp) { self.replaceOp(op, newOp); }, + "Replace an operation with a new operation.", nb::arg("op"), + nb::arg("new_op"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ) + .def( + "replace_op", + [](PyPatternRewriter &self, MlirOperation op, + const std::vector<MlirValue> &values) { + self.replaceOp(op, values); + }, + "Replace an operation with a list of values.", nb::arg("op"), + nb::arg("values"), + // clang-format off + nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None") + // clang-format on + ) + .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", + nb::arg("op"), + // clang-format off + nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None") + // clang-format on + ); + + //---------------------------------------------------------------------------- + // Mapping of the RewritePatternSet + //---------------------------------------------------------------------------- + nb::class_<PyRewritePatternSet>(m, "RewritePatternSet") + .def( + "__init__", + [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { + new (&self) PyRewritePatternSet(context.get()->get()); + }, + "context"_a = nb::none()) + .def( + "add", + [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, + unsigned benefit) { + std::string opName = + nb::cast<std::string>(root.attr("OPERATION_NAME")); + self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, + fn); + }, + "root"_a, "fn"_a, "benefit"_a = 1, + "Add a new rewrite pattern on the given root operation with the " + "callable as the matching and rewriting function and the given " + "benefit.") + .def("freeze", &PyRewritePatternSet::freeze, + "Freeze the pattern set into a frozen one."); + //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- @@ -237,7 +349,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def( "freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index c15a73b..46c329d 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,35 +270,16 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); -} - -static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { - return {module}; -} - -static inline mlir::FrozenRewritePatternSet * -unwrap(MlirFrozenRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); -} - -static inline MlirFrozenRewritePatternSet -wrap(mlir::FrozenRewritePatternSet *module) { - return {module}; -} - -MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { - auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); - op.ptr = nullptr; +MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet set) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); + set.ptr = nullptr; return wrap(m); } -void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { - delete unwrap(op); - op.ptr = nullptr; +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) { + delete unwrap(set); + set.ptr = nullptr; } MlirLogicalResult @@ -319,33 +300,86 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, /// PatternRewriter API //===----------------------------------------------------------------------===// -inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { - assert(rewriter.ptr && "unexpected null rewriter"); - return static_cast<mlir::PatternRewriter *>(rewriter.ptr); +MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { + return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); } -inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { - return {rewriter}; -} +//===----------------------------------------------------------------------===// +/// RewritePattern API +//===----------------------------------------------------------------------===// -MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { - return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); +namespace mlir { + +class ExternalRewritePattern : public mlir::RewritePattern { +public: + ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData, + StringRef rootName, PatternBenefit benefit, + MLIRContext *context, + ArrayRef<StringRef> generatedNames) + : RewritePattern(rootName, benefit, context, generatedNames), + callbacks(callbacks), userData(userData) { + if (callbacks.construct) + callbacks.construct(userData); + } + + ~ExternalRewritePattern() { + if (callbacks.destruct) + callbacks.destruct(userData); + } + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + return unwrap(callbacks.matchAndRewrite( + wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op), + wrap(&rewriter), userData)); + } + +private: + MlirRewritePatternCallbacks callbacks; + void *userData; +}; + +} // namespace mlir + +MlirRewritePattern mlirOpRewritePattenCreate( + MlirStringRef rootName, unsigned benefit, MlirContext context, + MlirRewritePatternCallbacks callbacks, void *userData, + size_t nGeneratedNames, MlirStringRef *generatedNames) { + std::vector<mlir::StringRef> generatedNamesVec; + generatedNamesVec.reserve(nGeneratedNames); + for (size_t i = 0; i < nGeneratedNames; ++i) { + generatedNamesVec.push_back(unwrap(generatedNames[i])); + } + return wrap(new mlir::ExternalRewritePattern( + callbacks, userData, unwrap(rootName), PatternBenefit(benefit), + unwrap(context), generatedNamesVec)); } //===----------------------------------------------------------------------===// -/// PDLPatternModule API +/// RewritePatternSet API //===----------------------------------------------------------------------===// -#if MLIR_ENABLE_PDL_IN_PATTERNMATCH -static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::PDLPatternModule *>(module.ptr); +MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) { + return wrap(new mlir::RewritePatternSet(unwrap(context))); +} + +void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) { + delete unwrap(set); } -static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { - return {module}; +void mlirRewritePatternSetAdd(MlirRewritePatternSet set, + MlirRewritePattern pattern) { + std::unique_ptr<mlir::RewritePattern> patternPtr( + const_cast<mlir::RewritePattern *>(unwrap(pattern))); + pattern.ptr = nullptr; + unwrap(set)->add(std::move(patternPtr)); } +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); @@ -363,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { return wrap(m); } -inline const mlir::PDLValue *unwrap(MlirPDLValue value) { - assert(value.ptr && "unexpected null PDL value"); - return static_cast<const mlir::PDLValue *>(value.ptr); -} - -inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } - -inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { - assert(results.ptr && "unexpected null PDL results"); - return static_cast<mlir::PDLResultList *>(results.ptr); -} - -inline MlirPDLResultList wrap(mlir::PDLResultList *results) { - return {results}; -} - MlirValue mlirPDLValueAsValue(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast<mlir::Value>()); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index a50ddbe..624519f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -55,16 +55,6 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { return returnOp; } -/// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null<func::FuncOp>( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - LogicalResult mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { IRRewriter rewriter(module.getContext()); @@ -72,7 +62,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; // Collect the mapping of functions to their call sites. module.walk([&](func::CallOp callOp) { - if (func::FuncOp calledFunc = getCalledFunction(callOp)) { + if (func::FuncOp calledFunc = + dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { callerMap[calledFunc].insert(callOp); } }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7f419a0..5edcc40b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1593,6 +1593,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } +mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args: dst, mbar, src, size. + args.push_back(mt.lookupValue(thisOp.getDstMem())); + args.push_back(mt.lookupValue(thisOp.getMbar())); + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + // Multicast mask, if available. + mlir::Value multicastMask = thisOp.getMulticastMask(); + const bool hasMulticastMask = static_cast<bool>(multicastMask); + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + + // Cache hint, if available. + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + + // Flag arguments for multicast and cachehint. + args.push_back(builder.getInt1(hasMulticastMask)); + args.push_back(builder.getInt1(hasCacheHint)); + + llvm::Intrinsic::ID id = + llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op); diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py new file mode 100644 index 0000000..acf7db2 --- /dev/null +++ b/mlir/test/python/rewrite.py @@ -0,0 +1,69 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.ir import * +from mlir.passmanager import * +from mlir.dialects.builtin import ModuleOp +from mlir.dialects import arith +from mlir.rewrite import * + + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testRewritePattern +@run +def testRewritePattern(): + def to_muli(op, rewriter): + with rewriter.ip: + new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location) + rewriter.replace_op(op, new_op.owner) + + def constant_1_to_2(op, rewriter): + c = op.attributes["value"].value + if c != 1: + return True # failed to match + with rewriter.ip: + new_op = arith.constant(op.result.type, 2, loc=op.location) + rewriter.replace_op(op, [new_op]) + + with Context(): + patterns = RewritePatternSet() + patterns.add(arith.AddIOp, to_muli) + patterns.add(arith.ConstantOp, constant_1_to_2) + frozen = patterns.freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %0 = arith.muli %arg0, %arg1 : i64 + # CHECK: return %0 : i64 + print(module) + + module = ModuleOp.parse( + r""" + module { + func.func @const() -> (i64, i64) { + %0 = arith.constant 1 : i64 + %1 = arith.constant 3 : i64 + return %0, %1 : i64, i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %c2_i64 = arith.constant 2 : i64 + # CHECK: %c3_i64 = arith.constant 3 : i64 + # CHECK: return %c2_i64, %c3_i64 : i64, i64 + print(module) |