diff options
Diffstat (limited to 'llvm/lib')
126 files changed, 2202 insertions, 2026 deletions
| diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index e9e2e7d..da32542 100755 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -2163,18 +2163,42 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),  }  Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) { -  FixedVectorType *VT = dyn_cast<FixedVectorType>(Op->getType()); -  if (!VT) -    return nullptr; - -  // This isn't strictly necessary, but handle the special/common case of zero: -  // all integer reductions of a zero input produce zero. -  if (isa<ConstantAggregateZero>(Op)) -    return ConstantInt::get(VT->getElementType(), 0); +  auto *OpVT = cast<VectorType>(Op->getType());    // This is the same as the underlying binops - poison propagates. -  if (isa<PoisonValue>(Op) || Op->containsPoisonElement()) -    return PoisonValue::get(VT->getElementType()); +  if (Op->containsPoisonElement()) +    return PoisonValue::get(OpVT->getElementType()); + +  // Shortcut non-accumulating reductions. +  if (Constant *SplatVal = Op->getSplatValue()) { +    switch (IID) { +    case Intrinsic::vector_reduce_and: +    case Intrinsic::vector_reduce_or: +    case Intrinsic::vector_reduce_smin: +    case Intrinsic::vector_reduce_smax: +    case Intrinsic::vector_reduce_umin: +    case Intrinsic::vector_reduce_umax: +      return SplatVal; +    case Intrinsic::vector_reduce_add: +      if (SplatVal->isNullValue()) +        return SplatVal; +      break; +    case Intrinsic::vector_reduce_mul: +      if (SplatVal->isNullValue() || SplatVal->isOneValue()) +        return SplatVal; +      break; +    case Intrinsic::vector_reduce_xor: +      if (SplatVal->isNullValue()) +        return SplatVal; +      if (OpVT->getElementCount().isKnownMultipleOf(2)) +        return Constant::getNullValue(OpVT->getElementType()); +      break; +    } +  } + +  FixedVectorType *VT = dyn_cast<FixedVectorType>(OpVT); +  if (!VT) +    return nullptr;    // TODO: Handle undef.    auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U)); diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp index 84ee8c0..11d8294 100644 --- a/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -2854,14 +2854,18 @@ bool DependenceInfo::testMIV(const SCEV *Src, const SCEV *Dst,           banerjeeMIVtest(Src, Dst, Loops, Result);  } -// Given a product, e.g., 10*X*Y, returns the first constant operand, -// in this case 10. If there is no constant part, returns std::nullopt. -static std::optional<APInt> getConstantPart(const SCEV *Expr) { +/// Given a SCEVMulExpr, returns its first operand if its first operand is a +/// constant and the product doesn't overflow in a signed sense. Otherwise, +/// returns std::nullopt. For example, given (10 * X * Y)<nsw>, it returns 10. +/// Notably, if it doesn't have nsw, the multiplication may overflow, and if +/// so, it may not a multiple of 10. +static std::optional<APInt> getConstanCoefficient(const SCEV *Expr) {    if (const auto *Constant = dyn_cast<SCEVConstant>(Expr))      return Constant->getAPInt();    if (const auto *Product = dyn_cast<SCEVMulExpr>(Expr))      if (const auto *Constant = dyn_cast<SCEVConstant>(Product->getOperand(0))) -      return Constant->getAPInt(); +      if (Product->hasNoSignedWrap()) +        return Constant->getAPInt();    return std::nullopt;  } @@ -2887,7 +2891,7 @@ bool DependenceInfo::accumulateCoefficientsGCD(const SCEV *Expr,    if (AddRec->getLoop() == CurLoop) {      CurLoopCoeff = Step;    } else { -    std::optional<APInt> ConstCoeff = getConstantPart(Step); +    std::optional<APInt> ConstCoeff = getConstanCoefficient(Step);      // If the coefficient is the product of a constant and other stuff, we can      // use the constant in the GCD computation. @@ -2940,7 +2944,7 @@ bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst,      const SCEV *Coeff = AddRec->getStepRecurrence(*SE);      // If the coefficient is the product of a constant and other stuff,      // we can use the constant in the GCD computation. -    std::optional<APInt> ConstCoeff = getConstantPart(Coeff); +    std::optional<APInt> ConstCoeff = getConstanCoefficient(Coeff);      if (!ConstCoeff)        return false;      RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff->abs()); @@ -2958,7 +2962,7 @@ bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst,      const SCEV *Coeff = AddRec->getStepRecurrence(*SE);      // If the coefficient is the product of a constant and other stuff,      // we can use the constant in the GCD computation. -    std::optional<APInt> ConstCoeff = getConstantPart(Coeff); +    std::optional<APInt> ConstCoeff = getConstanCoefficient(Coeff);      if (!ConstCoeff)        return false;      RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff->abs()); @@ -2979,7 +2983,7 @@ bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst,        } else if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Operand)) {          // Search for constant operand to participate in GCD;          // If none found; return false. -        std::optional<APInt> ConstOp = getConstantPart(Product); +        std::optional<APInt> ConstOp = getConstanCoefficient(Product);          if (!ConstOp)            return false;          ExtraGCD = APIntOps::GreatestCommonDivisor(ExtraGCD, ConstOp->abs()); @@ -3032,7 +3036,7 @@ bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst,      Delta = SE->getMinusSCEV(SrcCoeff, DstCoeff);      // If the coefficient is the product of a constant and other stuff,      // we can use the constant in the GCD computation. -    std::optional<APInt> ConstCoeff = getConstantPart(Delta); +    std::optional<APInt> ConstCoeff = getConstanCoefficient(Delta);      if (!ConstCoeff)        // The difference of the two coefficients might not be a product        // or constant, in which case we give up on this direction. diff --git a/llvm/lib/Analysis/HashRecognize.cpp b/llvm/lib/Analysis/HashRecognize.cpp index 4529123..8974ce5 100644 --- a/llvm/lib/Analysis/HashRecognize.cpp +++ b/llvm/lib/Analysis/HashRecognize.cpp @@ -468,8 +468,11 @@ std::variant<PolynomialInfo, StringRef> HashRecognize::recognizeCRC() const {      // Ensure that the PHIs have exactly two uses:      // the bit-shift, and the XOR (or a cast feeding into the XOR). +    // Also ensure that the SimpleRecurrence's evolution doesn't have stray +    // users.      if (!ConditionalRecurrence.Phi->hasNUses(2) || -        !SimpleRecurrence.Phi->hasNUses(2)) +        !SimpleRecurrence.Phi->hasNUses(2) || +        SimpleRecurrence.BO->getUniqueUndroppableUser() != SimpleRecurrence.Phi)        return "Recurrences have stray uses";      // Check that the SelectInst ConditionalRecurrence.Step is conditional on diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 5164cec..8e3ce49 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -4538,6 +4538,9 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {        if (!Indices.empty() && !Ty->isSized(&Visited))          return error(ID.Loc, "base element of getelementptr must be sized"); +      if (!ConstantExpr::isSupportedGetElementPtr(Ty)) +        return error(ID.Loc, "invalid base element for constant getelementptr"); +        if (!GetElementPtrInst::getIndexedType(Ty, Indices))          return error(ID.Loc, "invalid getelementptr indices"); @@ -5639,16 +5642,17 @@ bool LLParser::parseDIBasicType(MDNode *&Result, bool IsDistinct) {    OPTIONAL(name, MDStringField, );                                             \    OPTIONAL(size, MDUnsignedOrMDField, (0, UINT64_MAX));                        \    OPTIONAL(align, MDUnsignedField, (0, UINT32_MAX));                           \ +  OPTIONAL(dataSize, MDUnsignedField, (0, UINT32_MAX));                        \    OPTIONAL(encoding, DwarfAttEncodingField, );                                 \    OPTIONAL(num_extra_inhabitants, MDUnsignedField, (0, UINT32_MAX));           \    OPTIONAL(flags, DIFlagField, );    PARSE_MD_FIELDS();  #undef VISIT_MD_FIELDS -  Result = GET_OR_DISTINCT(DIBasicType, (Context, tag.Val, name.Val, -                                         size.getValueAsMetadata(Context), -                                         align.Val, encoding.Val, -                                         num_extra_inhabitants.Val, flags.Val)); +  Result = GET_OR_DISTINCT( +      DIBasicType, +      (Context, tag.Val, name.Val, size.getValueAsMetadata(Context), align.Val, +       encoding.Val, num_extra_inhabitants.Val, dataSize.Val, flags.Val));    return false;  } @@ -6341,8 +6345,8 @@ bool LLParser::parseDIObjCProperty(MDNode *&Result, bool IsDistinct) {  #undef VISIT_MD_FIELDS    Result = GET_OR_DISTINCT(DIObjCProperty, -                           (Context, name.Val, file.Val, line.Val, setter.Val, -                            getter.Val, attributes.Val, type.Val)); +                           (Context, name.Val, file.Val, line.Val, getter.Val, +                            setter.Val, attributes.Val, type.Val));    return false;  } diff --git a/llvm/lib/Bitcode/Reader/MetadataLoader.cpp b/llvm/lib/Bitcode/Reader/MetadataLoader.cpp index ed0443f..c63dc8f 100644 --- a/llvm/lib/Bitcode/Reader/MetadataLoader.cpp +++ b/llvm/lib/Bitcode/Reader/MetadataLoader.cpp @@ -1531,7 +1531,7 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(      break;    }    case bitc::METADATA_BASIC_TYPE: { -    if (Record.size() < 6 || Record.size() > 8) +    if (Record.size() < 6 || Record.size() > 9)        return error("Invalid record");      IsDistinct = Record[0] & 1; @@ -1540,13 +1540,13 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(                                  ? static_cast<DINode::DIFlags>(Record[6])                                  : DINode::FlagZero;      uint32_t NumExtraInhabitants = (Record.size() > 7) ? Record[7] : 0; - +    uint32_t DataSizeInBits = (Record.size() > 8) ? Record[8] : 0;      Metadata *SizeInBits = getMetadataOrConstant(SizeIsMetadata, Record[3]); -      MetadataList.assignValue(          GET_OR_DISTINCT(DIBasicType,                          (Context, Record[1], getMDString(Record[2]), SizeInBits, -                         Record[4], Record[5], NumExtraInhabitants, Flags)), +                         Record[4], Record[5], NumExtraInhabitants, +                         DataSizeInBits, Flags)),          NextMetadataNo);      NextMetadataNo++;      break; @@ -2323,8 +2323,9 @@ Error MetadataLoader::MetadataLoaderImpl::parseOneMetadata(          GET_OR_DISTINCT(DIObjCProperty,                          (Context, getMDString(Record[1]),                           getMDOrNull(Record[2]), Record[3], -                         getMDString(Record[4]), getMDString(Record[5]), -                         Record[6], getDITypeRefOrNull(Record[7]))), +                         /*GetterName=*/getMDString(Record[5]), +                         /*SetterName=*/getMDString(Record[4]), Record[6], +                         getDITypeRefOrNull(Record[7]))),          NextMetadataNo);      NextMetadataNo++;      break; diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index 61aa7c2f5..f17656c 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -1925,6 +1925,7 @@ void ModuleBitcodeWriter::writeDIBasicType(const DIBasicType *N,    Record.push_back(N->getEncoding());    Record.push_back(N->getFlags());    Record.push_back(N->getNumExtraInhabitants()); +  Record.push_back(N->getDataSizeInBits());    Stream.EmitRecord(bitc::METADATA_BASIC_TYPE, Record, Abbrev);    Record.clear(); diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index 8aa488f..f65d88a 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -1443,7 +1443,7 @@ getBBAddrMapFeature(const MachineFunction &MF, int NumMBBSectionRanges,            MF.hasBBSections() && NumMBBSectionRanges > 1,            // Use static_cast to avoid breakage of tests on windows.            static_cast<bool>(BBAddrMapSkipEmitBBEntries), HasCalls, -          static_cast<bool>(EmitBBHash)}; +          static_cast<bool>(EmitBBHash), false};  }  void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) { diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfCompileUnit.cpp b/llvm/lib/CodeGen/AsmPrinter/DwarfCompileUnit.cpp index 518121e..751d373 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfCompileUnit.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfCompileUnit.cpp @@ -1793,9 +1793,13 @@ void DwarfCompileUnit::createBaseTypeDIEs() {                      "_" + Twine(Btr.BitSize)).toStringRef(Str));      addUInt(Die, dwarf::DW_AT_encoding, dwarf::DW_FORM_data1, Btr.Encoding);      // Round up to smallest number of bytes that contains this number of bits. +    // ExprRefedBaseTypes is populated with types referenced by +    // DW_OP_LLVM_convert operations in location expressions. These are often +    // byte-sized, but one common counter-example is 1-bit sized conversions +    // from `i1` types. TODO: Should these use DW_AT_bit_size? See +    // DwarfUnit::constructTypeDIE.      addUInt(Die, dwarf::DW_AT_byte_size, std::nullopt,              divideCeil(Btr.BitSize, 8)); -      Btr.Die = &Die;    }  } diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp b/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp index e40fb76..b16e1315 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfUnit.cpp @@ -766,8 +766,19 @@ void DwarfUnit::constructTypeDIE(DIE &Buffer, const DIBasicType *BTy) {      addUInt(Buffer, dwarf::DW_AT_encoding, dwarf::DW_FORM_data1,              BTy->getEncoding()); -  uint64_t Size = BTy->getSizeInBits() >> 3; -  addUInt(Buffer, dwarf::DW_AT_byte_size, std::nullopt, Size); +  uint64_t SizeInBytes = divideCeil(BTy->getSizeInBits(), 8); +  addUInt(Buffer, dwarf::DW_AT_byte_size, std::nullopt, SizeInBytes); +  if (BTy->getTag() == dwarf::Tag::DW_TAG_base_type) { +    // DW_TAG_base_type: +    // If the value of an object of the given type does not fully occupy the +    // storage described by a byte size attribute, the base type entry may also +    // have a DW_AT_bit_size [...] attribute. +    // TODO: Do big endian targets need DW_AT_data_bit_offset? See discussion in +    // pull request #164372. +    if (uint64_t DataSizeInBits = BTy->getDataSizeInBits(); +        DataSizeInBits && DataSizeInBits != SizeInBytes * 8) +      addUInt(Buffer, dwarf::DW_AT_bit_size, std::nullopt, DataSizeInBits); +  }    if (BTy->isBigEndian())      addUInt(Buffer, dwarf::DW_AT_endianity, std::nullopt, dwarf::DW_END_big); @@ -1109,7 +1120,7 @@ void DwarfUnit::constructTypeDIE(DIE &Buffer, const DICompositeType *CTy) {            constructMemberDIE(Buffer, DDTy);          }        } else if (auto *Property = dyn_cast<DIObjCProperty>(Element)) { -        DIE &ElemDie = createAndAddDIE(Property->getTag(), Buffer); +        DIE &ElemDie = createAndAddDIE(Property->getTag(), Buffer, Property);          StringRef PropertyName = Property->getName();          addString(ElemDie, dwarf::DW_AT_APPLE_property_name, PropertyName);          if (Property->getType()) diff --git a/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp b/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp index fbcd614..485b44ae 100644 --- a/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp +++ b/llvm/lib/CodeGen/BasicBlockSectionsProfileReader.cpp @@ -287,6 +287,25 @@ Error BasicBlockSectionsProfileReader::ReadV1Profile() {        }        continue;      } +    case 'h': { // Basic block hash secifier. +      // Skip the profile when the profile iterator (FI) refers to the +      // past-the-end element. +      if (FI == ProgramPathAndClusterInfo.end()) +        continue; +      for (auto BBIDHashStr : Values) { +        auto [BBIDStr, HashStr] = BBIDHashStr.split(':'); +        unsigned long long BBID = 0, Hash = 0; +        if (getAsUnsignedInteger(BBIDStr, 10, BBID)) +          return createProfileParseError(Twine("unsigned integer expected: '") + +                                         BBIDStr + "'"); +        if (getAsUnsignedInteger(HashStr, 16, Hash)) +          return createProfileParseError( +              Twine("unsigned integer expected in hex format: '") + HashStr + +              "'"); +        FI->second.BBHashes[BBID] = Hash; +      } +      continue; +    }      default:        return createProfileParseError(Twine("invalid specifier: '") +                                       Twine(Specifier) + "'"); diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index ca82857..5fab6ec 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1893,6 +1893,8 @@ static bool canCreateUndefOrPoison(Register Reg, const MachineRegisterInfo &MRI,    case TargetOpcode::G_UADDSAT:    case TargetOpcode::G_SSUBSAT:    case TargetOpcode::G_USUBSAT: +  case TargetOpcode::G_SBFX: +  case TargetOpcode::G_UBFX:      return false;    case TargetOpcode::G_SSHLSAT:    case TargetOpcode::G_USHLSAT: diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index cf221bb..bdd6bf0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2715,6 +2715,12 @@ SDValue DAGCombiner::visitPTRADD(SDNode *N) {            (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap;        SDValue Add = DAG.getNode(ISD::ADD, DL, IntVT, {Y, Z}, Flags);        AddToWorklist(Add.getNode()); +      // We can't set InBounds even if both original ptradds were InBounds and +      // NUW: SDAG usually represents pointers as integers, therefore, the +      // matched pattern behaves as if it had implicit casts: +      //   (ptradd inbounds (inttoptr (ptrtoint (ptradd inbounds x, y))), z) +      // The outer inbounds ptradd might therefore rely on a provenance that x +      // does not have.        return DAG.getMemBasePlusOffset(X, Add, DL, Flags);      }    } @@ -2740,6 +2746,12 @@ SDValue DAGCombiner::visitPTRADD(SDNode *N) {          // that.          SDNodeFlags Flags =              (N->getFlags() & N0->getFlags()) & SDNodeFlags::NoUnsignedWrap; +        // We can't set InBounds even if both original ptradds were InBounds and +        // NUW: SDAG usually represents pointers as integers, therefore, the +        // matched pattern behaves as if it had implicit casts: +        //   (ptradd inbounds (inttoptr (ptrtoint (ptradd inbounds GA, v))), c) +        // The outer inbounds ptradd might therefore rely on a provenance that +        // GA does not have.          SDValue Inner = DAG.getMemBasePlusOffset(GAValue, N1, DL, Flags);          AddToWorklist(Inner.getNode());          return DAG.getMemBasePlusOffset(Inner, N0.getOperand(1), DL, Flags); @@ -2763,8 +2775,13 @@ SDValue DAGCombiner::visitPTRADD(SDNode *N) {      bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);      // If both additions in the original were NUW, reassociation preserves that. -    SDNodeFlags ReassocFlags = -        (N->getFlags() & N1->getFlags()) & SDNodeFlags::NoUnsignedWrap; +    SDNodeFlags CommonFlags = N->getFlags() & N1->getFlags(); +    SDNodeFlags ReassocFlags = CommonFlags & SDNodeFlags::NoUnsignedWrap; +    if (CommonFlags.hasNoUnsignedWrap()) { +      // If both operations are NUW and the PTRADD is inbounds, the offests are +      // both non-negative, so the reassociated PTRADDs are also inbounds. +      ReassocFlags |= N->getFlags() & SDNodeFlags::InBounds; +    }      if (ZIsConstant != YIsConstant) {        if (YIsConstant) @@ -22743,7 +22760,10 @@ SDValue DAGCombiner::replaceStoreOfInsertLoad(StoreSDNode *ST) {      NewPtr = DAG.getMemBasePlusOffset(Ptr, TypeSize::getFixed(COffset), DL);      PointerInfo = ST->getPointerInfo().getWithOffset(COffset);    } else { -    NewPtr = TLI.getVectorElementPointer(DAG, Ptr, Value.getValueType(), Idx); +    // The original DAG loaded the entire vector from memory, so arithmetic +    // within it must be inbounds. +    NewPtr = TLI.getInboundsVectorElementPointer(DAG, Ptr, Value.getValueType(), +                                                 Idx);    }    return DAG.getStore(Chain, DL, Elt, NewPtr, PointerInfo, ST->getAlign(), @@ -23506,6 +23526,93 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {      // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >      if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))        return DAG.getSplat(VT, DL, InVal); + +    // Extend this type to be byte-addressable +    EVT OldVT = VT; +    EVT EltVT = VT.getVectorElementType(); +    bool IsByteSized = EltVT.isByteSized(); +    if (!IsByteSized) { +      EltVT = +          EltVT.changeTypeToInteger().getRoundIntegerType(*DAG.getContext()); +      VT = VT.changeElementType(EltVT); +    } + +    // Check if this operation will be handled the default way for its type. +    auto IsTypeDefaultHandled = [this](EVT VT) { +      return TLI.getTypeAction(*DAG.getContext(), VT) == +                 TargetLowering::TypeSplitVector || +             TLI.isOperationExpand(ISD::INSERT_VECTOR_ELT, VT); +    }; + +    // Check if this operation is illegal and will be handled the default way, +    // even after extending the type to be byte-addressable. +    if (IsTypeDefaultHandled(OldVT) && IsTypeDefaultHandled(VT)) { +      // For each dynamic insertelt, the default way will save the vector to +      // the stack, store at an offset, and load the modified vector. This can +      // dramatically increase code size if we have a chain of insertelts on a +      // large vector: requiring O(V*C) stores/loads where V = length of +      // vector and C is length of chain. If each insertelt is only fed into the +      // next, the vector is write-only across this chain, and we can just +      // save once before the chain and load after in O(V + C) operations. +      SmallVector<SDNode *> Seq{N}; +      unsigned NumDynamic = 1; +      while (true) { +        SDValue InVec = Seq.back()->getOperand(0); +        if (InVec.getOpcode() != ISD::INSERT_VECTOR_ELT) +          break; +        Seq.push_back(InVec.getNode()); +        NumDynamic += !isa<ConstantSDNode>(InVec.getOperand(2)); +      } + +      // It always and only makes sense to lower this sequence when we have more +      // than one dynamic insertelt, since we will not have more than V constant +      // insertelts, so we will be reducing the total number of stores+loads. +      if (NumDynamic > 1) { +        // In cases where the vector is illegal it will be broken down into +        // parts and stored in parts - we should use the alignment for the +        // smallest part. +        Align SmallestAlign = DAG.getReducedAlign(VT, /*UseABI=*/false); +        SDValue StackPtr = +            DAG.CreateStackTemporary(VT.getStoreSize(), SmallestAlign); +        auto &MF = DAG.getMachineFunction(); +        int FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex(); +        auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex); + +        // Save the vector to the stack +        SDValue InVec = Seq.back()->getOperand(0); +        if (!IsByteSized) +          InVec = DAG.getNode(ISD::ANY_EXTEND, DL, VT, InVec); +        SDValue Store = DAG.getStore(DAG.getEntryNode(), DL, InVec, StackPtr, +                                     PtrInfo, SmallestAlign); + +        // Lower each dynamic insertelt to a store +        for (SDNode *N : reverse(Seq)) { +          SDValue Elmnt = N->getOperand(1); +          SDValue Index = N->getOperand(2); + +          // Check if we have to extend the element type +          if (!IsByteSized && Elmnt.getValueType().bitsLT(EltVT)) +            Elmnt = DAG.getNode(ISD::ANY_EXTEND, DL, EltVT, Elmnt); + +          // Store the new element. This may be larger than the vector element +          // type, so use a truncating store. +          SDValue EltPtr = +              TLI.getVectorElementPointer(DAG, StackPtr, VT, Index); +          EVT EltVT = Elmnt.getValueType(); +          Store = DAG.getTruncStore( +              Store, DL, Elmnt, EltPtr, MachinePointerInfo::getUnknownStack(MF), +              EltVT, +              commonAlignment(SmallestAlign, EltVT.getFixedSizeInBits() / 8)); +        } + +        // Load the saved vector from the stack +        SDValue Load = +            DAG.getLoad(VT, DL, Store, StackPtr, PtrInfo, SmallestAlign); +        SDValue LoadV = Load.getValue(0); +        return IsByteSized ? LoadV : DAG.getAnyExtOrTrunc(LoadV, DL, OldVT); +      } +    } +      return SDValue();    } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index da4e409..9bdf822 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -10668,19 +10668,20 @@ static SDValue clampDynamicVectorIndex(SelectionDAG &DAG, SDValue Idx,                       DAG.getConstant(MaxIndex, dl, IdxVT));  } -SDValue TargetLowering::getVectorElementPointer(SelectionDAG &DAG, -                                                SDValue VecPtr, EVT VecVT, -                                                SDValue Index) const { +SDValue +TargetLowering::getVectorElementPointer(SelectionDAG &DAG, SDValue VecPtr, +                                        EVT VecVT, SDValue Index, +                                        const SDNodeFlags PtrArithFlags) const {    return getVectorSubVecPointer(        DAG, VecPtr, VecVT,        EVT::getVectorVT(*DAG.getContext(), VecVT.getVectorElementType(), 1), -      Index); +      Index, PtrArithFlags);  } -SDValue TargetLowering::getVectorSubVecPointer(SelectionDAG &DAG, -                                               SDValue VecPtr, EVT VecVT, -                                               EVT SubVecVT, -                                               SDValue Index) const { +SDValue +TargetLowering::getVectorSubVecPointer(SelectionDAG &DAG, SDValue VecPtr, +                                       EVT VecVT, EVT SubVecVT, SDValue Index, +                                       const SDNodeFlags PtrArithFlags) const {    SDLoc dl(Index);    // Make sure the index type is big enough to compute in.    Index = DAG.getZExtOrTrunc(Index, dl, VecPtr.getValueType()); @@ -10704,7 +10705,7 @@ SDValue TargetLowering::getVectorSubVecPointer(SelectionDAG &DAG,    Index = DAG.getNode(ISD::MUL, dl, IdxVT, Index,                        DAG.getConstant(EltSize, dl, IdxVT)); -  return DAG.getMemBasePlusOffset(VecPtr, Index, dl); +  return DAG.getMemBasePlusOffset(VecPtr, Index, dl, PtrArithFlags);  }  //===----------------------------------------------------------------------===// @@ -12382,8 +12383,10 @@ SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,        !IsFast)      return SDValue(); -  SDValue NewPtr = -      getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo); +  // The original DAG loaded the entire vector from memory, so arithmetic +  // within it must be inbounds. +  SDValue NewPtr = getInboundsVectorElementPointer( +      DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);    // We are replacing a vector load with a scalar load. The new load must have    // identical memory op ordering to the original. diff --git a/llvm/lib/DebugInfo/DWARF/DWARFDie.cpp b/llvm/lib/DebugInfo/DWARF/DWARFDie.cpp index db5cc37..6c78ef0 100644 --- a/llvm/lib/DebugInfo/DWARF/DWARFDie.cpp +++ b/llvm/lib/DebugInfo/DWARF/DWARFDie.cpp @@ -129,6 +129,25 @@ prettyLanguageVersionString(const DWARFAttribute &AttrValue,        static_cast<SourceLanguageName>(*LName), *LVersion);  } +static llvm::Expected<llvm::StringRef> +getApplePropertyName(const DWARFDie &PropDIE) { +  if (!PropDIE) +    return llvm::createStringError("invalid DIE"); + +  if (PropDIE.getTag() != DW_TAG_APPLE_property) +    return llvm::createStringError("not referencing a DW_TAG_APPLE_property"); + +  auto PropNameForm = PropDIE.find(DW_AT_APPLE_property_name); +  if (!PropNameForm) +    return ""; + +  auto NameOrErr = PropNameForm->getAsCString(); +  if (!NameOrErr) +    return NameOrErr.takeError(); + +  return *NameOrErr; +} +  static void dumpAttribute(raw_ostream &OS, const DWARFDie &Die,                            const DWARFAttribute &AttrValue, unsigned Indent,                            DIDumpOptions DumpOpts) { @@ -233,6 +252,15 @@ static void dumpAttribute(raw_ostream &OS, const DWARFDie &Die,              Die.getAttributeValueAsReferencedDie(FormValue).getName(                  DINameKind::LinkageName))        OS << Space << "\"" << Name << '\"'; +  } else if (Attr == DW_AT_APPLE_property) { +    auto PropDIE = Die.getAttributeValueAsReferencedDie(FormValue); +    if (auto PropNameOrErr = getApplePropertyName(PropDIE)) +      OS << Space << "\"" << *PropNameOrErr << '\"'; +    else +      DumpOpts.RecoverableErrorHandler(createStringError( +          errc::invalid_argument, +          llvm::formatv("decoding DW_AT_APPLE_property_name: {}", +                        toString(PropNameOrErr.takeError()))));    } else if (Attr == DW_AT_type || Attr == DW_AT_containing_type) {      DWARFDie D = resolveReferencedType(Die, FormValue);      if (D && !D.isNULL()) { diff --git a/llvm/lib/Frontend/Driver/CodeGenOptions.cpp b/llvm/lib/Frontend/Driver/CodeGenOptions.cpp index df88490..b546e81 100644 --- a/llvm/lib/Frontend/Driver/CodeGenOptions.cpp +++ b/llvm/lib/Frontend/Driver/CodeGenOptions.cpp @@ -12,7 +12,6 @@  #include "llvm/TargetParser/Triple.h"  namespace llvm { -extern llvm::cl::opt<bool> DebugInfoCorrelate;  extern llvm::cl::opt<llvm::InstrProfCorrelator::ProfCorrelatorKind>      ProfileCorrelate;  } // namespace llvm @@ -64,8 +63,7 @@ TargetLibraryInfoImpl *createTLII(const llvm::Triple &TargetTriple,  }  std::string getDefaultProfileGenName() { -  return llvm::DebugInfoCorrelate || -                 llvm::ProfileCorrelate != InstrProfCorrelator::NONE +  return llvm::ProfileCorrelate != InstrProfCorrelator::NONE               ? "default_%m.proflite"               : "default_%m.profraw";  } diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 286ed03..0e5926f 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5473,7 +5473,8 @@ OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,      }      // TODO: Enable UndefinedSanitizer to diagnose an overflow here. -    CollapsedTripCount = Builder.CreateNUWMul(CollapsedTripCount, OrigTripCount); +    CollapsedTripCount = +        Builder.CreateNUWMul(CollapsedTripCount, OrigTripCount);    }    // Create the collapsed loop control flow. @@ -9338,9 +9339,8 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,      // target does not support `atomicrmw` of the size of the struct      LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");      OldVal->setAtomic(AO); -    const DataLayout &LoadDL = OldVal->getModule()->getDataLayout(); -    unsigned LoadSize = -        LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType()); +    const DataLayout &DL = OldVal->getModule()->getDataLayout(); +    unsigned LoadSize = DL.getTypeStoreSize(XElemTy);      OpenMPIRBuilder::AtomicInfo atomicInfo(          &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),          OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var); @@ -9384,9 +9384,8 @@ OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,      XSt->setAtomic(AO);    } else if (XElemTy->isStructTy()) {      LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read"); -    const DataLayout &LoadDL = OldVal->getModule()->getDataLayout(); -    unsigned LoadSize = -        LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType()); +    const DataLayout &DL = OldVal->getModule()->getDataLayout(); +    unsigned LoadSize = DL.getTypeStoreSize(XElemTy);      OpenMPIRBuilder::AtomicInfo atomicInfo(          &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),          OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var); @@ -9581,7 +9580,7 @@ Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(      OldVal->setAtomic(AO);      // CurBB      // |     /---\ -		// ContBB    | +    // ContBB    |      // |     \---/      // ExitBB      BasicBlock *CurBB = Builder.GetInsertBlock(); diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 3c222f5..95d954f 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -2199,6 +2199,7 @@ static void writeDIBasicType(raw_ostream &Out, const DIBasicType *N,    Printer.printString("name", N->getName());    Printer.printMetadataOrInt("size", N->getRawSizeInBits(), true);    Printer.printInt("align", N->getAlignInBits()); +  Printer.printInt("dataSize", N->getDataSizeInBits());    Printer.printDwarfEnum("encoding", N->getEncoding(),                           dwarf::AttributeEncodingString);    Printer.printInt("num_extra_inhabitants", N->getNumExtraInhabitants()); diff --git a/llvm/lib/IR/DIBuilder.cpp b/llvm/lib/IR/DIBuilder.cpp index 07a870f..ca11ecf 100644 --- a/llvm/lib/IR/DIBuilder.cpp +++ b/llvm/lib/IR/DIBuilder.cpp @@ -261,10 +261,12 @@ DIBasicType *DIBuilder::createNullPtrType() {  DIBasicType *DIBuilder::createBasicType(StringRef Name, uint64_t SizeInBits,                                          unsigned Encoding,                                          DINode::DIFlags Flags, -                                        uint32_t NumExtraInhabitants) { +                                        uint32_t NumExtraInhabitants, +                                        uint32_t DataSizeInBits) {    assert(!Name.empty() && "Unable to create type without name");    return DIBasicType::get(VMContext, dwarf::DW_TAG_base_type, Name, SizeInBits, -                          0, Encoding, NumExtraInhabitants, Flags); +                          0, Encoding, NumExtraInhabitants, DataSizeInBits, +                          Flags);  }  DIFixedPointType * diff --git a/llvm/lib/IR/DebugInfoMetadata.cpp b/llvm/lib/IR/DebugInfoMetadata.cpp index e30df88..fafc325 100644 --- a/llvm/lib/IR/DebugInfoMetadata.cpp +++ b/llvm/lib/IR/DebugInfoMetadata.cpp @@ -872,15 +872,18 @@ DIEnumerator *DIEnumerator::getImpl(LLVMContext &Context, const APInt &Value,  DIBasicType *DIBasicType::getImpl(LLVMContext &Context, unsigned Tag,                                    MDString *Name, Metadata *SizeInBits,                                    uint32_t AlignInBits, unsigned Encoding, -                                  uint32_t NumExtraInhabitants, DIFlags Flags, +                                  uint32_t NumExtraInhabitants, +                                  uint32_t DataSizeInBits, DIFlags Flags,                                    StorageType Storage, bool ShouldCreate) {    assert(isCanonical(Name) && "Expected canonical MDString"); -  DEFINE_GETIMPL_LOOKUP(DIBasicType, (Tag, Name, SizeInBits, AlignInBits, -                                      Encoding, NumExtraInhabitants, Flags)); +  DEFINE_GETIMPL_LOOKUP(DIBasicType, +                        (Tag, Name, SizeInBits, AlignInBits, Encoding, +                         NumExtraInhabitants, DataSizeInBits, Flags));    Metadata *Ops[] = {nullptr, nullptr, Name, SizeInBits, nullptr}; -  DEFINE_GETIMPL_STORE(DIBasicType, -                       (Tag, AlignInBits, Encoding, NumExtraInhabitants, Flags), -                       Ops); +  DEFINE_GETIMPL_STORE( +      DIBasicType, +      (Tag, AlignInBits, Encoding, NumExtraInhabitants, DataSizeInBits, Flags), +      Ops);  }  std::optional<DIBasicType::Signedness> DIBasicType::getSignedness() const { diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h index e03f993..2c9921d 100644 --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -480,20 +480,22 @@ template <> struct MDNodeKeyImpl<DIBasicType> {    uint32_t AlignInBits;    unsigned Encoding;    uint32_t NumExtraInhabitants; +  uint32_t DataSizeInBits;    unsigned Flags;    MDNodeKeyImpl(unsigned Tag, MDString *Name, Metadata *SizeInBits,                  uint32_t AlignInBits, unsigned Encoding, -                uint32_t NumExtraInhabitants, unsigned Flags) +                uint32_t NumExtraInhabitants, uint32_t DataSizeInBits, +                unsigned Flags)        : Tag(Tag), Name(Name), SizeInBits(SizeInBits), AlignInBits(AlignInBits),          Encoding(Encoding), NumExtraInhabitants(NumExtraInhabitants), -        Flags(Flags) {} +        DataSizeInBits(DataSizeInBits), Flags(Flags) {}    MDNodeKeyImpl(const DIBasicType *N)        : Tag(N->getTag()), Name(N->getRawName()),          SizeInBits(N->getRawSizeInBits()), AlignInBits(N->getAlignInBits()),          Encoding(N->getEncoding()), -        NumExtraInhabitants(N->getNumExtraInhabitants()), Flags(N->getFlags()) { -  } +        NumExtraInhabitants(N->getNumExtraInhabitants()), +        DataSizeInBits(N->getDataSizeInBits()), Flags(N->getFlags()) {}    bool isKeyOf(const DIBasicType *RHS) const {      return Tag == RHS->getTag() && Name == RHS->getRawName() && @@ -501,6 +503,7 @@ template <> struct MDNodeKeyImpl<DIBasicType> {             AlignInBits == RHS->getAlignInBits() &&             Encoding == RHS->getEncoding() &&             NumExtraInhabitants == RHS->getNumExtraInhabitants() && +           DataSizeInBits == RHS->getDataSizeInBits() &&             Flags == RHS->getFlags();    } diff --git a/llvm/lib/MC/MCParser/AsmLexer.cpp b/llvm/lib/MC/MCParser/AsmLexer.cpp index a6188f0..1af4a29 100644 --- a/llvm/lib/MC/MCParser/AsmLexer.cpp +++ b/llvm/lib/MC/MCParser/AsmLexer.cpp @@ -16,7 +16,6 @@  #include "llvm/ADT/StringExtras.h"  #include "llvm/ADT/StringRef.h"  #include "llvm/MC/MCAsmInfo.h" -#include "llvm/MC/MCParser/AsmLexer.h"  #include "llvm/Support/Compiler.h"  #include "llvm/Support/SMLoc.h"  #include "llvm/Support/SaveAndRestore.h" diff --git a/llvm/lib/Object/ELF.cpp b/llvm/lib/Object/ELF.cpp index 6da97f9..354c51d 100644 --- a/llvm/lib/Object/ELF.cpp +++ b/llvm/lib/Object/ELF.cpp @@ -831,17 +831,17 @@ decodeBBAddrMapImpl(const ELFFile<ELFT> &EF,    };    uint8_t Version = 0; -  uint8_t Feature = 0; +  uint16_t Feature = 0;    BBAddrMap::Features FeatEnable{};    while (!ULEBSizeErr && !MetadataDecodeErr && Cur &&           Cur.tell() < Content.size()) {      Version = Data.getU8(Cur);      if (!Cur)        break; -    if (Version < 2 || Version > 4) +    if (Version < 2 || Version > 5)        return createError("unsupported SHT_LLVM_BB_ADDR_MAP version: " +                           Twine(static_cast<int>(Version))); -    Feature = Data.getU8(Cur); // Feature byte +    Feature = Version < 5 ? Data.getU8(Cur) : Data.getU16(Cur);      if (!Cur)        break;      auto FeatEnableOrErr = BBAddrMap::Features::decode(Feature); @@ -858,6 +858,11 @@ decodeBBAddrMapImpl(const ELFFile<ELFT> &EF,                           "basic block hash feature is enabled: version = " +                           Twine(static_cast<int>(Version)) +                           " feature = " + Twine(static_cast<int>(Feature))); +    if (FeatEnable.PostLinkCfg && Version < 5) +      return createError("version should be >= 5 for SHT_LLVM_BB_ADDR_MAP when " +                         "post link cfg feature is enabled: version = " + +                         Twine(static_cast<int>(Version)) + +                         " feature = " + Twine(static_cast<int>(Feature)));      uint32_t NumBlocksInBBRange = 0;      uint32_t NumBBRanges = 1;      typename ELFFile<ELFT>::uintX_t RangeBaseAddress = 0; @@ -946,6 +951,10 @@ decodeBBAddrMapImpl(const ELFFile<ELFT> &EF,          uint64_t BBF = FeatEnable.BBFreq                             ? readULEB128As<uint64_t>(Data, Cur, ULEBSizeErr)                             : 0; +        uint32_t PostLinkBBFreq = +            FeatEnable.PostLinkCfg +                ? readULEB128As<uint32_t>(Data, Cur, ULEBSizeErr) +                : 0;          // Branch probability          llvm::SmallVector<PGOAnalysisMap::PGOBBEntry::SuccessorEntry, 2> @@ -955,13 +964,20 @@ decodeBBAddrMapImpl(const ELFFile<ELFT> &EF,            for (uint64_t I = 0; I < SuccCount; ++I) {              uint32_t BBID = readULEB128As<uint32_t>(Data, Cur, ULEBSizeErr);              uint32_t BrProb = readULEB128As<uint32_t>(Data, Cur, ULEBSizeErr); +            uint32_t PostLinkFreq = +                FeatEnable.PostLinkCfg +                    ? readULEB128As<uint32_t>(Data, Cur, ULEBSizeErr) +                    : 0; +              if (PGOAnalyses) -              Successors.push_back({BBID, BranchProbability::getRaw(BrProb)}); +              Successors.push_back( +                  {BBID, BranchProbability::getRaw(BrProb), PostLinkFreq});            }          }          if (PGOAnalyses) -          PGOBBEntries.push_back({BlockFrequency(BBF), std::move(Successors)}); +          PGOBBEntries.push_back( +              {BlockFrequency(BBF), PostLinkBBFreq, std::move(Successors)});        }        if (PGOAnalyses) diff --git a/llvm/lib/ObjectYAML/ELFEmitter.cpp b/llvm/lib/ObjectYAML/ELFEmitter.cpp index 8b75fbe..8530785 100644 --- a/llvm/lib/ObjectYAML/ELFEmitter.cpp +++ b/llvm/lib/ObjectYAML/ELFEmitter.cpp @@ -1465,13 +1465,19 @@ void ELFState<ELFT>::writeSectionContent(    for (const auto &[Idx, E] : llvm::enumerate(*Section.Entries)) {      // Write version and feature values.      if (Section.Type == llvm::ELF::SHT_LLVM_BB_ADDR_MAP) { -      if (E.Version > 4) +      if (E.Version > 5)          WithColor::warning() << "unsupported SHT_LLVM_BB_ADDR_MAP version: "                               << static_cast<int>(E.Version)                               << "; encoding using the most recent version";        CBA.write(E.Version); -      CBA.write(E.Feature); -      SHeader.sh_size += 2; +      SHeader.sh_size += 1; +      if (E.Version < 5) { +        CBA.write(static_cast<uint8_t>(E.Feature)); +        SHeader.sh_size += 1; +      } else { +        CBA.write<uint16_t>(E.Feature, ELFT::Endianness); +        SHeader.sh_size += 2; +      }      }      auto FeatureOrErr = llvm::object::BBAddrMap::Features::decode(E.Feature);      bool MultiBBRangeFeatureEnabled = false; @@ -1556,11 +1562,15 @@ void ELFState<ELFT>::writeSectionContent(      for (const auto &PGOBBE : PGOBBEntries) {        if (PGOBBE.BBFreq)          SHeader.sh_size += CBA.writeULEB128(*PGOBBE.BBFreq); +      if (FeatureOrErr->PostLinkCfg || PGOBBE.PostLinkBBFreq.has_value()) +        SHeader.sh_size += CBA.writeULEB128(PGOBBE.PostLinkBBFreq.value_or(0));        if (PGOBBE.Successors) {          SHeader.sh_size += CBA.writeULEB128(PGOBBE.Successors->size()); -        for (const auto &[ID, BrProb] : *PGOBBE.Successors) { +        for (const auto &[ID, BrProb, PostLinkBrFreq] : *PGOBBE.Successors) {            SHeader.sh_size += CBA.writeULEB128(ID);            SHeader.sh_size += CBA.writeULEB128(BrProb); +          if (FeatureOrErr->PostLinkCfg || PostLinkBrFreq.has_value()) +            SHeader.sh_size += CBA.writeULEB128(PostLinkBrFreq.value_or(0));          }        }      } diff --git a/llvm/lib/ObjectYAML/ELFYAML.cpp b/llvm/lib/ObjectYAML/ELFYAML.cpp index f8a84b0..e5e5fc2 100644 --- a/llvm/lib/ObjectYAML/ELFYAML.cpp +++ b/llvm/lib/ObjectYAML/ELFYAML.cpp @@ -1886,7 +1886,7 @@ void MappingTraits<ELFYAML::BBAddrMapEntry>::mapping(      IO &IO, ELFYAML::BBAddrMapEntry &E) {    assert(IO.getContext() && "The IO context is not initialized");    IO.mapRequired("Version", E.Version); -  IO.mapOptional("Feature", E.Feature, Hex8(0)); +  IO.mapOptional("Feature", E.Feature, Hex16(0));    IO.mapOptional("NumBBRanges", E.NumBBRanges);    IO.mapOptional("BBRanges", E.BBRanges);  } @@ -1920,6 +1920,7 @@ void MappingTraits<ELFYAML::PGOAnalysisMapEntry::PGOBBEntry>::mapping(      IO &IO, ELFYAML::PGOAnalysisMapEntry::PGOBBEntry &E) {    assert(IO.getContext() && "The IO context is not initialized");    IO.mapOptional("BBFreq", E.BBFreq); +  IO.mapOptional("PostLinkBBFreq", E.PostLinkBBFreq);    IO.mapOptional("Successors", E.Successors);  } @@ -1929,6 +1930,7 @@ void MappingTraits<ELFYAML::PGOAnalysisMapEntry::PGOBBEntry::SuccessorEntry>::    assert(IO.getContext() && "The IO context is not initialized");    IO.mapRequired("ID", E.ID);    IO.mapRequired("BrProb", E.BrProb); +  IO.mapOptional("PostLinkBrFreq", E.PostLinkBrFreq);  }  void MappingTraits<ELFYAML::GnuHashHeader>::mapping(IO &IO, diff --git a/llvm/lib/Support/AutoConvert.cpp b/llvm/lib/Support/AutoConvert.cpp index 0b6928e..741bb7b 100644 --- a/llvm/lib/Support/AutoConvert.cpp +++ b/llvm/lib/Support/AutoConvert.cpp @@ -96,7 +96,7 @@ std::error_code llvm::setzOSFileTag(int FD, int CCSID, bool Text) {    return std::error_code();  } -ErrorOr<__ccsid_t> llvm::getzOSFileTag(const char *FileName, const int FD) { +ErrorOr<__ccsid_t> llvm::getzOSFileTag(const Twine &FileName, const int FD) {    // If we have a file descriptor, use it to find out file tagging. Otherwise we    // need to use stat() with the file path.    if (FD != -1) { @@ -110,12 +110,12 @@ ErrorOr<__ccsid_t> llvm::getzOSFileTag(const char *FileName, const int FD) {      return Query.fccsid;    }    struct stat Attr; -  if (stat(FileName, &Attr) == -1) +  if (stat(FileName.str().c_str(), &Attr) == -1)      return std::error_code(errno, std::generic_category());    return Attr.st_tag.ft_ccsid;  } -ErrorOr<bool> llvm::needzOSConversion(const char *FileName, const int FD) { +ErrorOr<bool> llvm::needzOSConversion(const Twine &FileName, const int FD) {    ErrorOr<__ccsid_t> Ccsid = getzOSFileTag(FileName, FD);    if (std::error_code EC = Ccsid.getError())      return EC; diff --git a/llvm/lib/Support/BranchProbability.cpp b/llvm/lib/Support/BranchProbability.cpp index e376344..ea42f34 100644 --- a/llvm/lib/Support/BranchProbability.cpp +++ b/llvm/lib/Support/BranchProbability.cpp @@ -111,3 +111,10 @@ uint64_t BranchProbability::scale(uint64_t Num) const {  uint64_t BranchProbability::scaleByInverse(uint64_t Num) const {    return ::scale<0>(Num, D, N);  } + +BranchProbability BranchProbability::pow(unsigned N) const { +  BranchProbability Res = BranchProbability::getOne(); +  for (unsigned I = 0; I < N; ++I) +    Res *= *this; +  return Res; +} diff --git a/llvm/lib/Support/MemoryBuffer.cpp b/llvm/lib/Support/MemoryBuffer.cpp index 1c4645a..23b9f8c 100644 --- a/llvm/lib/Support/MemoryBuffer.cpp +++ b/llvm/lib/Support/MemoryBuffer.cpp @@ -512,7 +512,7 @@ getOpenFileImpl(sys::fs::file_t FD, const Twine &Filename, uint64_t FileSize,    }  #ifdef __MVS__ -  ErrorOr<bool> NeedsConversion = needConversion(Filename.str().c_str(), FD); +  ErrorOr<bool> NeedsConversion = needConversion(Filename, FD);    if (std::error_code EC = NeedsConversion.getError())      return EC;    // File size may increase due to EBCDIC -> UTF-8 conversion, therefore we diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 457e540..ccc8eb8 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -122,7 +122,7 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {      NumBytes = Desc.getSize() ? Desc.getSize() : 4;      const auto *MFI = MF->getInfo<AArch64FunctionInfo>(); -    if (!MFI->shouldSignReturnAddress(MF)) +    if (!MFI->shouldSignReturnAddress(*MF))        return NumBytes;      const auto &STI = MF->getSubtarget<AArch64Subtarget>(); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index b9e299e..2871a20 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -1805,14 +1805,22 @@ def : SHA3_pattern<EOR3, int_aarch64_crypto_eor3u, v8i16>;  def : SHA3_pattern<EOR3, int_aarch64_crypto_eor3u, v4i32>;  def : SHA3_pattern<EOR3, int_aarch64_crypto_eor3u, v2i64>; -class EOR3_pattern<ValueType VecTy> -  : Pat<(xor (xor (VecTy V128:$Vn), (VecTy V128:$Vm)), (VecTy V128:$Va)), -        (EOR3 (VecTy V128:$Vn), (VecTy V128:$Vm), (VecTy V128:$Va))>; - -def : EOR3_pattern<v16i8>; -def : EOR3_pattern<v8i16>; -def : EOR3_pattern<v4i32>; -def : EOR3_pattern<v2i64>; +multiclass EOR3_pattern<ValueType Vec128Ty, ValueType Vec64Ty>{ +  def : Pat<(xor (xor (Vec128Ty V128:$Vn), (Vec128Ty V128:$Vm)), (Vec128Ty V128:$Va)), +        (EOR3 (Vec128Ty V128:$Vn), (Vec128Ty V128:$Vm), (Vec128Ty V128:$Va))>; +  def : Pat<(xor (xor (Vec64Ty V64:$Vn), (Vec64Ty V64:$Vm)), (Vec64Ty V64:$Va)), +            (EXTRACT_SUBREG +              (EOR3 +                (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vn, dsub), +                (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vm, dsub), +                (INSERT_SUBREG (IMPLICIT_DEF), V64:$Va, dsub)), +              dsub)>; +} + +defm : EOR3_pattern<v16i8, v8i8>; +defm : EOR3_pattern<v8i16, v4i16>; +defm : EOR3_pattern<v4i32, v2i32>; +defm : EOR3_pattern<v2i64, v1i64>;  class BCAX_pattern<ValueType VecTy>    : Pat<(xor (VecTy V128:$Vn), (and (VecTy V128:$Vm), (vnot (VecTy V128:$Va)))), diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index fede586..47c1ac4 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1032,6 +1032,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,      }      break;    } +  case Intrinsic::experimental_vector_extract_last_active: +    if (ST->isSVEorStreamingSVEAvailable()) { +      auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]); +      // This should turn into chained clastb instructions. +      return LegalCost; +    } +    break;    default:      break;    } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index 9ce1224..aed325c 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -221,12 +221,22 @@ bool AMDGPUInstructionSelector::selectCOPY(MachineInstr &I) const {  bool AMDGPUInstructionSelector::selectCOPY_SCC_VCC(MachineInstr &I) const {    const DebugLoc &DL = I.getDebugLoc();    MachineBasicBlock *BB = I.getParent(); +  Register VCCReg = I.getOperand(1).getReg(); +  MachineInstr *Cmp; + +  if (STI.getGeneration() >= AMDGPUSubtarget::VOLCANIC_ISLANDS) { +    unsigned CmpOpc = +        STI.isWave64() ? AMDGPU::S_CMP_LG_U64 : AMDGPU::S_CMP_LG_U32; +    Cmp = BuildMI(*BB, &I, DL, TII.get(CmpOpc)).addReg(VCCReg).addImm(0); +  } else { +    // For gfx7 and earlier, S_CMP_LG_U64 doesn't exist, so we use S_OR_B64 +    // which sets SCC as a side effect. +    Register DeadDst = MRI->createVirtualRegister(&AMDGPU::SReg_64RegClass); +    Cmp = BuildMI(*BB, &I, DL, TII.get(AMDGPU::S_OR_B64), DeadDst) +              .addReg(VCCReg) +              .addReg(VCCReg); +  } -  unsigned CmpOpc = -      STI.isWave64() ? AMDGPU::S_CMP_LG_U64 : AMDGPU::S_CMP_LG_U32; -  MachineInstr *Cmp = BuildMI(*BB, &I, DL, TII.get(CmpOpc)) -                          .addReg(I.getOperand(1).getReg()) -                          .addImm(0);    if (!constrainSelectedInstRegOperands(*Cmp, TII, TRI, RBI))      return false; diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp index e187959..907f830 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp @@ -24,6 +24,7 @@  #include "llvm/CodeGen/GlobalISel/CSEInfo.h"  #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h"  #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" +#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"  #include "llvm/CodeGen/GlobalISel/Utils.h"  #include "llvm/CodeGen/MachineFunctionPass.h"  #include "llvm/CodeGen/MachineUniformityAnalysis.h" @@ -34,9 +35,17 @@  using namespace llvm;  using namespace AMDGPU; +using namespace llvm::MIPatternMatch;  namespace { +// AMDGPU-specific pattern matchers +template <typename SrcTy> +inline UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE> +m_GAMDGPUReadAnyLane(const SrcTy &Src) { +  return UnaryOp_match<SrcTy, AMDGPU::G_AMDGPU_READANYLANE>(Src); +} +  class AMDGPURegBankLegalize : public MachineFunctionPass {  public:    static char ID; @@ -160,10 +169,18 @@ AMDGPURegBankLegalizeCombiner::tryMatchRALFromUnmerge(Register Src) {  Register AMDGPURegBankLegalizeCombiner::getReadAnyLaneSrc(Register Src) {    // Src = G_AMDGPU_READANYLANE RALSrc -  auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE); -  if (RAL) +  Register RALSrc; +  if (mi_match(Src, MRI, m_GAMDGPUReadAnyLane(m_Reg(RALSrc))))      return RALSrc; +  // TruncSrc = G_AMDGPU_READANYLANE RALSrc +  // AextSrc = G_TRUNC TruncSrc +  // Src = G_ANYEXT AextSrc +  if (mi_match(Src, MRI, +               m_GAnyExt(m_GTrunc(m_GAMDGPUReadAnyLane(m_Reg(RALSrc)))))) { +    return RALSrc; +  } +    // LoVgpr, HiVgpr = G_UNMERGE_VALUES UnmergeSrc    // LoSgpr = G_AMDGPU_READANYLANE LoVgpr    // HiSgpr = G_AMDGPU_READANYLANE HiVgpr diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp index 5407566..dc8fa7f 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp @@ -500,6 +500,16 @@ void RegBankLegalizeHelper::lowerUnpackMinMax(MachineInstr &MI) {    MI.eraseFromParent();  } +void RegBankLegalizeHelper::lowerUnpackAExt(MachineInstr &MI) { +  auto [Op1Lo, Op1Hi] = unpackAExt(MI.getOperand(1).getReg()); +  auto [Op2Lo, Op2Hi] = unpackAExt(MI.getOperand(2).getReg()); +  auto ResLo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Op1Lo, Op2Lo}); +  auto ResHi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Op1Hi, Op2Hi}); +  B.buildBuildVectorTrunc(MI.getOperand(0).getReg(), +                          {ResLo.getReg(0), ResHi.getReg(0)}); +  MI.eraseFromParent(); +} +  static bool isSignedBFE(MachineInstr &MI) {    if (GIntrinsic *GI = dyn_cast<GIntrinsic>(&MI))      return (GI->is(Intrinsic::amdgcn_sbfe)); @@ -616,6 +626,23 @@ void RegBankLegalizeHelper::lowerSplitTo32(MachineInstr &MI) {    MI.eraseFromParent();  } +void RegBankLegalizeHelper::lowerSplitTo16(MachineInstr &MI) { +  Register Dst = MI.getOperand(0).getReg(); +  assert(MRI.getType(Dst) == V2S16); +  auto [Op1Lo32, Op1Hi32] = unpackAExt(MI.getOperand(1).getReg()); +  auto [Op2Lo32, Op2Hi32] = unpackAExt(MI.getOperand(2).getReg()); +  unsigned Opc = MI.getOpcode(); +  auto Flags = MI.getFlags(); +  auto Op1Lo = B.buildTrunc(SgprRB_S16, Op1Lo32); +  auto Op1Hi = B.buildTrunc(SgprRB_S16, Op1Hi32); +  auto Op2Lo = B.buildTrunc(SgprRB_S16, Op2Lo32); +  auto Op2Hi = B.buildTrunc(SgprRB_S16, Op2Hi32); +  auto Lo = B.buildInstr(Opc, {SgprRB_S16}, {Op1Lo, Op2Lo}, Flags); +  auto Hi = B.buildInstr(Opc, {SgprRB_S16}, {Op1Hi, Op2Hi}, Flags); +  B.buildMergeLikeInstr(Dst, {Lo, Hi}); +  MI.eraseFromParent(); +} +  void RegBankLegalizeHelper::lowerSplitTo32Select(MachineInstr &MI) {    Register Dst = MI.getOperand(0).getReg();    LLT DstTy = MRI.getType(Dst); @@ -688,6 +715,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,      return lowerUnpackBitShift(MI);    case UnpackMinMax:      return lowerUnpackMinMax(MI); +  case ScalarizeToS16: +    return lowerSplitTo16(MI);    case Ext32To64: {      const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());      MachineInstrBuilder Hi; @@ -804,6 +833,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,      }      break;    } +  case UnpackAExt: +    return lowerUnpackAExt(MI);    case WidenMMOToS32:      return widenMMOToS32(cast<GAnyLoad>(MI));    } @@ -837,6 +868,7 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {      return LLT::scalar(32);    case Sgpr64:    case Vgpr64: +  case UniInVgprS64:      return LLT::scalar(64);    case Sgpr128:    case Vgpr128: @@ -960,6 +992,7 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {    case UniInVcc:    case UniInVgprS16:    case UniInVgprS32: +  case UniInVgprS64:    case UniInVgprV2S16:    case UniInVgprV4S32:    case UniInVgprB32: @@ -1092,6 +1125,7 @@ void RegBankLegalizeHelper::applyMappingDst(        break;      }      case UniInVgprS32: +    case UniInVgprS64:      case UniInVgprV2S16:      case UniInVgprV4S32: {        assert(Ty == getTyFromID(MethodIDs[OpIdx])); @@ -1120,7 +1154,8 @@ void RegBankLegalizeHelper::applyMappingDst(        assert(RB == SgprRB);        Register NewDst = MRI.createVirtualRegister(SgprRB_S32);        Op.setReg(NewDst); -      B.buildTrunc(Reg, NewDst); +      if (!MRI.use_empty(Reg)) +        B.buildTrunc(Reg, NewDst);        break;      }      case InvalidMapping: { diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h index d937815..e7598f8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h @@ -72,6 +72,7 @@ class RegBankLegalizeHelper {    static constexpr LLT P6 = LLT::pointer(6, 32);    MachineRegisterInfo::VRegAttrs SgprRB_S32 = {SgprRB, S32}; +  MachineRegisterInfo::VRegAttrs SgprRB_S16 = {SgprRB, S16};    MachineRegisterInfo::VRegAttrs VgprRB_S32 = {VgprRB, S32};    MachineRegisterInfo::VRegAttrs VccRB_S1 = {VccRB, S1}; @@ -121,9 +122,11 @@ private:    void lowerV_BFE(MachineInstr &MI);    void lowerS_BFE(MachineInstr &MI);    void lowerSplitTo32(MachineInstr &MI); +  void lowerSplitTo16(MachineInstr &MI);    void lowerSplitTo32Select(MachineInstr &MI);    void lowerSplitTo32SExtInReg(MachineInstr &MI);    void lowerUnpackMinMax(MachineInstr &MI); +  void lowerUnpackAExt(MachineInstr &MI);  };  } // end namespace AMDGPU diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp index a67b12a..b22e9bd 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp @@ -470,7 +470,19 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,        .Uni(S16, {{Sgpr32Trunc}, {Sgpr32AExt, Sgpr32AExt}})        .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})        .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}}) -      .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}); +      .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}) +      .Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, UnpackAExt}) +      .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}}) +      .Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr64}}) +      .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr64}}); + +  addRulesForGOpcs({G_UADDO, G_USUBO}, Standard) +      .Uni(S32, {{Sgpr32, Sgpr32Trunc}, {Sgpr32, Sgpr32}}) +      .Div(S32, {{Vgpr32, Vcc}, {Vgpr32, Vgpr32}}); + +  addRulesForGOpcs({G_UADDE, G_USUBE}, Standard) +      .Uni(S32, {{Sgpr32, Sgpr32Trunc}, {Sgpr32, Sgpr32, Sgpr32AExtBoolInReg}}) +      .Div(S32, {{Vgpr32, Vcc}, {Vgpr32, Vgpr32, Vcc}});    addRulesForGOpcs({G_MUL}, Standard).Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}); @@ -906,9 +918,20 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,    bool hasSALUFloat = ST->hasSALUFloatInsts();    addRulesForGOpcs({G_FADD}, Standard) +      .Uni(S16, {{UniInVgprS16}, {Vgpr16, Vgpr16}}, !hasSALUFloat) +      .Uni(S16, {{Sgpr16}, {Sgpr16, Sgpr16}}, hasSALUFloat) +      .Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})        .Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}}, hasSALUFloat)        .Uni(S32, {{UniInVgprS32}, {Vgpr32, Vgpr32}}, !hasSALUFloat) -      .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}); +      .Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}}) +      .Uni(S64, {{UniInVgprS64}, {Vgpr64, Vgpr64}}) +      .Div(S64, {{Vgpr64}, {Vgpr64, Vgpr64}}) +      .Uni(V2S16, {{UniInVgprV2S16}, {VgprV2S16, VgprV2S16}}, !hasSALUFloat) +      .Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, ScalarizeToS16}, +           hasSALUFloat) +      .Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}}) +      .Any({{UniV2S32}, {{UniInVgprV2S32}, {VgprV2S32, VgprV2S32}}}) +      .Any({{DivV2S32}, {{VgprV2S32}, {VgprV2S32, VgprV2S32}}});    addRulesForGOpcs({G_FPTOUI})        .Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}}, hasSALUFloat) diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h index 93e0efd..e6df5d8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h @@ -92,8 +92,10 @@ enum UniformityLLTOpPredicateID {    V4S32,    UniV2S16, +  UniV2S32,    DivV2S16, +  DivV2S32,    // B types    B32, @@ -178,7 +180,9 @@ enum RegBankLLTMappingApplyID {    UniInVcc,    UniInVgprS16,    UniInVgprS32, +  UniInVgprS64,    UniInVgprV2S16, +  UniInVgprV2S32,    UniInVgprV4S32,    UniInVgprB32,    UniInVgprB64, @@ -217,13 +221,15 @@ enum LoweringMethodID {    V_BFE,    VgprToVccCopy,    SplitTo32, +  ScalarizeToS16,    SplitTo32Select,    SplitTo32SExtInReg,    Ext32To64,    UniCstExt,    SplitLoad,    WidenLoad, -  WidenMMOToS32 +  WidenMMOToS32, +  UnpackAExt  };  enum FastRulesTypes { diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp index 75a94ac..b28c50e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp @@ -1315,6 +1315,9 @@ void AMDGPUPassConfig::addIRPasses() {        isPassEnabled(EnableImageIntrinsicOptimizer))      addPass(createAMDGPUImageIntrinsicOptimizerPass(&TM)); +  if (EnableUniformIntrinsicCombine) +    addPass(createAMDGPUUniformIntrinsicCombineLegacyPass()); +    // This can be disabled by passing ::Disable here or on the command line    // with --expand-variadics-override=disable.    addPass(createExpandVariadicsPass(ExpandVariadicsMode::Lowering)); @@ -2066,6 +2069,8 @@ void AMDGPUCodeGenPassBuilder::addIRPasses(AddIRPass &addPass) const {    if (isPassEnabled(EnableImageIntrinsicOptimizer))      addPass(AMDGPUImageIntrinsicOptimizerPass(TM)); +  if (EnableUniformIntrinsicCombine) +    addPass(AMDGPUUniformIntrinsicCombinePass());    // This can be disabled by passing ::Disable here or on the command line    // with --expand-variadics-override=disable.    addPass(ExpandVariadicsPass(ExpandVariadicsMode::Lowering)); diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index b34ab2a..8bb2808 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -7035,9 +7035,15 @@ static SDValue lowerBALLOTIntrinsic(const SITargetLowering &TLI, SDNode *N,    SDLoc SL(N);    if (Src.getOpcode() == ISD::SETCC) { +    SDValue Op0 = Src.getOperand(0); +    SDValue Op1 = Src.getOperand(1); +    // Need to expand bfloat to float for comparison (setcc). +    if (Op0.getValueType() == MVT::bf16) { +      Op0 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Op0); +      Op1 = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, Op1); +    }      // (ballot (ISD::SETCC ...)) -> (AMDGPUISD::SETCC ...) -    return DAG.getNode(AMDGPUISD::SETCC, SL, VT, Src.getOperand(0), -                       Src.getOperand(1), Src.getOperand(2)); +    return DAG.getNode(AMDGPUISD::SETCC, SL, VT, Op0, Op1, Src.getOperand(2));    }    if (const ConstantSDNode *Arg = dyn_cast<ConstantSDNode>(Src)) {      // (ballot 0) -> 0 diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp index d80a6f3..a6c1af2 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp @@ -1823,6 +1823,16 @@ void SIRegisterInfo::buildSpillLoadStore(        }      } +    Register FinalValueReg = ValueReg; +    if (LoadStoreOp == AMDGPU::SCRATCH_LOAD_USHORT_SADDR) { +      // If we are loading 16-bit value with SRAMECC endabled we need a temp +      // 32-bit VGPR to load and extract 16-bits into the final register. +      ValueReg = +          RS->scavengeRegisterBackwards(AMDGPU::VGPR_32RegClass, MI, false, 0); +      SubReg = ValueReg; +      IsKill = false; +    } +      MachinePointerInfo PInfo = BasePtrInfo.getWithOffset(RegOffset);      MachineMemOperand *NewMMO =          MF->getMachineMemOperand(PInfo, MMO->getFlags(), RemEltSize, @@ -1863,6 +1873,17 @@ void SIRegisterInfo::buildSpillLoadStore(        MIB.addImm(0); // swz      MIB.addMemOperand(NewMMO); +    if (FinalValueReg != ValueReg) { +      // Extract 16-bit from the loaded 32-bit value. +      ValueReg = getSubReg(ValueReg, AMDGPU::lo16); +      MIB = BuildMI(MBB, MI, DL, TII->get(AMDGPU::V_MOV_B16_t16_e64)) +                .addReg(FinalValueReg, getDefRegState(true)) +                .addImm(0) +                .addReg(ValueReg, getKillRegState(true)) +                .addImm(0); +      ValueReg = FinalValueReg; +    } +      if (!IsAGPR && NeedSuperRegDef)        MIB.addReg(ValueReg, RegState::ImplicitDefine); @@ -2505,7 +2526,9 @@ bool SIRegisterInfo::eliminateFrameIndex(MachineBasicBlock::iterator MI,        unsigned Opc;        if (MI->getOpcode() == AMDGPU::SI_SPILL_V16_RESTORE) {          assert(ST.enableFlatScratch() && "Flat Scratch is not enabled!"); -        Opc = AMDGPU::SCRATCH_LOAD_SHORT_D16_SADDR_t16; +        Opc = ST.d16PreservesUnusedBits() +                  ? AMDGPU::SCRATCH_LOAD_SHORT_D16_SADDR_t16 +                  : AMDGPU::SCRATCH_LOAD_USHORT_SADDR;        } else {          Opc = MI->getOpcode() == AMDGPU::SI_BLOCK_SPILL_V1024_RESTORE                    ? AMDGPU::SCRATCH_LOAD_BLOCK_SADDR diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index fdba454..6b06534 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -601,10 +601,20 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM_,      setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);      setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom); -    if (!Subtarget->hasVFP2Base()) +    if (!Subtarget->hasVFP2Base()) {        setAllExpand(MVT::f32); -    if (!Subtarget->hasFP64()) +    } else { +      for (auto Op : {ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL, +                      ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FSQRT}) +        setOperationAction(Op, MVT::f32, Legal); +    } +    if (!Subtarget->hasFP64()) {        setAllExpand(MVT::f64); +    } else { +      for (auto Op : {ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL, +                      ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FSQRT}) +        setOperationAction(Op, MVT::f64, Legal); +    }    }    if (Subtarget->hasFullFP16()) { @@ -1281,12 +1291,16 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM_,      if (!Subtarget->hasFPARMv8Base() || !Subtarget->hasFP64()) {        setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);        setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand); +      setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f64, LibCall); +      setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f64, LibCall);      }      // fp16 is a special v7 extension that adds f16 <-> f32 conversions.      if (!Subtarget->hasFP16()) {        setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);        setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand); +      setOperationAction(ISD::STRICT_FP16_TO_FP, MVT::f32, LibCall); +      setOperationAction(ISD::STRICT_FP_TO_FP16, MVT::f32, LibCall);      }      // Strict floating-point comparisons need custom lowering. @@ -1333,31 +1347,42 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM_,    }    // FP16 often need to be promoted to call lib functions +  // clang-format off    if (Subtarget->hasFullFP16()) { -    setOperationAction(ISD::FREM, MVT::f16, Promote); -    setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); -    setOperationAction(ISD::FSIN, MVT::f16, Promote); -    setOperationAction(ISD::FCOS, MVT::f16, Promote); -    setOperationAction(ISD::FTAN, MVT::f16, Promote); -    setOperationAction(ISD::FSINCOS, MVT::f16, Promote); -    setOperationAction(ISD::FPOWI, MVT::f16, Promote); -    setOperationAction(ISD::FPOW, MVT::f16, Promote); -    setOperationAction(ISD::FEXP, MVT::f16, Promote); -    setOperationAction(ISD::FEXP2, MVT::f16, Promote); -    setOperationAction(ISD::FEXP10, MVT::f16, Promote); -    setOperationAction(ISD::FLOG, MVT::f16, Promote); -    setOperationAction(ISD::FLOG10, MVT::f16, Promote); -    setOperationAction(ISD::FLOG2, MVT::f16, Promote);      setOperationAction(ISD::LRINT, MVT::f16, Expand);      setOperationAction(ISD::LROUND, MVT::f16, Expand); - -    setOperationAction(ISD::FROUND, MVT::f16, Legal); -    setOperationAction(ISD::FROUNDEVEN, MVT::f16, Legal); -    setOperationAction(ISD::FTRUNC, MVT::f16, Legal); -    setOperationAction(ISD::FNEARBYINT, MVT::f16, Legal); -    setOperationAction(ISD::FRINT, MVT::f16, Legal); -    setOperationAction(ISD::FFLOOR, MVT::f16, Legal); -    setOperationAction(ISD::FCEIL, MVT::f16, Legal); +    setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); +   +    for (auto Op : {ISD::FREM,          ISD::FPOW,         ISD::FPOWI, +                  ISD::FCOS,          ISD::FSIN,         ISD::FSINCOS, +                  ISD::FSINCOSPI,     ISD::FMODF,        ISD::FACOS, +                  ISD::FASIN,         ISD::FATAN,        ISD::FATAN2, +                  ISD::FCOSH,         ISD::FSINH,        ISD::FTANH, +                  ISD::FTAN,          ISD::FEXP,         ISD::FEXP2, +                  ISD::FEXP10,        ISD::FLOG,         ISD::FLOG2, +                  ISD::FLOG10,        ISD::STRICT_FREM,  ISD::STRICT_FPOW, +                  ISD::STRICT_FPOWI,  ISD::STRICT_FCOS,  ISD::STRICT_FSIN, +                  ISD::STRICT_FACOS,  ISD::STRICT_FASIN, ISD::STRICT_FATAN, +                  ISD::STRICT_FATAN2, ISD::STRICT_FCOSH, ISD::STRICT_FSINH, +                  ISD::STRICT_FTANH,  ISD::STRICT_FEXP,  ISD::STRICT_FEXP2, +                  ISD::STRICT_FLOG,   ISD::STRICT_FLOG2, ISD::STRICT_FLOG10, +                  ISD::STRICT_FTAN}) { +        setOperationAction(Op, MVT::f16, Promote); +    } + +    // Round-to-integer need custom lowering for fp16, as Promote doesn't work +    // because the result type is integer. +    for (auto Op : {ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT, ISD::STRICT_LLRINT}) +      setOperationAction(Op, MVT::f16, Custom); +   +    for (auto Op : {ISD::FROUND,         ISD::FROUNDEVEN,        ISD::FTRUNC, +                    ISD::FNEARBYINT,     ISD::FRINT,             ISD::FFLOOR,  +                    ISD::FCEIL,          ISD::STRICT_FROUND,     ISD::STRICT_FROUNDEVEN, +                    ISD::STRICT_FTRUNC,  ISD::STRICT_FNEARBYINT, ISD::STRICT_FRINT,  +                    ISD::STRICT_FFLOOR,  ISD::STRICT_FCEIL}) { +      setOperationAction(Op, MVT::f16, Legal); +    } +    // clang-format on    }    if (Subtarget->hasNEON()) { @@ -10725,6 +10750,19 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {      return LowerCMP(Op, DAG);    case ISD::ABS:      return LowerABS(Op, DAG); +  case ISD::STRICT_LROUND: +  case ISD::STRICT_LLROUND: +  case ISD::STRICT_LRINT: +  case ISD::STRICT_LLRINT: { +    assert((Op.getOperand(1).getValueType() == MVT::f16 || +            Op.getOperand(1).getValueType() == MVT::bf16) && +           "Expected custom lowering of rounding operations only for f16"); +    SDLoc DL(Op); +    SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other}, +                              {Op.getOperand(0), Op.getOperand(1)}); +    return DAG.getNode(Op.getOpcode(), DL, {Op.getValueType(), MVT::Other}, +                       {Ext.getValue(1), Ext.getValue(0)}); +  }    }  } @@ -22071,6 +22109,11 @@ bool ARMTargetLowering::isComplexDeinterleavingOperationSupported(            ScalarTy->isIntegerTy(32));  } +ArrayRef<MCPhysReg> ARMTargetLowering::getRoundingControlRegisters() const { +  static const MCPhysReg RCRegs[] = {ARM::FPSCR_RM}; +  return RCRegs; +} +  Value *ARMTargetLowering::createComplexDeinterleavingIR(      IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,      ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h index 357d2c5..bf3438b 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -1009,6 +1009,8 @@ class VectorType;      bool isUnsupportedFloatingType(EVT VT) const; +    ArrayRef<MCPhysReg> getRoundingControlRegisters() const override; +      SDValue getCMOV(const SDLoc &dl, EVT VT, SDValue FalseVal, SDValue TrueVal,                      SDValue ARMcc, SDValue Flags, SelectionDAG &DAG) const;      SDValue getARMCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td index 10d4cd5..f7176a6 100644 --- a/llvm/lib/Target/ARM/ARMInstrInfo.td +++ b/llvm/lib/Target/ARM/ARMInstrInfo.td @@ -473,15 +473,15 @@ def xor_su : PatFrag<(ops node:$lhs, node:$rhs), (xor node:$lhs, node:$rhs)>;  // An 'fmul' node with a single use.  let HasOneUse = 1 in -def fmul_su : PatFrag<(ops node:$lhs, node:$rhs), (fmul node:$lhs, node:$rhs)>; +def fmul_su : PatFrag<(ops node:$lhs, node:$rhs), (any_fmul node:$lhs, node:$rhs)>;  // An 'fadd' node which checks for single non-hazardous use. -def fadd_mlx : PatFrag<(ops node:$lhs, node:$rhs),(fadd node:$lhs, node:$rhs),[{ +def fadd_mlx : PatFrag<(ops node:$lhs, node:$rhs),(any_fadd node:$lhs, node:$rhs),[{    return hasNoVMLxHazardUse(N);  }]>;  // An 'fsub' node which checks for single non-hazardous use. -def fsub_mlx : PatFrag<(ops node:$lhs, node:$rhs),(fsub node:$lhs, node:$rhs),[{ +def fsub_mlx : PatFrag<(ops node:$lhs, node:$rhs),(any_fsub node:$lhs, node:$rhs),[{    return hasNoVMLxHazardUse(N);  }]>; diff --git a/llvm/lib/Target/ARM/ARMInstrVFP.td b/llvm/lib/Target/ARM/ARMInstrVFP.td index 6771106..e2cc97b 100644 --- a/llvm/lib/Target/ARM/ARMInstrVFP.td +++ b/llvm/lib/Target/ARM/ARMInstrVFP.td @@ -439,14 +439,14 @@ let TwoOperandAliasConstraint = "$Dn = $Dd", mayRaiseFPException = 1, Uses = [FP  def VADDD  : ADbI<0b11100, 0b11, 0, 0,                    (outs DPR:$Dd), (ins DPR:$Dn, DPR:$Dm),                    IIC_fpALU64, "vadd", ".f64\t$Dd, $Dn, $Dm", -                  [(set DPR:$Dd, (fadd DPR:$Dn, (f64 DPR:$Dm)))]>, +                  [(set DPR:$Dd, (any_fadd DPR:$Dn, (f64 DPR:$Dm)))]>,               Sched<[WriteFPALU64]>;  let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VADDS  : ASbIn<0b11100, 0b11, 0, 0,                     (outs SPR:$Sd), (ins SPR:$Sn, SPR:$Sm),                     IIC_fpALU32, "vadd", ".f32\t$Sd, $Sn, $Sm", -                   [(set SPR:$Sd, (fadd SPR:$Sn, SPR:$Sm))]>, +                   [(set SPR:$Sd, (any_fadd SPR:$Sn, SPR:$Sm))]>,               Sched<[WriteFPALU32]> {    // Some single precision VFP instructions may be executed on both NEON and    // VFP pipelines on A8. @@ -457,21 +457,21 @@ let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FP  def VADDH  : AHbI<0b11100, 0b11, 0, 0,                    (outs HPR:$Sd), (ins HPR:$Sn, HPR:$Sm),                    IIC_fpALU16, "vadd", ".f16\t$Sd, $Sn, $Sm", -                  [(set (f16 HPR:$Sd), (fadd (f16 HPR:$Sn), (f16 HPR:$Sm)))]>, +                  [(set (f16 HPR:$Sd), (any_fadd (f16 HPR:$Sn), (f16 HPR:$Sm)))]>,               Sched<[WriteFPALU32]>;  let TwoOperandAliasConstraint = "$Dn = $Dd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VSUBD  : ADbI<0b11100, 0b11, 1, 0,                    (outs DPR:$Dd), (ins DPR:$Dn, DPR:$Dm),                    IIC_fpALU64, "vsub", ".f64\t$Dd, $Dn, $Dm", -                  [(set DPR:$Dd, (fsub DPR:$Dn, (f64 DPR:$Dm)))]>, +                  [(set DPR:$Dd, (any_fsub DPR:$Dn, (f64 DPR:$Dm)))]>,               Sched<[WriteFPALU64]>;  let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VSUBS  : ASbIn<0b11100, 0b11, 1, 0,                     (outs SPR:$Sd), (ins SPR:$Sn, SPR:$Sm),                     IIC_fpALU32, "vsub", ".f32\t$Sd, $Sn, $Sm", -                   [(set SPR:$Sd, (fsub SPR:$Sn, SPR:$Sm))]>, +                   [(set SPR:$Sd, (any_fsub SPR:$Sn, SPR:$Sm))]>,               Sched<[WriteFPALU32]>{    // Some single precision VFP instructions may be executed on both NEON and    // VFP pipelines on A8. @@ -482,42 +482,42 @@ let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FP  def VSUBH  : AHbI<0b11100, 0b11, 1, 0,                    (outs HPR:$Sd), (ins HPR:$Sn, HPR:$Sm),                    IIC_fpALU16, "vsub", ".f16\t$Sd, $Sn, $Sm", -                  [(set (f16 HPR:$Sd), (fsub (f16 HPR:$Sn), (f16 HPR:$Sm)))]>, +                  [(set (f16 HPR:$Sd), (any_fsub (f16 HPR:$Sn), (f16 HPR:$Sm)))]>,              Sched<[WriteFPALU32]>;  let TwoOperandAliasConstraint = "$Dn = $Dd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VDIVD  : ADbI<0b11101, 0b00, 0, 0,                    (outs DPR:$Dd), (ins DPR:$Dn, DPR:$Dm),                    IIC_fpDIV64, "vdiv", ".f64\t$Dd, $Dn, $Dm", -                  [(set DPR:$Dd, (fdiv DPR:$Dn, (f64 DPR:$Dm)))]>, +                  [(set DPR:$Dd, (any_fdiv DPR:$Dn, (f64 DPR:$Dm)))]>,               Sched<[WriteFPDIV64]>;  let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VDIVS  : ASbI<0b11101, 0b00, 0, 0,                    (outs SPR:$Sd), (ins SPR:$Sn, SPR:$Sm),                    IIC_fpDIV32, "vdiv", ".f32\t$Sd, $Sn, $Sm", -                  [(set SPR:$Sd, (fdiv SPR:$Sn, SPR:$Sm))]>, +                  [(set SPR:$Sd, (any_fdiv SPR:$Sn, SPR:$Sm))]>,               Sched<[WriteFPDIV32]>;  let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FPSCR_RM]  in  def VDIVH  : AHbI<0b11101, 0b00, 0, 0,                    (outs HPR:$Sd), (ins HPR:$Sn, HPR:$Sm),                    IIC_fpDIV16, "vdiv", ".f16\t$Sd, $Sn, $Sm", -                  [(set (f16 HPR:$Sd), (fdiv (f16 HPR:$Sn), (f16 HPR:$Sm)))]>, +                  [(set (f16 HPR:$Sd), (any_fdiv (f16 HPR:$Sn), (f16 HPR:$Sm)))]>,               Sched<[WriteFPDIV32]>;  let TwoOperandAliasConstraint = "$Dn = $Dd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VMULD  : ADbI<0b11100, 0b10, 0, 0,                    (outs DPR:$Dd), (ins DPR:$Dn, DPR:$Dm),                    IIC_fpMUL64, "vmul", ".f64\t$Dd, $Dn, $Dm", -                  [(set DPR:$Dd, (fmul DPR:$Dn, (f64 DPR:$Dm)))]>, +                  [(set DPR:$Dd, (any_fmul DPR:$Dn, (f64 DPR:$Dm)))]>,               Sched<[WriteFPMUL64, ReadFPMUL, ReadFPMUL]>;  let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VMULS  : ASbIn<0b11100, 0b10, 0, 0,                     (outs SPR:$Sd), (ins SPR:$Sn, SPR:$Sm),                     IIC_fpMUL32, "vmul", ".f32\t$Sd, $Sn, $Sm", -                   [(set SPR:$Sd, (fmul SPR:$Sn, SPR:$Sm))]>, +                   [(set SPR:$Sd, (any_fmul SPR:$Sn, SPR:$Sm))]>,              Sched<[WriteFPMUL32, ReadFPMUL, ReadFPMUL]> {    // Some single precision VFP instructions may be executed on both NEON and    // VFP pipelines on A8. @@ -528,21 +528,21 @@ let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FP  def VMULH  : AHbI<0b11100, 0b10, 0, 0,                    (outs HPR:$Sd), (ins HPR:$Sn, HPR:$Sm),                    IIC_fpMUL16, "vmul", ".f16\t$Sd, $Sn, $Sm", -                  [(set (f16 HPR:$Sd), (fmul (f16 HPR:$Sn), (f16 HPR:$Sm)))]>, +                  [(set (f16 HPR:$Sd), (any_fmul (f16 HPR:$Sn), (f16 HPR:$Sm)))]>,               Sched<[WriteFPMUL32, ReadFPMUL, ReadFPMUL]>;  let TwoOperandAliasConstraint = "$Dn = $Dd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VNMULD : ADbI<0b11100, 0b10, 1, 0,                    (outs DPR:$Dd), (ins DPR:$Dn, DPR:$Dm),                    IIC_fpMUL64, "vnmul", ".f64\t$Dd, $Dn, $Dm", -                  [(set DPR:$Dd, (fneg (fmul DPR:$Dn, (f64 DPR:$Dm))))]>, +                  [(set DPR:$Dd, (fneg (any_fmul DPR:$Dn, (f64 DPR:$Dm))))]>,               Sched<[WriteFPMUL64, ReadFPMUL, ReadFPMUL]>;  let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VNMULS : ASbI<0b11100, 0b10, 1, 0,                    (outs SPR:$Sd), (ins SPR:$Sn, SPR:$Sm),                    IIC_fpMUL32, "vnmul", ".f32\t$Sd, $Sn, $Sm", -                  [(set SPR:$Sd, (fneg (fmul SPR:$Sn, SPR:$Sm)))]>, +                  [(set SPR:$Sd, (fneg (any_fmul SPR:$Sn, SPR:$Sm)))]>,              Sched<[WriteFPMUL32, ReadFPMUL, ReadFPMUL]> {    // Some single precision VFP instructions may be executed on both NEON and    // VFP pipelines on A8. @@ -553,7 +553,7 @@ let TwoOperandAliasConstraint = "$Sn = $Sd", mayRaiseFPException = 1, Uses = [FP  def VNMULH : AHbI<0b11100, 0b10, 1, 0,                    (outs HPR:$Sd), (ins HPR:$Sn, HPR:$Sm),                    IIC_fpMUL16, "vnmul", ".f16\t$Sd, $Sn, $Sm", -                  [(set (f16 HPR:$Sd), (fneg (fmul (f16 HPR:$Sn), (f16 HPR:$Sm))))]>, +                  [(set (f16 HPR:$Sd), (fneg (any_fmul (f16 HPR:$Sn), (f16 HPR:$Sm))))]>,               Sched<[WriteFPMUL32, ReadFPMUL, ReadFPMUL]>;  multiclass vsel_inst<string op, bits<2> opc, int CC> { @@ -587,7 +587,7 @@ defm VSELGE : vsel_inst<"ge", 0b10, 10>;  defm VSELEQ : vsel_inst<"eq", 0b00, 0>;  defm VSELVS : vsel_inst<"vs", 0b01, 6>; -multiclass vmaxmin_inst<string op, bit opc, SDNode SD> { +multiclass vmaxmin_inst<string op, bit opc, PatFrags SD> {    let DecoderNamespace = "VFPV8", PostEncoderMethod = "",        isUnpredicable = 1, mayRaiseFPException = 1 in {      def H : AHbInp<0b11101, 0b00, opc, @@ -610,8 +610,8 @@ multiclass vmaxmin_inst<string op, bit opc, SDNode SD> {    }  } -defm VFP_VMAXNM : vmaxmin_inst<"vmaxnm", 0, fmaxnum>; -defm VFP_VMINNM : vmaxmin_inst<"vminnm", 1, fminnum>; +defm VFP_VMAXNM : vmaxmin_inst<"vmaxnm", 0, any_fmaxnum>; +defm VFP_VMINNM : vmaxmin_inst<"vminnm", 1, any_fminnum>;  // Match reassociated forms only if not sign dependent rounding.  def : Pat<(fmul (fneg DPR:$a), (f64 DPR:$b)), @@ -746,7 +746,7 @@ let mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VCVTDS  : ASuI<0b11101, 0b11, 0b0111, 0b11, 0,                     (outs DPR:$Dd), (ins SPR:$Sm),                     IIC_fpCVTDS, "vcvt", ".f64.f32\t$Dd, $Sm", "", -                   [(set DPR:$Dd, (fpextend SPR:$Sm))]>, +                   [(set DPR:$Dd, (any_fpextend SPR:$Sm))]>,               Sched<[WriteFPCVT]> {    // Instruction operands.    bits<5> Dd; @@ -766,7 +766,7 @@ def VCVTDS  : ASuI<0b11101, 0b11, 0b0111, 0b11, 0,  let mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VCVTSD  : VFPAI<(outs SPR:$Sd), (ins DPR:$Dm), VFPUnaryFrm,                      IIC_fpCVTSD, "vcvt", ".f32.f64\t$Sd, $Dm", "", -                    [(set SPR:$Sd, (fpround DPR:$Dm))]>, +                    [(set SPR:$Sd, (any_fpround DPR:$Dm))]>,                Sched<[WriteFPCVT]> {    // Instruction operands.    bits<5> Sd; @@ -796,7 +796,7 @@ def VCVTBHS: ASuI<0b11101, 0b11, 0b0010, 0b01, 0, (outs SPR:$Sd), (ins SPR:$Sm),                   Requires<[HasFP16]>,               Sched<[WriteFPCVT]>; -def : FP16Pat<(f32 (fpextend (f16 HPR:$Sm))), +def : FP16Pat<(f32 (any_fpextend (f16 HPR:$Sm))),                (VCVTBHS (COPY_TO_REGCLASS (f16 HPR:$Sm), SPR))>;  def : FP16Pat<(f16_to_fp GPR:$a),                (VCVTBHS (COPY_TO_REGCLASS GPR:$a, SPR))>; @@ -808,16 +808,16 @@ def VCVTBSH: ASuI<0b11101, 0b11, 0b0011, 0b01, 0, (outs SPR:$Sd), (ins SPR:$Sda,                   Requires<[HasFP16]>,               Sched<[WriteFPCVT]>; -def : FP16Pat<(f16 (fpround SPR:$Sm)), +def : FP16Pat<(f16 (any_fpround SPR:$Sm)),                (COPY_TO_REGCLASS (VCVTBSH (IMPLICIT_DEF), SPR:$Sm), HPR)>;  def : FP16Pat<(fp_to_f16 SPR:$a),                (i32 (COPY_TO_REGCLASS (VCVTBSH (IMPLICIT_DEF), SPR:$a), GPR))>; -def : FP16Pat<(insertelt (v8f16 MQPR:$src1), (f16 (fpround (f32 SPR:$src2))), imm_even:$lane), +def : FP16Pat<(insertelt (v8f16 MQPR:$src1), (f16 (any_fpround (f32 SPR:$src2))), imm_even:$lane),                (v8f16 (INSERT_SUBREG (v8f16 MQPR:$src1),                                      (VCVTBSH (EXTRACT_SUBREG (v8f16 MQPR:$src1), (SSubReg_f16_reg imm:$lane)),                                               SPR:$src2),                                      (SSubReg_f16_reg imm:$lane)))>; -def : FP16Pat<(insertelt (v4f16 DPR:$src1), (f16 (fpround (f32 SPR:$src2))), imm_even:$lane), +def : FP16Pat<(insertelt (v4f16 DPR:$src1), (f16 (any_fpround (f32 SPR:$src2))), imm_even:$lane),                (v4f16 (INSERT_SUBREG (v4f16 DPR:$src1),                                      (VCVTBSH (EXTRACT_SUBREG (v4f16 DPR:$src1), (SSubReg_f16_reg imm:$lane)),                                               SPR:$src2), @@ -830,9 +830,9 @@ def VCVTTHS: ASuI<0b11101, 0b11, 0b0010, 0b11, 0, (outs SPR:$Sd), (ins SPR:$Sm),                   Requires<[HasFP16]>,               Sched<[WriteFPCVT]>; -def : FP16Pat<(f32 (fpextend (extractelt (v8f16 MQPR:$src), imm_odd:$lane))), +def : FP16Pat<(f32 (any_fpextend (extractelt (v8f16 MQPR:$src), imm_odd:$lane))),                (VCVTTHS (EXTRACT_SUBREG MQPR:$src, (SSubReg_f16_reg imm_odd:$lane)))>; -def : FP16Pat<(f32 (fpextend (extractelt (v4f16 DPR:$src), imm_odd:$lane))), +def : FP16Pat<(f32 (any_fpextend (extractelt (v4f16 DPR:$src), imm_odd:$lane))),                (VCVTTHS (EXTRACT_SUBREG                  (v2f32 (COPY_TO_REGCLASS (v4f16 DPR:$src), DPR_VFP2)),                  (SSubReg_f16_reg imm_odd:$lane)))>; @@ -844,12 +844,12 @@ def VCVTTSH: ASuI<0b11101, 0b11, 0b0011, 0b11, 0, (outs SPR:$Sd), (ins SPR:$Sda,                   Requires<[HasFP16]>,              Sched<[WriteFPCVT]>; -def : FP16Pat<(insertelt (v8f16 MQPR:$src1), (f16 (fpround (f32 SPR:$src2))), imm_odd:$lane), +def : FP16Pat<(insertelt (v8f16 MQPR:$src1), (f16 (any_fpround (f32 SPR:$src2))), imm_odd:$lane),                (v8f16 (INSERT_SUBREG (v8f16 MQPR:$src1),                                      (VCVTTSH (EXTRACT_SUBREG (v8f16 MQPR:$src1), (SSubReg_f16_reg imm:$lane)),                                               SPR:$src2),                                      (SSubReg_f16_reg imm:$lane)))>; -def : FP16Pat<(insertelt (v4f16 DPR:$src1), (f16 (fpround (f32 SPR:$src2))), imm_odd:$lane), +def : FP16Pat<(insertelt (v4f16 DPR:$src1), (f16 (any_fpround (f32 SPR:$src2))), imm_odd:$lane),                (v4f16 (INSERT_SUBREG (v4f16 DPR:$src1),                                      (VCVTTSH (EXTRACT_SUBREG (v4f16 DPR:$src1), (SSubReg_f16_reg imm:$lane)),                                               SPR:$src2), @@ -872,7 +872,7 @@ def VCVTBHD : ADuI<0b11101, 0b11, 0b0010, 0b01, 0,    let hasSideEffects = 0;  } -def : FullFP16Pat<(f64 (fpextend (f16 HPR:$Sm))), +def : FullFP16Pat<(f64 (any_fpextend (f16 HPR:$Sm))),                    (VCVTBHD (COPY_TO_REGCLASS (f16 HPR:$Sm), SPR))>,                    Requires<[HasFPARMv8, HasDPVFP]>;  def : FP16Pat<(f64 (f16_to_fp GPR:$a)), @@ -898,7 +898,7 @@ def VCVTBDH : ADuI<0b11101, 0b11, 0b0011, 0b01, 0,    let hasSideEffects = 0;  } -def : FullFP16Pat<(f16 (fpround DPR:$Dm)), +def : FullFP16Pat<(f16 (any_fpround DPR:$Dm)),                    (COPY_TO_REGCLASS (VCVTBDH (IMPLICIT_DEF), DPR:$Dm), HPR)>,                    Requires<[HasFPARMv8, HasDPVFP]>;  def : FP16Pat<(fp_to_f16 (f64 DPR:$a)), @@ -1007,41 +1007,41 @@ multiclass vcvt_inst<string opc, bits<2> rm,    let Predicates = [HasFPARMv8] in {      let Predicates = [HasFullFP16] in { -    def : Pat<(i32 (fp_to_sint (node (f16 HPR:$a)))), +    def : Pat<(i32 (any_fp_to_sint (node (f16 HPR:$a)))),                (COPY_TO_REGCLASS                  (!cast<Instruction>(NAME#"SH") (f16 HPR:$a)),                  GPR)>; -    def : Pat<(i32 (fp_to_uint (node (f16 HPR:$a)))), +    def : Pat<(i32 (any_fp_to_uint (node (f16 HPR:$a)))),                (COPY_TO_REGCLASS                  (!cast<Instruction>(NAME#"UH") (f16 HPR:$a)),                  GPR)>;      } -    def : Pat<(i32 (fp_to_sint (node SPR:$a))), +    def : Pat<(i32 (any_fp_to_sint (node SPR:$a))),                (COPY_TO_REGCLASS                  (!cast<Instruction>(NAME#"SS") SPR:$a),                  GPR)>; -    def : Pat<(i32 (fp_to_uint (node SPR:$a))), +    def : Pat<(i32 (any_fp_to_uint (node SPR:$a))),                (COPY_TO_REGCLASS                  (!cast<Instruction>(NAME#"US") SPR:$a),                  GPR)>;    }    let Predicates = [HasFPARMv8, HasDPVFP] in { -    def : Pat<(i32 (fp_to_sint (node (f64 DPR:$a)))), +    def : Pat<(i32 (any_fp_to_sint (node (f64 DPR:$a)))),                (COPY_TO_REGCLASS                  (!cast<Instruction>(NAME#"SD") DPR:$a),                  GPR)>; -    def : Pat<(i32 (fp_to_uint (node (f64 DPR:$a)))), +    def : Pat<(i32 (any_fp_to_uint (node (f64 DPR:$a)))),                (COPY_TO_REGCLASS                  (!cast<Instruction>(NAME#"UD") DPR:$a),                  GPR)>;    }  } -defm VCVTA : vcvt_inst<"a", 0b00, fround>; +defm VCVTA : vcvt_inst<"a", 0b00, any_fround>;  defm VCVTN : vcvt_inst<"n", 0b01>; -defm VCVTP : vcvt_inst<"p", 0b10, fceil>; -defm VCVTM : vcvt_inst<"m", 0b11, ffloor>; +defm VCVTP : vcvt_inst<"p", 0b10, any_fceil>; +defm VCVTM : vcvt_inst<"m", 0b11, any_ffloor>;  def VNEGD  : ADuI<0b11101, 0b11, 0b0001, 0b01, 0,                    (outs DPR:$Dd), (ins DPR:$Dm), @@ -1103,9 +1103,9 @@ multiclass vrint_inst_zrx<string opc, bit op, bit op2, SDPatternOperator node,          Requires<[HasFPARMv8,HasDPVFP]>;  } -defm VRINTZ : vrint_inst_zrx<"z", 0, 1, ftrunc, [], 0>; -defm VRINTR : vrint_inst_zrx<"r", 0, 0, fnearbyint, [FPSCR_RM], 0>; -defm VRINTX : vrint_inst_zrx<"x", 1, 0, frint, [FPSCR_RM], 1>; +defm VRINTZ : vrint_inst_zrx<"z", 0, 1, any_ftrunc, [], 0>; +defm VRINTR : vrint_inst_zrx<"r", 0, 0, any_fnearbyint, [FPSCR_RM], 0>; +defm VRINTX : vrint_inst_zrx<"x", 1, 0, any_frint, [FPSCR_RM], 1>;  multiclass vrint_inst_anpm<string opc, bits<2> rm,                             SDPatternOperator node = null_frag> { @@ -1145,30 +1145,31 @@ multiclass vrint_inst_anpm<string opc, bits<2> rm,          Requires<[HasFPARMv8,HasDPVFP]>;  } -defm VRINTA : vrint_inst_anpm<"a", 0b00, fround>; -defm VRINTN : vrint_inst_anpm<"n", 0b01, froundeven>; -defm VRINTP : vrint_inst_anpm<"p", 0b10, fceil>; -defm VRINTM : vrint_inst_anpm<"m", 0b11, ffloor>; +defm VRINTA : vrint_inst_anpm<"a", 0b00, any_fround>; +defm VRINTN : vrint_inst_anpm<"n", 0b01, any_froundeven>; +defm VRINTP : vrint_inst_anpm<"p", 0b10, any_fceil>; +defm VRINTM : vrint_inst_anpm<"m", 0b11, any_ffloor>; +  let mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VSQRTD : ADuI<0b11101, 0b11, 0b0001, 0b11, 0,                    (outs DPR:$Dd), (ins DPR:$Dm),                    IIC_fpSQRT64, "vsqrt", ".f64\t$Dd, $Dm", "", -                  [(set DPR:$Dd, (fsqrt (f64 DPR:$Dm)))]>, +                  [(set DPR:$Dd, (any_fsqrt (f64 DPR:$Dm)))]>,               Sched<[WriteFPSQRT64]>;  let mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VSQRTS : ASuI<0b11101, 0b11, 0b0001, 0b11, 0,                    (outs SPR:$Sd), (ins SPR:$Sm),                    IIC_fpSQRT32, "vsqrt", ".f32\t$Sd, $Sm", "", -                  [(set SPR:$Sd, (fsqrt SPR:$Sm))]>, +                  [(set SPR:$Sd, (any_fsqrt SPR:$Sm))]>,               Sched<[WriteFPSQRT32]>;  let mayRaiseFPException = 1, Uses = [FPSCR_RM] in  def VSQRTH : AHuI<0b11101, 0b11, 0b0001, 0b11, 0,                    (outs HPR:$Sd), (ins HPR:$Sm),                    IIC_fpSQRT16, "vsqrt", ".f16\t$Sd, $Sm", -                  [(set (f16 HPR:$Sd), (fsqrt (f16 HPR:$Sm)))]>; +                  [(set (f16 HPR:$Sd), (any_fsqrt (f16 HPR:$Sm)))]>;  let hasSideEffects = 0 in {  let isMoveReg = 1 in { @@ -1509,10 +1510,10 @@ def VSITOD : AVConv1IDs_Encode<0b11101, 0b11, 0b1000, 0b1011,  }  let Predicates=[HasVFP2, HasDPVFP] in { -  def : VFPPat<(f64 (sint_to_fp GPR:$a)), +  def : VFPPat<(f64 (any_sint_to_fp GPR:$a)),                 (VSITOD (COPY_TO_REGCLASS GPR:$a, SPR))>; -  def : VFPPat<(f64 (sint_to_fp (i32 (alignedload32 addrmode5:$a)))), +  def : VFPPat<(f64 (any_sint_to_fp (i32 (alignedload32 addrmode5:$a)))),                 (VSITOD (VLDRS addrmode5:$a))>;  } @@ -1529,10 +1530,10 @@ def VSITOS : AVConv1InSs_Encode<0b11101, 0b11, 0b1000, 0b1010,    let D = VFPNeonA8Domain;  } -def : VFPNoNEONPat<(f32 (sint_to_fp GPR:$a)), +def : VFPNoNEONPat<(f32 (any_sint_to_fp GPR:$a)),                     (VSITOS (COPY_TO_REGCLASS GPR:$a, SPR))>; -def : VFPNoNEONPat<(f32 (sint_to_fp (i32 (alignedload32 addrmode5:$a)))), +def : VFPNoNEONPat<(f32 (any_sint_to_fp (i32 (alignedload32 addrmode5:$a)))),                     (VSITOS (VLDRS addrmode5:$a))>;  let mayRaiseFPException = 1 in  @@ -1545,7 +1546,7 @@ def VSITOH : AVConv1IHs_Encode<0b11101, 0b11, 0b1000, 0b1001,    let isUnpredicable = 1;  } -def : VFPNoNEONPat<(f16 (sint_to_fp GPR:$a)), +def : VFPNoNEONPat<(f16 (any_sint_to_fp GPR:$a)),                     (VSITOH (COPY_TO_REGCLASS GPR:$a, SPR))>;  let mayRaiseFPException = 1 in  @@ -1558,10 +1559,10 @@ def VUITOD : AVConv1IDs_Encode<0b11101, 0b11, 0b1000, 0b1011,  }  let Predicates=[HasVFP2, HasDPVFP] in { -  def : VFPPat<(f64 (uint_to_fp GPR:$a)), +  def : VFPPat<(f64 (any_uint_to_fp GPR:$a)),                 (VUITOD (COPY_TO_REGCLASS GPR:$a, SPR))>; -  def : VFPPat<(f64 (uint_to_fp (i32 (alignedload32 addrmode5:$a)))), +  def : VFPPat<(f64 (any_uint_to_fp (i32 (alignedload32 addrmode5:$a)))),                 (VUITOD (VLDRS addrmode5:$a))>;  } @@ -1578,10 +1579,10 @@ def VUITOS : AVConv1InSs_Encode<0b11101, 0b11, 0b1000, 0b1010,    let D = VFPNeonA8Domain;  } -def : VFPNoNEONPat<(f32 (uint_to_fp GPR:$a)), +def : VFPNoNEONPat<(f32 (any_uint_to_fp GPR:$a)),                     (VUITOS (COPY_TO_REGCLASS GPR:$a, SPR))>; -def : VFPNoNEONPat<(f32 (uint_to_fp (i32 (alignedload32 addrmode5:$a)))), +def : VFPNoNEONPat<(f32 (any_uint_to_fp (i32 (alignedload32 addrmode5:$a)))),                     (VUITOS (VLDRS addrmode5:$a))>;  let mayRaiseFPException = 1 in  @@ -1594,7 +1595,7 @@ def VUITOH : AVConv1IHs_Encode<0b11101, 0b11, 0b1000, 0b1001,    let isUnpredicable = 1;  } -def : VFPNoNEONPat<(f16 (uint_to_fp GPR:$a)), +def : VFPNoNEONPat<(f16 (any_uint_to_fp GPR:$a)),                     (VUITOH (COPY_TO_REGCLASS GPR:$a, SPR))>;  // FP -> Int: @@ -1669,12 +1670,12 @@ def VTOSIZD : AVConv1IsD_Encode<0b11101, 0b11, 0b1101, 0b1011,  }  let Predicates=[HasVFP2, HasDPVFP] in { -  def : VFPPat<(i32 (fp_to_sint (f64 DPR:$a))), +  def : VFPPat<(i32 (any_fp_to_sint (f64 DPR:$a))),                 (COPY_TO_REGCLASS (VTOSIZD DPR:$a), GPR)>;    def : VFPPat<(i32 (fp_to_sint_sat (f64 DPR:$a), i32)),                 (COPY_TO_REGCLASS (VTOSIZD DPR:$a), GPR)>; -  def : VFPPat<(alignedstore32 (i32 (fp_to_sint (f64 DPR:$a))), addrmode5:$ptr), +  def : VFPPat<(alignedstore32 (i32 (any_fp_to_sint (f64 DPR:$a))), addrmode5:$ptr),                 (VSTRS (VTOSIZD DPR:$a), addrmode5:$ptr)>;    def : VFPPat<(alignedstore32 (i32 (fp_to_sint_sat (f64 DPR:$a), i32)), addrmode5:$ptr),                 (VSTRS (VTOSIZD DPR:$a), addrmode5:$ptr)>; @@ -1693,12 +1694,12 @@ def VTOSIZS : AVConv1InsS_Encode<0b11101, 0b11, 0b1101, 0b1010,    let D = VFPNeonA8Domain;  } -def : VFPNoNEONPat<(i32 (fp_to_sint SPR:$a)), +def : VFPNoNEONPat<(i32 (any_fp_to_sint SPR:$a)),                     (COPY_TO_REGCLASS (VTOSIZS SPR:$a), GPR)>;  def : VFPPat<(i32 (fp_to_sint_sat SPR:$a, i32)),               (COPY_TO_REGCLASS (VTOSIZS SPR:$a), GPR)>; -def : VFPNoNEONPat<(alignedstore32 (i32 (fp_to_sint (f32 SPR:$a))), +def : VFPNoNEONPat<(alignedstore32 (i32 (any_fp_to_sint (f32 SPR:$a))),                                     addrmode5:$ptr),                     (VSTRS (VTOSIZS SPR:$a), addrmode5:$ptr)>;  def : VFPPat<(alignedstore32 (i32 (fp_to_sint_sat (f32 SPR:$a), i32)), @@ -1715,7 +1716,7 @@ def VTOSIZH : AVConv1IsH_Encode<0b11101, 0b11, 0b1101, 0b1001,    let isUnpredicable = 1;  } -def : VFPNoNEONPat<(i32 (fp_to_sint (f16 HPR:$a))), +def : VFPNoNEONPat<(i32 (any_fp_to_sint (f16 HPR:$a))),                     (COPY_TO_REGCLASS (VTOSIZH (f16 HPR:$a)), GPR)>;  def : VFPPat<(i32 (fp_to_sint_sat (f16 HPR:$a), i32)),               (COPY_TO_REGCLASS (VTOSIZH (f16 HPR:$a)), GPR)>; @@ -1730,12 +1731,12 @@ def VTOUIZD : AVConv1IsD_Encode<0b11101, 0b11, 0b1100, 0b1011,  }  let Predicates=[HasVFP2, HasDPVFP] in { -  def : VFPPat<(i32 (fp_to_uint (f64 DPR:$a))), +  def : VFPPat<(i32 (any_fp_to_uint (f64 DPR:$a))),                 (COPY_TO_REGCLASS (VTOUIZD DPR:$a), GPR)>;    def : VFPPat<(i32 (fp_to_uint_sat (f64 DPR:$a), i32)),                 (COPY_TO_REGCLASS (VTOUIZD DPR:$a), GPR)>; -  def : VFPPat<(alignedstore32 (i32 (fp_to_uint (f64 DPR:$a))), addrmode5:$ptr), +  def : VFPPat<(alignedstore32 (i32 (any_fp_to_uint (f64 DPR:$a))), addrmode5:$ptr),                 (VSTRS (VTOUIZD DPR:$a), addrmode5:$ptr)>;    def : VFPPat<(alignedstore32 (i32 (fp_to_uint_sat (f64 DPR:$a), i32)), addrmode5:$ptr),                 (VSTRS (VTOUIZD DPR:$a), addrmode5:$ptr)>; @@ -1754,12 +1755,12 @@ def VTOUIZS : AVConv1InsS_Encode<0b11101, 0b11, 0b1100, 0b1010,    let D = VFPNeonA8Domain;  } -def : VFPNoNEONPat<(i32 (fp_to_uint SPR:$a)), +def : VFPNoNEONPat<(i32 (any_fp_to_uint SPR:$a)),                     (COPY_TO_REGCLASS (VTOUIZS SPR:$a), GPR)>;  def : VFPPat<(i32 (fp_to_uint_sat SPR:$a, i32)),               (COPY_TO_REGCLASS (VTOUIZS SPR:$a), GPR)>; -def : VFPNoNEONPat<(alignedstore32 (i32 (fp_to_uint (f32 SPR:$a))), +def : VFPNoNEONPat<(alignedstore32 (i32 (any_fp_to_uint (f32 SPR:$a))),                                     addrmode5:$ptr),                    (VSTRS (VTOUIZS SPR:$a), addrmode5:$ptr)>;  def : VFPPat<(alignedstore32 (i32 (fp_to_uint_sat (f32 SPR:$a), i32)), @@ -1776,7 +1777,7 @@ def VTOUIZH : AVConv1IsH_Encode<0b11101, 0b11, 0b1100, 0b1001,    let isUnpredicable = 1;  } -def : VFPNoNEONPat<(i32 (fp_to_uint (f16 HPR:$a))), +def : VFPNoNEONPat<(i32 (any_fp_to_uint (f16 HPR:$a))),                     (COPY_TO_REGCLASS (VTOUIZH (f16 HPR:$a)), GPR)>;  def : VFPPat<(i32 (fp_to_uint_sat (f16 HPR:$a), i32)),               (COPY_TO_REGCLASS (VTOUIZH (f16 HPR:$a)), GPR)>; @@ -2320,13 +2321,13 @@ def : Pat<(fadd_mlx HPR:$dstin, (fmul_su (f16 HPR:$a), HPR:$b)),  // Match @llvm.fma.* intrinsics  // (fma x, y, z) -> (vfms z, x, y) -def : Pat<(f64 (fma DPR:$Dn, DPR:$Dm, DPR:$Ddin)), +def : Pat<(f64 (any_fma DPR:$Dn, DPR:$Dm, DPR:$Ddin)),            (VFMAD DPR:$Ddin, DPR:$Dn, DPR:$Dm)>,        Requires<[HasVFP4,HasDPVFP]>; -def : Pat<(f32 (fma SPR:$Sn, SPR:$Sm, SPR:$Sdin)), +def : Pat<(f32 (any_fma SPR:$Sn, SPR:$Sm, SPR:$Sdin)),            (VFMAS SPR:$Sdin, SPR:$Sn, SPR:$Sm)>,        Requires<[HasVFP4]>; -def : Pat<(f16 (fma HPR:$Sn, HPR:$Sm, (f16 HPR:$Sdin))), +def : Pat<(f16 (any_fma HPR:$Sn, HPR:$Sm, (f16 HPR:$Sdin))),            (VFMAH (f16 HPR:$Sdin), (f16 HPR:$Sn), (f16 HPR:$Sm))>,        Requires<[HasFullFP16]>; @@ -2375,13 +2376,13 @@ def : Pat<(fsub_mlx HPR:$dstin, (fmul_su (f16 HPR:$a), HPR:$b)),  // Match @llvm.fma.* intrinsics  // (fma (fneg x), y, z) -> (vfms z, x, y) -def : Pat<(f64 (fma (fneg DPR:$Dn), DPR:$Dm, DPR:$Ddin)), +def : Pat<(f64 (any_fma (fneg DPR:$Dn), DPR:$Dm, DPR:$Ddin)),            (VFMSD DPR:$Ddin, DPR:$Dn, DPR:$Dm)>,        Requires<[HasVFP4,HasDPVFP]>; -def : Pat<(f32 (fma (fneg SPR:$Sn), SPR:$Sm, SPR:$Sdin)), +def : Pat<(f32 (any_fma (fneg SPR:$Sn), SPR:$Sm, SPR:$Sdin)),            (VFMSS SPR:$Sdin, SPR:$Sn, SPR:$Sm)>,        Requires<[HasVFP4]>; -def : Pat<(f16 (fma (fneg (f16 HPR:$Sn)), (f16 HPR:$Sm), (f16 HPR:$Sdin))), +def : Pat<(f16 (any_fma (fneg (f16 HPR:$Sn)), (f16 HPR:$Sm), (f16 HPR:$Sdin))),            (VFMSH (f16 HPR:$Sdin), (f16 HPR:$Sn), (f16 HPR:$Sm))>,        Requires<[HasFullFP16]>; @@ -2427,23 +2428,23 @@ def : Pat<(fsub_mlx (fneg (fmul_su SPR:$a, SPR:$b)), SPR:$dstin),  // Match @llvm.fma.* intrinsics  // (fneg (fma x, y, z)) -> (vfnma z, x, y) -def : Pat<(fneg (fma (f64 DPR:$Dn), (f64 DPR:$Dm), (f64 DPR:$Ddin))), +def : Pat<(fneg (any_fma (f64 DPR:$Dn), (f64 DPR:$Dm), (f64 DPR:$Ddin))),            (VFNMAD DPR:$Ddin, DPR:$Dn, DPR:$Dm)>,        Requires<[HasVFP4,HasDPVFP]>; -def : Pat<(fneg (fma (f32 SPR:$Sn), (f32 SPR:$Sm), (f32 SPR:$Sdin))), +def : Pat<(fneg (any_fma (f32 SPR:$Sn), (f32 SPR:$Sm), (f32 SPR:$Sdin))),            (VFNMAS SPR:$Sdin, SPR:$Sn, SPR:$Sm)>,        Requires<[HasVFP4]>; -def : Pat<(fneg (fma (f16 HPR:$Sn), (f16 HPR:$Sm), (f16 (f16 HPR:$Sdin)))), +def : Pat<(fneg (any_fma (f16 HPR:$Sn), (f16 HPR:$Sm), (f16 (f16 HPR:$Sdin)))),            (VFNMAH (f16 HPR:$Sdin), (f16 HPR:$Sn), (f16 HPR:$Sm))>,        Requires<[HasFullFP16]>;  // (fma (fneg x), y, (fneg z)) -> (vfnma z, x, y) -def : Pat<(f64 (fma (fneg DPR:$Dn), DPR:$Dm, (fneg DPR:$Ddin))), +def : Pat<(f64 (any_fma (fneg DPR:$Dn), DPR:$Dm, (fneg DPR:$Ddin))),            (VFNMAD DPR:$Ddin, DPR:$Dn, DPR:$Dm)>,        Requires<[HasVFP4,HasDPVFP]>; -def : Pat<(f32 (fma (fneg SPR:$Sn), SPR:$Sm, (fneg SPR:$Sdin))), +def : Pat<(f32 (any_fma (fneg SPR:$Sn), SPR:$Sm, (fneg SPR:$Sdin))),            (VFNMAS SPR:$Sdin, SPR:$Sn, SPR:$Sm)>,        Requires<[HasVFP4]>; -def : Pat<(f16 (fma (fneg (f16 HPR:$Sn)), (f16 HPR:$Sm), (fneg (f16 HPR:$Sdin)))), +def : Pat<(f16 (any_fma (fneg (f16 HPR:$Sn)), (f16 HPR:$Sm), (fneg (f16 HPR:$Sdin)))),            (VFNMAH (f16 HPR:$Sdin), (f16 HPR:$Sn), (f16 HPR:$Sm))>,        Requires<[HasFullFP16]>; @@ -2488,23 +2489,23 @@ def : Pat<(fsub_mlx (fmul_su SPR:$a, SPR:$b), SPR:$dstin),  // Match @llvm.fma.* intrinsics  // (fma x, y, (fneg z)) -> (vfnms z, x, y)) -def : Pat<(f64 (fma DPR:$Dn, DPR:$Dm, (fneg DPR:$Ddin))), +def : Pat<(f64 (any_fma DPR:$Dn, DPR:$Dm, (fneg DPR:$Ddin))),            (VFNMSD DPR:$Ddin, DPR:$Dn, DPR:$Dm)>,        Requires<[HasVFP4,HasDPVFP]>; -def : Pat<(f32 (fma SPR:$Sn, SPR:$Sm, (fneg SPR:$Sdin))), +def : Pat<(f32 (any_fma SPR:$Sn, SPR:$Sm, (fneg SPR:$Sdin))),            (VFNMSS SPR:$Sdin, SPR:$Sn, SPR:$Sm)>,        Requires<[HasVFP4]>; -def : Pat<(f16 (fma (f16 HPR:$Sn), (f16 HPR:$Sm), (fneg (f16 HPR:$Sdin)))), +def : Pat<(f16 (any_fma (f16 HPR:$Sn), (f16 HPR:$Sm), (fneg (f16 HPR:$Sdin)))),            (VFNMSH (f16 HPR:$Sdin), (f16 HPR:$Sn), (f16 HPR:$Sm))>,        Requires<[HasFullFP16]>;  // (fneg (fma (fneg x), y, z)) -> (vfnms z, x, y) -def : Pat<(fneg (f64 (fma (fneg DPR:$Dn), DPR:$Dm, DPR:$Ddin))), +def : Pat<(fneg (f64 (any_fma (fneg DPR:$Dn), DPR:$Dm, DPR:$Ddin))),            (VFNMSD DPR:$Ddin, DPR:$Dn, DPR:$Dm)>,        Requires<[HasVFP4,HasDPVFP]>; -def : Pat<(fneg (f32 (fma (fneg SPR:$Sn), SPR:$Sm, SPR:$Sdin))), +def : Pat<(fneg (f32 (any_fma (fneg SPR:$Sn), SPR:$Sm, SPR:$Sdin))),            (VFNMSS SPR:$Sdin, SPR:$Sn, SPR:$Sm)>,        Requires<[HasVFP4]>; -def : Pat<(fneg (f16 (fma (fneg (f16 HPR:$Sn)), (f16 HPR:$Sm), (f16 HPR:$Sdin)))), +def : Pat<(fneg (f16 (any_fma (fneg (f16 HPR:$Sn)), (f16 HPR:$Sm), (f16 HPR:$Sdin)))),            (VFNMSH (f16 HPR:$Sdin), (f16 HPR:$Sn), (f16 HPR:$Sm))>,        Requires<[HasFullFP16]>; diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index 1e4797b..cf8b833 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -36,9 +36,10 @@ using namespace llvm;  using namespace llvm::dxil;  namespace { -/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic -/// for TranslateMetadata pass -class DiagnosticInfoTranslateMD : public DiagnosticInfo { + +/// A simple wrapper of DiagnosticInfo that generates module-level diagnostic +/// for the DXILValidateMetadata pass +class DiagnosticInfoValidateMD : public DiagnosticInfo {  private:    const Twine &Msg;    const Module &Mod; @@ -47,9 +48,9 @@ public:    /// \p M is the module for which the diagnostic is being emitted. \p Msg is    /// the message to show. Note that this class does not copy this message, so    /// this reference must be valid for the whole life time of the diagnostic. -  DiagnosticInfoTranslateMD(const Module &M, -                            const Twine &Msg LLVM_LIFETIME_BOUND, -                            DiagnosticSeverity Severity = DS_Error) +  DiagnosticInfoValidateMD(const Module &M, +                           const Twine &Msg LLVM_LIFETIME_BOUND, +                           DiagnosticSeverity Severity = DS_Error)        : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}    void print(DiagnosticPrinter &DP) const override { @@ -57,6 +58,16 @@ public:    }  }; +static void reportError(Module &M, Twine Message, +                        DiagnosticSeverity Severity = DS_Error) { +  M.getContext().diagnose(DiagnosticInfoValidateMD(M, Message, Severity)); +} + +static void reportLoopError(Module &M, Twine Message, +                            DiagnosticSeverity Severity = DS_Error) { +  reportError(M, Twine("Invalid \"llvm.loop\" metadata: ") + Message, Severity); +} +  enum class EntryPropsTag {    ShaderFlags = 0,    GSState, @@ -314,25 +325,122 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {    BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);  } -static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) { +// Determines if the metadata node will be compatible with DXIL's loop metadata +// representation. +// +// Reports an error for compatible metadata that is ill-formed. +static bool isLoopMDCompatible(Module &M, Metadata *MD) { +  // DXIL only accepts the following loop hints: +  std::array<StringLiteral, 3> ValidHintNames = {"llvm.loop.unroll.count", +                                                 "llvm.loop.unroll.disable", +                                                 "llvm.loop.unroll.full"}; + +  MDNode *HintMD = dyn_cast<MDNode>(MD); +  if (!HintMD || HintMD->getNumOperands() == 0) +    return false; + +  auto *HintStr = dyn_cast<MDString>(HintMD->getOperand(0)); +  if (!HintStr) +    return false; + +  if (!llvm::is_contained(ValidHintNames, HintStr->getString())) +    return false; + +  auto ValidCountNode = [](MDNode *CountMD) -> bool { +    if (CountMD->getNumOperands() == 2) +      if (auto *Count = dyn_cast<ConstantAsMetadata>(CountMD->getOperand(1))) +        if (isa<ConstantInt>(Count->getValue())) +          return true; +    return false; +  }; + +  if (HintStr->getString() == "llvm.loop.unroll.count") { +    if (!ValidCountNode(HintMD)) { +      reportLoopError(M, "\"llvm.loop.unroll.count\" must have 2 operands and " +                         "the second must be a constant integer"); +      return false; +    } +  } else if (HintMD->getNumOperands() != 1) { +    reportLoopError( +        M, "\"llvm.loop.unroll.disable\" and \"llvm.loop.unroll.full\" " +           "must be provided as a single operand"); +    return false; +  } + +  return true; +} + +static void translateLoopMetadata(Module &M, Instruction *I, MDNode *BaseMD) { +  // A distinct node has the self-referential form: !0 = !{ !0, ... } +  auto IsDistinctNode = [](MDNode *Node) -> bool { +    return Node && Node->getNumOperands() != 0 && Node == Node->getOperand(0); +  }; + +  // Set metadata to null to remove empty/ill-formed metadata from instruction +  if (BaseMD->getNumOperands() == 0 || !IsDistinctNode(BaseMD)) +    return I->setMetadata("llvm.loop", nullptr); + +  // It is valid to have a chain of self-refential loop metadata nodes, as +  // below. We will collapse these into just one when we reconstruct the +  // metadata. +  // +  // Eg: +  // !0 = !{!0, !1} +  // !1 = !{!1, !2} +  // !2 = !{!"llvm.loop.unroll.disable"} +  // +  // So, traverse down a potential self-referential chain +  while (1 < BaseMD->getNumOperands() && +         IsDistinctNode(dyn_cast<MDNode>(BaseMD->getOperand(1)))) +    BaseMD = dyn_cast<MDNode>(BaseMD->getOperand(1)); + +  // To reconstruct a distinct node we create a temporary node that we will +  // then update to create a self-reference. +  llvm::TempMDTuple TempNode = llvm::MDNode::getTemporary(M.getContext(), {}); +  SmallVector<Metadata *> CompatibleOperands = {TempNode.get()}; + +  // Iterate and reconstruct the metadata nodes that contains any hints, +  // stripping any unrecognized metadata. +  ArrayRef<MDOperand> Operands = BaseMD->operands(); +  for (auto &Op : Operands.drop_front()) +    if (isLoopMDCompatible(M, Op.get())) +      CompatibleOperands.push_back(Op.get()); + +  if (2 < CompatibleOperands.size()) +    reportLoopError(M, "Provided conflicting hints"); + +  MDNode *CompatibleLoopMD = MDNode::get(M.getContext(), CompatibleOperands); +  TempNode->replaceAllUsesWith(CompatibleLoopMD); + +  I->setMetadata("llvm.loop", CompatibleLoopMD); +} + +using InstructionMDList = std::array<unsigned, 7>; + +static InstructionMDList getCompatibleInstructionMDs(llvm::Module &M) {    return {        M.getMDKindID("dx.nonuniform"),    M.getMDKindID("dx.controlflow.hints"),        M.getMDKindID("dx.precise"),       llvm::LLVMContext::MD_range, -      llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias}; +      llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias, +      M.getMDKindID("llvm.loop")};  }  static void translateInstructionMetadata(Module &M) {    // construct allowlist of valid metadata node kinds -  std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M); +  InstructionMDList DXILCompatibleMDs = getCompatibleInstructionMDs(M); +  unsigned char MDLoopKind = M.getContext().getMDKindID("llvm.loop");    for (Function &F : M) {      for (BasicBlock &BB : F) {        // This needs to be done first so that "hlsl.controlflow.hints" isn't -      // removed in the whitelist below +      // removed in the allow-list below        if (auto *I = BB.getTerminator())          translateBranchMetadata(M, I);        for (auto &I : make_early_inc_range(BB)) { +        if (isa<BranchInst>(I)) +          if (MDNode *LoopMD = I.getMetadata(MDLoopKind)) +            translateLoopMetadata(M, &I, LoopMD);          I.dropUnknownNonDebugMetadata(DXILCompatibleMDs);        }      } @@ -364,6 +472,16 @@ static void cleanModuleFlags(Module &M) {      M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val);  } +using GlobalMDList = std::array<StringLiteral, 7>; + +// The following are compatible with DXIL but not emit with clang, they can +// be added when applicable: +// dx.typeAnnotations, dx.viewIDState, dx.dxrPayloadAnnotations +static GlobalMDList CompatibleNamedModuleMDs = { +    "llvm.ident",     "llvm.module.flags", "dx.resources",   "dx.valver", +    "dx.shaderModel", "dx.version",        "dx.entryPoints", +}; +  static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,                                      DXILResourceTypeMap &DRTM,                                      const ModuleShaderFlags &ShaderFlags, @@ -389,31 +507,23 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,      uint64_t CombinedMask = ShaderFlags.getCombinedFlags();      EntryFnMDNodes.emplace_back(          emitTopLevelLibraryNode(M, ResourceMD, CombinedMask)); -  } else if (MMDI.EntryPropertyVec.size() > 1) { -    M.getContext().diagnose(DiagnosticInfoTranslateMD( -        M, "Non-library shader: One and only one entry expected")); -  } +  } else if (1 < MMDI.EntryPropertyVec.size()) +    reportError(M, "Non-library shader: One and only one entry expected");    for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) { -    const ComputedShaderFlags &EntrySFMask = -        ShaderFlags.getFunctionFlags(EntryProp.Entry); - -    // If ShaderProfile is Library, mask is already consolidated in the -    // top-level library node. Hence it is not emitted.      uint64_t EntryShaderFlags = 0;      if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { -      EntryShaderFlags = EntrySFMask; -      if (EntryProp.ShaderStage != MMDI.ShaderProfile) { -        M.getContext().diagnose(DiagnosticInfoTranslateMD( -            M, -            "Shader stage '" + -                Twine(getShortShaderStage(EntryProp.ShaderStage) + -                      "' for entry '" + Twine(EntryProp.Entry->getName()) + -                      "' different from specified target profile '" + -                      Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) + -                            "'")))); -      } +      EntryShaderFlags = ShaderFlags.getFunctionFlags(EntryProp.Entry); +      if (EntryProp.ShaderStage != MMDI.ShaderProfile) +        reportError( +            M, "Shader stage '" + +                   Twine(getShortShaderStage(EntryProp.ShaderStage)) + +                   "' for entry '" + Twine(EntryProp.Entry->getName()) + +                   "' different from specified target profile '" + +                   Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) + +                         "'"));      } +      EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,                                              EntryShaderFlags,                                              MMDI.ShaderProfile)); @@ -426,19 +536,17 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,    cleanModuleFlags(M); -  // dx.rootsignatures will have been parsed from its metadata form as its -  // binary form as part of the RootSignatureAnalysisWrapper, so safely -  // remove it as it is not recognized in DXIL -  if (NamedMDNode *RootSignature = M.getNamedMetadata("dx.rootsignatures")) -    RootSignature->eraseFromParent(); +  // Finally, strip all module metadata that is not explicitly specified in the +  // allow-list +  SmallVector<NamedMDNode *> ToStrip; -  // llvm.errno.tbaa was recently added but is not supported in LLVM 3.7 and -  // causes all tests using the DXIL Validator to fail. -  // -  // This is a temporary fix and should be replaced with a allowlist once -  // we have determined all metadata that the DXIL Validator allows -  if (NamedMDNode *ErrNo = M.getNamedMetadata("llvm.errno.tbaa")) -    ErrNo->eraseFromParent(); +  for (NamedMDNode &NamedMD : M.named_metadata()) +    if (!NamedMD.getName().starts_with("llvm.dbg.") && +        !llvm::is_contained(CompatibleNamedModuleMDs, NamedMD.getName())) +      ToStrip.push_back(&NamedMD); + +  for (NamedMDNode *NamedMD : ToStrip) +    NamedMD->eraseFromParent();  }  PreservedAnalyses DXILTranslateMetadata::run(Module &M, @@ -454,45 +562,34 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,    return PreservedAnalyses::all();  } -namespace { -class DXILTranslateMetadataLegacy : public ModulePass { -public: -  static char ID; // Pass identification, replacement for typeid -  explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {} - -  StringRef getPassName() const override { return "DXIL Translate Metadata"; } - -  void getAnalysisUsage(AnalysisUsage &AU) const override { -    AU.addRequired<DXILResourceTypeWrapperPass>(); -    AU.addRequired<DXILResourceWrapperPass>(); -    AU.addRequired<ShaderFlagsAnalysisWrapper>(); -    AU.addRequired<DXILMetadataAnalysisWrapperPass>(); -    AU.addRequired<RootSignatureAnalysisWrapper>(); - -    AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); -    AU.addPreserved<DXILResourceBindingWrapperPass>(); -    AU.addPreserved<DXILResourceWrapperPass>(); -    AU.addPreserved<RootSignatureAnalysisWrapper>(); -    AU.addPreserved<ShaderFlagsAnalysisWrapper>(); -  } +void DXILTranslateMetadataLegacy::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.addRequired<DXILResourceTypeWrapperPass>(); +  AU.addRequired<DXILResourceWrapperPass>(); +  AU.addRequired<ShaderFlagsAnalysisWrapper>(); +  AU.addRequired<DXILMetadataAnalysisWrapperPass>(); +  AU.addRequired<RootSignatureAnalysisWrapper>(); + +  AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); +  AU.addPreserved<DXILResourceBindingWrapperPass>(); +  AU.addPreserved<DXILResourceWrapperPass>(); +  AU.addPreserved<RootSignatureAnalysisWrapper>(); +  AU.addPreserved<ShaderFlagsAnalysisWrapper>(); +} -  bool runOnModule(Module &M) override { -    DXILResourceMap &DRM = -        getAnalysis<DXILResourceWrapperPass>().getResourceMap(); -    DXILResourceTypeMap &DRTM = -        getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); -    const ModuleShaderFlags &ShaderFlags = -        getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags(); -    dxil::ModuleMetadataInfo MMDI = -        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); - -    translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI); -    translateInstructionMetadata(M); -    return true; -  } -}; +bool DXILTranslateMetadataLegacy::runOnModule(Module &M) { +  DXILResourceMap &DRM = +      getAnalysis<DXILResourceWrapperPass>().getResourceMap(); +  DXILResourceTypeMap &DRTM = +      getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); +  const ModuleShaderFlags &ShaderFlags = +      getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags(); +  dxil::ModuleMetadataInfo MMDI = +      getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); -} // namespace +  translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI); +  translateInstructionMetadata(M); +  return true; +}  char DXILTranslateMetadataLegacy::ID = 0; diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.h b/llvm/lib/Target/DirectX/DXILTranslateMetadata.h index 4c1ffac..cfb8aaa 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.h +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.h @@ -10,6 +10,7 @@  #define LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H  #include "llvm/IR/PassManager.h" +#include "llvm/Pass.h"  namespace llvm { @@ -20,6 +21,22 @@ public:    PreservedAnalyses run(Module &M, ModuleAnalysisManager &);  }; +/// Wrapper pass for the legacy pass manager. +/// +/// This is required because the passes that will depend on this are codegen +/// passes which run through the legacy pass manager. +class DXILTranslateMetadataLegacy : public ModulePass { +public: +  static char ID; // Pass identification, replacement for typeid +  explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {} + +  StringRef getPassName() const override { return "DXIL Translate Metadata"; } + +  void getAnalysisUsage(AnalysisUsage &AU) const override; + +  bool runOnModule(Module &M) override; +}; +  } // namespace llvm  #endif // LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H diff --git a/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp b/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp index 3b810d0..79863e1 100644 --- a/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp +++ b/llvm/lib/Target/Hexagon/HexagonCopyHoisting.cpp @@ -34,7 +34,7 @@ class HexagonCopyHoisting : public MachineFunctionPass {  public:    static char ID; -  HexagonCopyHoisting() : MachineFunctionPass(ID), MFN(nullptr), MRI(nullptr) {} +  HexagonCopyHoisting() : MachineFunctionPass(ID) {}    StringRef getPassName() const override { return "Hexagon Copy Hoisting"; } @@ -56,8 +56,8 @@ public:    void moveCopyInstr(MachineBasicBlock *DestBB,                       std::pair<Register, Register> Key, MachineInstr *MI); -  MachineFunction *MFN; -  MachineRegisterInfo *MRI; +  MachineFunction *MFN = nullptr; +  MachineRegisterInfo *MRI = nullptr;    std::vector<DenseMap<std::pair<Register, Register>, MachineInstr *>>        CopyMIList;  }; diff --git a/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp b/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp index 93418f7..a10c937 100644 --- a/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp +++ b/llvm/lib/Target/Hexagon/HexagonGenMemAbsolute.cpp @@ -34,13 +34,13 @@ STATISTIC(HexagonNumStoreAbsConversions,  namespace {  class HexagonGenMemAbsolute : public MachineFunctionPass { -  const HexagonInstrInfo *TII; -  MachineRegisterInfo *MRI; -  const TargetRegisterInfo *TRI; +  const HexagonInstrInfo *TII = nullptr; +  MachineRegisterInfo *MRI = nullptr; +  const TargetRegisterInfo *TRI = nullptr;  public:    static char ID; -  HexagonGenMemAbsolute() : MachineFunctionPass(ID), TII(0), MRI(0), TRI(0) {} +  HexagonGenMemAbsolute() : MachineFunctionPass(ID) {}    StringRef getPassName() const override {      return "Hexagon Generate Load/Store Set Absolute Address Instruction"; diff --git a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td index 1637b91..d19920c 100644 --- a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td +++ b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td @@ -612,6 +612,9 @@ let Predicates = [UseHVX] in {             (V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>;    def: Pat<(VecQ32 (trunc HVI32:$Vs)),             (V6_vandvrt HvxVR:$Vs, (ToI32 0x01010101))>; +  def: Pat<(VecQ16 (trunc HWI32:$Vss)), +           (Combineq(VecQ32(V6_vandvrt (HiVec $Vss), (ToI32 0x01010101))), +           (VecQ32 (V6_vandvrt (LoVec $Vss), (ToI32 0x01010101))))>;  }  let Predicates = [UseHVX] in { diff --git a/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp b/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp index b9cdd6a..ce2de75 100644 --- a/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp +++ b/llvm/lib/Target/Hexagon/HexagonSubtarget.cpp @@ -544,7 +544,7 @@ int HexagonSubtarget::updateLatency(MachineInstr &SrcInst,    if (!hasV60Ops())      return Latency; -  auto &QII = static_cast<const HexagonInstrInfo &>(*getInstrInfo()); +  const HexagonInstrInfo &QII = *getInstrInfo();    // BSB scheduling.    if (QII.isHVXVec(SrcInst) || useBSBScheduling())      Latency = (Latency + 1) >> 1; diff --git a/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp b/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp index 71bdfc66..5a85f34 100644 --- a/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTfrCleanup.cpp @@ -43,7 +43,7 @@ namespace {  class HexagonTfrCleanup : public MachineFunctionPass {  public:    static char ID; -  HexagonTfrCleanup() : MachineFunctionPass(ID), HII(0), TRI(0) {} +  HexagonTfrCleanup() : MachineFunctionPass(ID) {}    StringRef getPassName() const override { return "Hexagon TFR Cleanup"; }    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.setPreservesAll(); @@ -52,8 +52,8 @@ public:    bool runOnMachineFunction(MachineFunction &MF) override;  private: -  const HexagonInstrInfo *HII; -  const TargetRegisterInfo *TRI; +  const HexagonInstrInfo *HII = nullptr; +  const TargetRegisterInfo *TRI = nullptr;    typedef DenseMap<unsigned, uint64_t> ImmediateMap; diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td index 690dd73..e86b21c 100644 --- a/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchFloat32InstrInfo.td @@ -365,6 +365,7 @@ def : Pat<(f32 (uint_to_fp (i64 (sexti32 (i64 GPR:$src))))),  // FP Rounding  let Predicates = [HasBasicF, IsLA64] in {  def : PatFpr<frint, FRINT_S, FPR32>; +def : PatFpr<flog2, FLOGB_S, FPR32>;  } // Predicates = [HasBasicF, IsLA64]  let Predicates = [HasBasicF, IsLA32] in { diff --git a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td index daefbaa..2e88254 100644 --- a/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchFloat64InstrInfo.td @@ -348,6 +348,7 @@ def : Pat<(bitconvert FPR64:$src), (MOVFR2GR_D FPR64:$src)>;  // FP Rounding  let Predicates = [HasBasicD, IsLA64] in {  def : PatFpr<frint, FRINT_D, FPR64>; +def : PatFpr<flog2, FLOGB_D, FPR64>;  } // Predicates = [HasBasicD, IsLA64]  /// Pseudo-instructions needed for the soft-float ABI with LA32D diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 80c96c6..a6de839 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -244,8 +244,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,      setOperationAction(ISD::FP_TO_BF16, MVT::f32,                         Subtarget.isSoftFPABI() ? LibCall : Custom); -    if (Subtarget.is64Bit()) +    if (Subtarget.is64Bit()) {        setOperationAction(ISD::FRINT, MVT::f32, Legal); +      setOperationAction(ISD::FLOG2, MVT::f32, Legal); +    }      if (!Subtarget.hasBasicD()) {        setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom); @@ -291,8 +293,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,      setOperationAction(ISD::FP_TO_BF16, MVT::f64,                         Subtarget.isSoftFPABI() ? LibCall : Custom); -    if (Subtarget.is64Bit()) +    if (Subtarget.is64Bit()) {        setOperationAction(ISD::FRINT, MVT::f64, Legal); +      setOperationAction(ISD::FLOG2, MVT::f64, Legal); +    }    }    // Set operations for 'LSX' feature. @@ -362,6 +366,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,        setOperationAction(ISD::FMA, VT, Legal);        setOperationAction(ISD::FSQRT, VT, Legal);        setOperationAction(ISD::FNEG, VT, Legal); +      setOperationAction(ISD::FLOG2, VT, Legal);        setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,                           ISD::SETUGE, ISD::SETUGT},                          VT, Expand); @@ -443,6 +448,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,        setOperationAction(ISD::FMA, VT, Legal);        setOperationAction(ISD::FSQRT, VT, Legal);        setOperationAction(ISD::FNEG, VT, Legal); +      setOperationAction(ISD::FLOG2, VT, Legal);        setCondCodeAction({ISD::SETGE, ISD::SETGT, ISD::SETOGE, ISD::SETOGT,                           ISD::SETUGE, ISD::SETUGT},                          VT, Expand); diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index 613dea6..ca4ee5f 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -1593,6 +1593,9 @@ def : Pat<(fma_nsz (fneg v4f64:$xj), v4f64:$xk, v4f64:$xa),  // XVFSQRT_{S/D}  defm : PatXrF<fsqrt, "XVFSQRT">; +// XVFLOGB_{S/D} +defm : PatXrF<flog2, "XVFLOGB">; +  // XVRECIP_{S/D}  def : Pat<(fdiv vsplatf32_fpimm_eq_1, v8f32:$xj),            (XVFRECIP_S v8f32:$xj)>; @@ -2024,6 +2027,24 @@ def : Pat<(v4i32(fp_to_uint v4f64:$vj)),                 (XVFTINTRZ_LU_D v4f64:$vj)),                sub_128)>; +// XVAVG_{B/H/W/D/BU/HU/WU/DU}, XVAVGR_{B/H/W/D/BU/HU/WU/DU} +defm : VAvgPat<sra, "XVAVG_B", v32i8>; +defm : VAvgPat<sra, "XVAVG_H", v16i16>; +defm : VAvgPat<sra, "XVAVG_W", v8i32>; +defm : VAvgPat<sra, "XVAVG_D", v4i64>; +defm : VAvgPat<srl, "XVAVG_BU", v32i8>; +defm : VAvgPat<srl, "XVAVG_HU", v16i16>; +defm : VAvgPat<srl, "XVAVG_WU", v8i32>; +defm : VAvgPat<srl, "XVAVG_DU", v4i64>; +defm : VAvgrPat<sra, "XVAVGR_B", v32i8>; +defm : VAvgrPat<sra, "XVAVGR_H", v16i16>; +defm : VAvgrPat<sra, "XVAVGR_W", v8i32>; +defm : VAvgrPat<sra, "XVAVGR_D", v4i64>; +defm : VAvgrPat<srl, "XVAVGR_BU", v32i8>; +defm : VAvgrPat<srl, "XVAVGR_HU", v16i16>; +defm : VAvgrPat<srl, "XVAVGR_WU", v8i32>; +defm : VAvgrPat<srl, "XVAVGR_DU", v4i64>; +  // abs  def : Pat<(abs v32i8:$xj), (XVSIGNCOV_B v32i8:$xj, v32i8:$xj)>;  def : Pat<(abs v16i16:$xj), (XVSIGNCOV_H v16i16:$xj, v16i16:$xj)>; diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td index 4619c6b..92402ba 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td @@ -1518,6 +1518,18 @@ multiclass InsertExtractPatV2<ValueType vecty, ValueType elemty> {    }  } +multiclass VAvgPat<SDPatternOperator OpNode, string Inst, ValueType vt> { +  def : Pat<(OpNode (vt (add vt:$vj, vt:$vk)), (vt (vsplat_imm_eq_1))), +            (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>; +} + +multiclass VAvgrPat<SDPatternOperator OpNode, string Inst, ValueType vt> { +  def : Pat<(OpNode (vt (add (vt (add vt:$vj, vt:$vk)), +                             (vt (vsplat_imm_eq_1)))), +                    (vt (vsplat_imm_eq_1))), +            (!cast<LAInst>(Inst) vt:$vj, vt:$vk)>; +} +  let Predicates = [HasExtLSX] in {  // VADD_{B/H/W/D} @@ -1783,6 +1795,9 @@ def : Pat<(fma_nsz (fneg v2f64:$vj), v2f64:$vk, v2f64:$va),  // VFSQRT_{S/D}  defm : PatVrF<fsqrt, "VFSQRT">; +// VFLOGB_{S/D} +defm : PatVrF<flog2, "VFLOGB">; +  // VFRECIP_{S/D}  def : Pat<(fdiv vsplatf32_fpimm_eq_1, v4f32:$vj),            (VFRECIP_S v4f32:$vj)>; @@ -2154,6 +2169,24 @@ def : Pat<(f32 f32imm_vldi:$in),  def : Pat<(f64 f64imm_vldi:$in),            (f64 (EXTRACT_SUBREG (VLDI (to_f64imm_vldi f64imm_vldi:$in)), sub_64))>; +// VAVG_{B/H/W/D/BU/HU/WU/DU}, VAVGR_{B/H/W/D/BU/HU/WU/DU} +defm : VAvgPat<sra, "VAVG_B", v16i8>; +defm : VAvgPat<sra, "VAVG_H", v8i16>; +defm : VAvgPat<sra, "VAVG_W", v4i32>; +defm : VAvgPat<sra, "VAVG_D", v2i64>; +defm : VAvgPat<srl, "VAVG_BU", v16i8>; +defm : VAvgPat<srl, "VAVG_HU", v8i16>; +defm : VAvgPat<srl, "VAVG_WU", v4i32>; +defm : VAvgPat<srl, "VAVG_DU", v2i64>; +defm : VAvgrPat<sra, "VAVGR_B", v16i8>; +defm : VAvgrPat<sra, "VAVGR_H", v8i16>; +defm : VAvgrPat<sra, "VAVGR_W", v4i32>; +defm : VAvgrPat<sra, "VAVGR_D", v2i64>; +defm : VAvgrPat<srl, "VAVGR_BU", v16i8>; +defm : VAvgrPat<srl, "VAVGR_HU", v8i16>; +defm : VAvgrPat<srl, "VAVGR_WU", v4i32>; +defm : VAvgrPat<srl, "VAVGR_DU", v2i64>; +  // abs  def : Pat<(abs v16i8:$vj), (VSIGNCOV_B v16i8:$vj, v16i8:$vj)>;  def : Pat<(abs v8i16:$vj), (VSIGNCOV_H v8i16:$vj, v8i16:$vj)>; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 7e7ee75..c667a09 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1871,17 +1871,6 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }    (is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, _CH))          \           : (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, ))) -#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32)   \ -  [&]() -> auto {                                                              \ -    if (is_mc && is_ch)                                                        \ -      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC_CH);      \ -    if (is_ch)                                                                 \ -      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _CH);         \ -    if (is_mc)                                                                 \ -      return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, _MC);         \ -    return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, );              \ -  }() -  static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,                                                         bool IsShared32,                                                         bool IsCacheHint, @@ -1925,112 +1914,6 @@ static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,    }  } -static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32, -                                              bool IsMultiCast, -                                              bool IsCacheHint, bool IsIm2Col) { -  if (IsIm2Col) { -    switch (Dim) { -    case 3: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, IM2COL, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 4: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, IM2COL, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 5: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, IM2COL, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    default: -      llvm_unreachable("Invalid Dimension in im2col mode for " -                       "GetCpAsyncBulkTensorG2SOpcode."); -    } -  } else { -    switch (Dim) { -    case 1: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(1D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 2: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(2D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 3: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(3D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 4: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(4D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    case 5: -      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(5D, TILE, IsMultiCast, -                                                 IsCacheHint, IsShared32); -    default: -      llvm_unreachable( -          "Invalid Dimension in tile mode for GetCpAsyncBulkTensorG2SOpcode."); -    } -  } -} - -static size_t GetDimsFromIntrinsic(unsigned IID) { -  switch (IID) { -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d: -    return 3; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d: -    return 4; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d: -    return 5; -  default: -    llvm_unreachable("Invalid im2col intrinsic in GetDimsFromIntrinsic."); -  } -} - -void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N, -                                                         bool IsIm2Col) { -  // We have {Chain, Intrinsic-ID} followed by the actual intrisic args: -  // {dst, mbar, src, dims{d0...dN}, im2col_offsets{dims-2} -  // multicast, cache_hint, -  // multicast_flag, cache_hint_flag, cta_group_flag} -  // NumOperands = {Chain, IID} + {Actual intrinsic args} -  //             = {2}          + {8 + dims + im2col_offsets} -  size_t NumOps = N->getNumOperands(); -  size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1)) -                            : (NumOps - 10); -  // Offsets is always 'NumDims - 2' and only for im2col mode -  size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0; -  bool IsCacheHint = N->getConstantOperandVal(NumOps - 2) == 1; -  bool IsMultiCast = N->getConstantOperandVal(NumOps - 3) == 1; -  size_t NumBaseArgs = NumDims + NumOffsets + 3; // for {dst, mbar, src} -  size_t MultiCastIdx = NumBaseArgs + 2;         // for Chain and IID - -  unsigned CTAGroupVal = N->getConstantOperandVal(NumOps - 1); -  if ((CTAGroupVal > 0) && !Subtarget->hasCpAsyncBulkTensorCTAGroupSupport()) -    report_fatal_error( -        formatv("CpAsyncBulkTensorG2S cta_group::1/2 is not supported on sm_{}", -                Subtarget->getSmVersion())); - -  SDLoc DL(N); -  SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumBaseArgs)); - -  // Push MultiCast operand, if available -  if (IsMultiCast) -    Ops.push_back(N->getOperand(MultiCastIdx)); - -  // Push CacheHint operand, if available -  if (IsCacheHint) -    Ops.push_back(N->getOperand(MultiCastIdx + 1)); - -  // Flag for CTA Group -  Ops.push_back(getI32Imm(CTAGroupVal, DL)); - -  // Finally, the chain operand -  Ops.push_back(N->getOperand(0)); - -  bool IsShared32 = -      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32; -  unsigned Opcode = GetCpAsyncBulkTensorG2SOpcode( -      NumDims, IsShared32, IsMultiCast, IsCacheHint, IsIm2Col); -  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops)); -} -  void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,                                                              unsigned RedOp,                                                              bool IsIm2Col) { @@ -2175,18 +2058,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {    switch (IID) {    default:      return false; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d: -    SelectCpAsyncBulkTensorG2SCommon(N); -    return true; -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d: -  case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d: -    SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true); -    return true;    case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d:    case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d:    case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d: diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index c912e70..1cb579b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -86,7 +86,6 @@ private:    bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);    void SelectV2I64toI128(SDNode *N);    void SelectI128toV2I64(SDNode *N); -  void SelectCpAsyncBulkTensorG2SCommon(SDNode *N, bool IsIm2Col = false);    void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,                                             bool IsIm2Col = false);    void SelectTcgen05Ld(SDNode *N, bool hasOffset = false); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index dfde0cc..b260221 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -139,7 +139,6 @@ def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;  def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;  def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;  def hasTcgen05MMAScaleInputDImm : Predicate<"Subtarget->hasTcgen05MMAScaleInputDImm()">; -def hasTMACTAGroupSupport  : Predicate<"Subtarget->hasCpAsyncBulkTensorCTAGroupSupport()">;  def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;  class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index c923f0e..e8758aa 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -599,75 +599,15 @@ class TMA_IM2COL_UTIL<int dim, string mode> {    string base_str = !interleave(!foreach(i, !range(offsets), "$im2col" # i), ", ");  } -// From Global to Shared memory (G2S) -class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> { -  string prefix = "cp.async.bulk.tensor"; -  string dir = "shared::cluster.global"; -  string completion = "mbarrier::complete_tx::bytes"; -  string inst_name = prefix -                     # "." # dim # "d" -                     # "." # dir -                     # "." # mode -                     # "." # completion -                     # !if(mc, ".multicast::cluster", "") -                     # !if(ch, ".L2::cache_hint", ""); -  string intr_name = "CP_ASYNC_BULK_TENSOR_G2S_" -                     # dim # "D" -                     # !if(is_shared32, "_SHARED32", "") -                     # !if(!eq(mode, "tile"), "_TILE", "_IM2COL"); -} -  def CTAGroupFlags : Operand<i32> {    let PrintMethod = "printCTAGroup";  } -multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> { -  defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; -  defvar dims_str = TMA_DIMS_UTIL<dim>.base_str; -  defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; -  defvar rc = !if(is_shared32, B32, B64); - -  defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0); -  defvar im2col_dag = !if(!eq(mode, "im2col"), -    !dag(ins, !listsplat(B16, num_im2col), !foreach(i, !range(num_im2col), "im2col" # i)), -    (ins)); -  defvar im2col_str = !interleave(!foreach(i, !range(num_im2col), "$im2col" # i), ", "); -  defvar im2col_asm_str = ", {{" # im2col_str # "}}"; - -  defvar asm_str = !if(!eq(mode, "im2col"), -    !strconcat(asm_str_default, im2col_asm_str), asm_str_default); +def tma_cta_group_imm0 : TImmLeaf<i32, [{return Imm == 0;}]>; +def tma_cta_group_imm_any : TImmLeaf<i32, [{return Imm >= 0;}]>; -  def "" : NVPTXInst<(outs), -            !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)), -            !strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";")>, -            Requires<[hasPTX<80>, hasSM<90>]>; -  def _MC : NVPTXInst<(outs), -                  !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, -                       (ins B16:$mc, CTAGroupFlags:$cg)), -                  !strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;")>, -                  Requires<[hasPTX<80>, hasSM<90>]>; -  def _CH : NVPTXInst<(outs), -                  !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, -                       (ins B64:$ch, CTAGroupFlags:$cg)), -                  !strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;")>, -                  Requires<[hasPTX<80>, hasSM<90>]>; -  def _MC_CH : NVPTXInst<(outs), -                     !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, -                          (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg)), -                     !strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;")>, -                     Requires<[hasPTX<80>, hasSM<90>]>; -} - -foreach dim = [1, 2, 3, 4, 5] in { -  foreach shared32 = [true, false] in { -    foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in { -      defm G2S_STRINGS<dim, mode, 0, 0, shared32>.intr_name : -        CP_ASYNC_BULK_TENSOR_G2S_INTR<dim, shared32, mode>; -    } -  } -} - -multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> { +multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred, +                               TImmLeaf cta_group_type = tma_cta_group_imm_any> {    defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;    defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;    defvar asm_str_base = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; @@ -697,10 +637,10 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>                           !setdagop(dims_dag, intr),                           !setdagop(im2col_dag, intr),                           (intr B16:$mc, B64:$ch)); -  defvar intr_dag_no_hints   = !con(intr_dag_base, (intr 0,  0,  timm:$cg)); -  defvar intr_dag_with_mc    = !con(intr_dag_base, (intr -1, 0,  timm:$cg)); -  defvar intr_dag_with_ch    = !con(intr_dag_base, (intr 0, -1,  timm:$cg)); -  defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, timm:$cg)); +  defvar intr_dag_no_hints   = !con(intr_dag_base, (intr 0,  0,  cta_group_type:$cg)); +  defvar intr_dag_with_mc    = !con(intr_dag_base, (intr -1, 0,  cta_group_type:$cg)); +  defvar intr_dag_with_ch    = !con(intr_dag_base, (intr 0, -1,  cta_group_type:$cg)); +  defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, cta_group_type:$cg));    def "" : NVPTXInst<(outs), ins_dag,               inst_name # asm_str # ";", @@ -719,14 +659,30 @@ multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []>                   [intr_dag_with_mc_ch]>,                   Requires<pred>;  } + +foreach dim = 1...5 in { +  defm TMA_G2S_TILE_CG0_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "tile", [hasPTX<80>, hasSM<90>], +                            tma_cta_group_imm0>; +  defm TMA_G2S_TILE_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "tile", +                            [callSubtarget<"hasTMABlackwellSupport">]>; +}  foreach dim = 3...5 in { +  defm TMA_G2S_IM2COL_CG0_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "im2col", [hasPTX<80>, hasSM<90>], +                            tma_cta_group_imm0>; +  defm TMA_G2S_IM2COL_ # dim # "D" +      : TMA_TENSOR_G2S_INTR<dim, "im2col", +                            [callSubtarget<"hasTMABlackwellSupport">]>;    foreach mode = ["im2col_w", "im2col_w_128"] in {      defm TMA_G2S_ # !toupper(mode) # "_" # dim # "D" -      : TMA_TENSOR_G2S_INTR<dim, mode, [hasTMACTAGroupSupport]>; +        : TMA_TENSOR_G2S_INTR<dim, mode, +                              [callSubtarget<"hasTMABlackwellSupport">]>;    }  }  defm TMA_G2S_TILE_GATHER4_2D : TMA_TENSOR_G2S_INTR<5, "tile_gather4", -                               [hasTMACTAGroupSupport]>; +                               [callSubtarget<"hasTMABlackwellSupport">]>;  multiclass TMA_TENSOR_G2S_CTA_INTR<int dim, string mode, list<Predicate> pred = []> {    defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; @@ -784,7 +740,8 @@ foreach dim = 3...5 in {      : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w", [hasPTX<86>, hasSM<100>]>;    defm TMA_G2S_CTA_IM2COL_W_128_ # dim # "D" -    : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", [hasTMACTAGroupSupport]>; +    : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", +                              [callSubtarget<"hasTMABlackwellSupport">]>;  }  defm TMA_G2S_CTA_TILE_GATHER4_2D : TMA_TENSOR_G2S_CTA_INTR<5, "tile_gather4",                                     [hasPTX<86>, hasSM<100>]>; @@ -835,7 +792,7 @@ foreach dim = 1...5 in {    }  }  defm TMA_S2G_TILE_SCATTER4_2D : TMA_TENSOR_S2G_INTR<5, "tile_scatter4", -                                [hasTMACTAGroupSupport]>; +                                [callSubtarget<"hasTMABlackwellSupport">]>;  def TMAReductionFlags : Operand<i32> {    let PrintMethod = "printTmaReductionMode"; @@ -930,11 +887,11 @@ foreach dim = 3...5 in {    foreach mode = ["im2col_w", "im2col_w_128"] in {      defvar suffix = !toupper(mode) # "_" # dim # "D";      defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode, -                                   [hasTMACTAGroupSupport]>; +                                   [callSubtarget<"hasTMABlackwellSupport">]>;    }  }  defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4", -                                     [hasTMACTAGroupSupport]>; +                                     [callSubtarget<"hasTMABlackwellSupport">]>;  //Prefetchu and Prefetch diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 194dbdc..021b1f6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -166,18 +166,15 @@ public:    // f32x2 instructions in Blackwell family    bool hasF32x2Instructions() const; -  // TMA G2S copy with cta_group::1/2 support -  bool hasCpAsyncBulkTensorCTAGroupSupport() const { -    // TODO: Update/tidy-up after the family-conditional support arrives -    switch (FullSmVersion) { -    case 1003: -    case 1013: -      return PTXVersion >= 86; -    case 1033: -      return PTXVersion >= 88; -    default: -      return false; -    } +  // Checks support for following in TMA: +  //  - cta_group::1/2 support +  //  - im2col_w/w_128 mode support +  //  - tile_gather4 mode support +  //  - tile_scatter4 mode support +  bool hasTMABlackwellSupport() const { +    return hasPTXWithFamilySMs(90, {100, 110}) || +           hasPTXWithFamilySMs(88, {100, 101}) || +           hasPTXWithAccelSMs(86, {100, 101});    }    // Prior to CUDA 12.3 ptxas did not recognize that the trap instruction diff --git a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp index 000d296..4ff489d 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetMachine.cpp @@ -296,8 +296,9 @@ PPCTargetMachine::PPCTargetMachine(const Target &T, const Triple &TT,                                     std::optional<Reloc::Model> RM,                                     std::optional<CodeModel::Model> CM,                                     CodeGenOptLevel OL, bool JIT) -    : CodeGenTargetMachineImpl(T, TT.computeDataLayout(), TT, CPU, -                               computeFSAdditions(FS, OL, TT), Options, +    : CodeGenTargetMachineImpl(T, +                               TT.computeDataLayout(Options.MCOptions.ABIName), +                               TT, CPU, computeFSAdditions(FS, OL, TT), Options,                                 getEffectiveRelocModel(TT, RM),                                 getEffectivePPCCodeModel(TT, CM, JIT), OL),        TLOF(createTLOF(getTargetTriple())), diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp index 8198173..282cf5d 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp @@ -92,6 +92,10 @@ private:    void emitFence(AtomicOrdering FenceOrdering, SyncScope::ID FenceSSID,                   MachineIRBuilder &MIB) const;    bool selectUnmergeValues(MachineInstr &MI, MachineIRBuilder &MIB) const; +  void addVectorLoadStoreOperands(MachineInstr &I, +                                  SmallVectorImpl<SrcOp> &SrcOps, +                                  unsigned &CurOp, bool IsMasked, +                                  bool IsStrided) const;    bool selectIntrinsicWithSideEffects(MachineInstr &I,                                        MachineIRBuilder &MIB) const; @@ -716,6 +720,26 @@ static unsigned selectRegImmLoadStoreOp(unsigned GenericOpc, unsigned OpSize) {    return GenericOpc;  } +void RISCVInstructionSelector::addVectorLoadStoreOperands( +    MachineInstr &I, SmallVectorImpl<SrcOp> &SrcOps, unsigned &CurOp, +    bool IsMasked, bool IsStrided) const { +  // Base Pointer +  auto PtrReg = I.getOperand(CurOp++).getReg(); +  SrcOps.push_back(PtrReg); + +  // Stride +  if (IsStrided) { +    auto StrideReg = I.getOperand(CurOp++).getReg(); +    SrcOps.push_back(StrideReg); +  } + +  // Mask +  if (IsMasked) { +    auto MaskReg = I.getOperand(CurOp++).getReg(); +    SrcOps.push_back(MaskReg); +  } +} +  bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(      MachineInstr &I, MachineIRBuilder &MIB) const {    // Find the intrinsic ID. @@ -752,21 +776,7 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(        SrcOps.push_back(Register(RISCV::NoRegister));      } -    // Base Pointer -    auto PtrReg = I.getOperand(CurOp++).getReg(); -    SrcOps.push_back(PtrReg); - -    // Stride -    if (IsStrided) { -      auto StrideReg = I.getOperand(CurOp++).getReg(); -      SrcOps.push_back(StrideReg); -    } - -    // Mask -    if (IsMasked) { -      auto MaskReg = I.getOperand(CurOp++).getReg(); -      SrcOps.push_back(MaskReg); -    } +    addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided);      RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT));      const RISCV::VLEPseudo *P = @@ -795,6 +805,48 @@ bool RISCVInstructionSelector::selectIntrinsicWithSideEffects(      I.eraseFromParent();      return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI);    } +  case Intrinsic::riscv_vsm: +  case Intrinsic::riscv_vse: +  case Intrinsic::riscv_vse_mask: +  case Intrinsic::riscv_vsse: +  case Intrinsic::riscv_vsse_mask: { +    bool IsMasked = IntrinID == Intrinsic::riscv_vse_mask || +                    IntrinID == Intrinsic::riscv_vsse_mask; +    bool IsStrided = IntrinID == Intrinsic::riscv_vsse || +                     IntrinID == Intrinsic::riscv_vsse_mask; +    LLT VT = MRI->getType(I.getOperand(1).getReg()); +    unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits()); + +    // Sources +    unsigned CurOp = 1; +    SmallVector<SrcOp, 4> SrcOps; // Source registers. + +    // Store value +    auto PassthruReg = I.getOperand(CurOp++).getReg(); +    SrcOps.push_back(PassthruReg); + +    addVectorLoadStoreOperands(I, SrcOps, CurOp, IsMasked, IsStrided); + +    RISCVVType::VLMUL LMUL = RISCVTargetLowering::getLMUL(getMVTForLLT(VT)); +    const RISCV::VSEPseudo *P = RISCV::getVSEPseudo( +        IsMasked, IsStrided, Log2SEW, static_cast<unsigned>(LMUL)); + +    auto PseudoMI = MIB.buildInstr(P->Pseudo, {}, SrcOps); + +    // Select VL +    auto VLOpFn = renderVLOp(I.getOperand(CurOp++)); +    for (auto &RenderFn : *VLOpFn) +      RenderFn(PseudoMI); + +    // SEW +    PseudoMI.addImm(Log2SEW); + +    // Memref +    PseudoMI.cloneMemRefs(I); + +    I.eraseFromParent(); +    return constrainSelectedInstRegOperands(*PseudoMI, TII, TRI, RBI); +  }    }  } diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp index 4105618..526675a 100644 --- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp @@ -127,6 +127,10 @@ bool RISCVExpandPseudo::expandMI(MachineBasicBlock &MBB,    case RISCV::PseudoCCAND:    case RISCV::PseudoCCOR:    case RISCV::PseudoCCXOR: +  case RISCV::PseudoCCMAX: +  case RISCV::PseudoCCMAXU: +  case RISCV::PseudoCCMIN: +  case RISCV::PseudoCCMINU:    case RISCV::PseudoCCADDW:    case RISCV::PseudoCCSUBW:    case RISCV::PseudoCCSLL: @@ -217,6 +221,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,          .addImm(0);    } else {      unsigned NewOpc; +    // clang-format off      switch (MI.getOpcode()) {      default:        llvm_unreachable("Unexpected opcode!"); @@ -228,6 +233,10 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,      case RISCV::PseudoCCAND:   NewOpc = RISCV::AND;   break;      case RISCV::PseudoCCOR:    NewOpc = RISCV::OR;    break;      case RISCV::PseudoCCXOR:   NewOpc = RISCV::XOR;   break; +    case RISCV::PseudoCCMAX:   NewOpc = RISCV::MAX;   break; +    case RISCV::PseudoCCMIN:   NewOpc = RISCV::MIN;   break; +    case RISCV::PseudoCCMAXU:  NewOpc = RISCV::MAXU;  break; +    case RISCV::PseudoCCMINU:  NewOpc = RISCV::MINU;  break;      case RISCV::PseudoCCADDI:  NewOpc = RISCV::ADDI;  break;      case RISCV::PseudoCCSLLI:  NewOpc = RISCV::SLLI;  break;      case RISCV::PseudoCCSRLI:  NewOpc = RISCV::SRLI;  break; @@ -250,6 +259,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB,      case RISCV::PseudoCCNDS_BFOS: NewOpc = RISCV::NDS_BFOS; break;      case RISCV::PseudoCCNDS_BFOZ: NewOpc = RISCV::NDS_BFOZ; break;      } +    // clang-format on      if (NewOpc == RISCV::NDS_BFOZ || NewOpc == RISCV::NDS_BFOS) {        BuildMI(TrueBB, DL, TII->get(NewOpc), DestReg) diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index b4556f6..cfee6ab 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -1851,6 +1851,11 @@ def TuneShortForwardBranchOpt  def HasShortForwardBranchOpt : Predicate<"Subtarget->hasShortForwardBranchOpt()">;  def NoShortForwardBranchOpt : Predicate<"!Subtarget->hasShortForwardBranchOpt()">; +def TuneShortForwardBranchIMinMax +    : SubtargetFeature<"short-forward-branch-i-minmax", "HasShortForwardBranchIMinMax", +                       "true", "Enable short forward branch optimization for min,max instructions in Zbb", +                       [TuneShortForwardBranchOpt]>; +  // Some subtargets require a S2V transfer buffer to move scalars into vectors.  // FIXME: Forming .vx/.vf/.wx/.wf can reduce register pressure.  def TuneNoSinkSplatOperands diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 9a6afa1..b25a054 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3995,6 +3995,7 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits,      case RISCV::CTZW:      case RISCV::CPOPW:      case RISCV::SLLI_UW: +    case RISCV::ABSW:      case RISCV::FMV_W_X:      case RISCV::FCVT_H_W:      case RISCV::FCVT_H_W_INX: diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 1c930ac..56881f7 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -433,6 +433,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,    if (Subtarget.hasStdExtP() ||        (Subtarget.hasVendorXCValu() && !Subtarget.is64Bit())) {      setOperationAction(ISD::ABS, XLenVT, Legal); +    if (Subtarget.is64Bit()) +      setOperationAction(ISD::ABS, MVT::i32, Custom);    } else if (Subtarget.hasShortForwardBranchOpt()) {      // We can use PseudoCCSUB to implement ABS.      setOperationAction(ISD::ABS, XLenVT, Legal); @@ -14816,8 +14818,16 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,      assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&             "Unexpected custom legalisation"); +    if (Subtarget.hasStdExtP()) { +      SDValue Src = +          DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0)); +      SDValue Abs = DAG.getNode(RISCVISD::ABSW, DL, MVT::i64, Src); +      Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Abs)); +      return; +    } +      if (Subtarget.hasStdExtZbb()) { -      // Emit a special ABSW node that will be expanded to NEGW+MAX at isel. +      // Emit a special node that will be expanded to NEGW+MAX at isel.        // This allows us to remember that the result is sign extended. Expanding        // to NEGW+MAX here requires a Freeze which breaks ComputeNumSignBits.        SDValue Src = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, @@ -20290,6 +20300,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,      break;    } +  case RISCVISD::ABSW:    case RISCVISD::CLZW:    case RISCVISD::CTZW: {      // Only the lower 32 bits of the first operand are read @@ -21862,6 +21873,7 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(    case RISCVISD::REMUW:    case RISCVISD::ROLW:    case RISCVISD::RORW: +  case RISCVISD::ABSW:    case RISCVISD::FCVT_W_RV64:    case RISCVISD::FCVT_WU_RV64:    case RISCVISD::STRICT_FCVT_W_RV64: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 912b82d..c9df787 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -869,7 +869,7 @@ std::optional<unsigned> getFoldedOpcode(MachineFunction &MF, MachineInstr &MI,    }  } -// This is the version used during inline spilling +// This is the version used during InlineSpiller::spillAroundUses  MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl(      MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned> Ops,      MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS, @@ -1699,6 +1699,10 @@ unsigned getPredicatedOpcode(unsigned Opcode) {    case RISCV::AND:   return RISCV::PseudoCCAND;    case RISCV::OR:    return RISCV::PseudoCCOR;    case RISCV::XOR:   return RISCV::PseudoCCXOR; +  case RISCV::MAX:   return RISCV::PseudoCCMAX; +  case RISCV::MAXU:  return RISCV::PseudoCCMAXU; +  case RISCV::MIN:   return RISCV::PseudoCCMIN; +  case RISCV::MINU:  return RISCV::PseudoCCMINU;    case RISCV::ADDI:  return RISCV::PseudoCCADDI;    case RISCV::SLLI:  return RISCV::PseudoCCSLLI; @@ -1735,7 +1739,8 @@ unsigned getPredicatedOpcode(unsigned Opcode) {  /// return the defining instruction.  static MachineInstr *canFoldAsPredicatedOp(Register Reg,                                             const MachineRegisterInfo &MRI, -                                           const TargetInstrInfo *TII) { +                                           const TargetInstrInfo *TII, +                                           const RISCVSubtarget &STI) {    if (!Reg.isVirtual())      return nullptr;    if (!MRI.hasOneNonDBGUse(Reg)) @@ -1743,6 +1748,12 @@ static MachineInstr *canFoldAsPredicatedOp(Register Reg,    MachineInstr *MI = MRI.getVRegDef(Reg);    if (!MI)      return nullptr; + +  if (!STI.hasShortForwardBranchIMinMax() && +      (MI->getOpcode() == RISCV::MAX || MI->getOpcode() == RISCV::MIN || +       MI->getOpcode() == RISCV::MINU || MI->getOpcode() == RISCV::MAXU)) +    return nullptr; +    // Check if MI can be predicated and folded into the CCMOV.    if (getPredicatedOpcode(MI->getOpcode()) == RISCV::INSTRUCTION_LIST_END)      return nullptr; @@ -1806,10 +1817,10 @@ RISCVInstrInfo::optimizeSelect(MachineInstr &MI,    MachineRegisterInfo &MRI = MI.getParent()->getParent()->getRegInfo();    MachineInstr *DefMI = -      canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this); +      canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this, STI);    bool Invert = !DefMI;    if (!DefMI) -    DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this); +    DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this, STI);    if (!DefMI)      return nullptr; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index 7c89686..9cb53fb 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -768,7 +768,7 @@ def BGE  : BranchCC_rri<0b101, "bge">;  def BLTU : BranchCC_rri<0b110, "bltu">;  def BGEU : BranchCC_rri<0b111, "bgeu">; -let IsSignExtendingOpW = 1 in { +let IsSignExtendingOpW = 1, canFoldAsLoad = 1 in {  def LB  : Load_ri<0b000, "lb">, Sched<[WriteLDB, ReadMemBase]>;  def LH  : Load_ri<0b001, "lh">, Sched<[WriteLDH, ReadMemBase]>;  def LW  : Load_ri<0b010, "lw">, Sched<[WriteLDW, ReadMemBase]>; @@ -889,8 +889,10 @@ def CSRRCI : CSR_ii<0b111, "csrrci">;  /// RV64I instructions  let Predicates = [IsRV64] in { +let canFoldAsLoad = 1 in {  def LWU   : Load_ri<0b110, "lwu">, Sched<[WriteLDW, ReadMemBase]>;  def LD    : Load_ri<0b011, "ld">, Sched<[WriteLDD, ReadMemBase]>; +}  def SD    : Store_rri<0b011, "sd">, Sched<[WriteSTD, ReadStoreData, ReadMemBase]>;  let IsSignExtendingOpW = 1 in { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td index afac37d..4ffe3e6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -71,6 +71,7 @@ defvar DExtsRV64 = [DExt, ZdinxExt];  //===----------------------------------------------------------------------===//  let Predicates = [HasStdExtD] in { +let canFoldAsLoad = 1 in  def FLD : FPLoad_r<0b011, "fld", FPR64, WriteFLD64>;  // Operands for stores are in the order srcreg, base, offset rather than diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td index 6571d99..b30f8ec 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td @@ -330,6 +330,7 @@ class PseudoFROUND<DAGOperand Ty, ValueType vt, ValueType intvt = XLenVT>  //===----------------------------------------------------------------------===//  let Predicates = [HasStdExtF] in { +let canFoldAsLoad = 1 in  def FLW : FPLoad_r<0b010, "flw", FPR32, WriteFLD32>;  // Operands for stores are in the order srcreg, base, offset rather than diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index cc085bb..4cbbba3 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -1461,5 +1461,10 @@ let Predicates = [HasStdExtP, IsRV32] in {  // Codegen patterns  //===----------------------------------------------------------------------===// +def riscv_absw : RVSDNode<"ABSW", SDTIntUnaryOp>; +  let Predicates = [HasStdExtP] in  def : PatGpr<abs, ABS>; + +let Predicates = [HasStdExtP, IsRV64] in +def : PatGpr<riscv_absw, ABSW>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td index 0114fbd..5a67a5a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoSFB.td @@ -106,6 +106,10 @@ def PseudoCCSRA : SFBALU_rr;  def PseudoCCAND : SFBALU_rr;  def PseudoCCOR  : SFBALU_rr;  def PseudoCCXOR : SFBALU_rr; +def PseudoCCMAX : SFBALU_rr; +def PseudoCCMIN : SFBALU_rr; +def PseudoCCMAXU : SFBALU_rr; +def PseudoCCMINU : SFBALU_rr;  def PseudoCCADDI : SFBALU_ri;  def PseudoCCANDI : SFBALU_ri; diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp index d08115b..ea98cdb 100644 --- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp +++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp @@ -172,6 +172,7 @@ static bool hasAllNBitUsers(const MachineInstr &OrigMI,        case RISCV::CTZW:        case RISCV::CPOPW:        case RISCV::SLLI_UW: +      case RISCV::ABSW:        case RISCV::FMV_W_X:        case RISCV::FCVT_H_W:        case RISCV::FCVT_H_W_INX: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3fea21e..3f0424f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3151,6 +3151,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,      return selectInsertElt(ResVReg, ResType, I);    case Intrinsic::spv_gep:      return selectGEP(ResVReg, ResType, I); +  case Intrinsic::spv_bitcast: { +    Register OpReg = I.getOperand(2).getReg(); +    SPIRVType *OpType = +        OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr; +    if (!GR.isBitcastCompatible(ResType, OpType)) +      report_fatal_error("incompatible result and operand types in a bitcast"); +    return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast); +  }    case Intrinsic::spv_unref_global:    case Intrinsic::spv_init_global: {      MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg()); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 6e444c9..65dffc7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {    // Returns the loaded value.    Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,                                FixedVectorType *TargetType, Value *Source) { -    assert(TargetType->getNumElements() <= SourceType->getNumElements());      LoadInst *NewLoad = B.CreateLoad(SourceType, Source);      buildAssignType(B, SourceType, NewLoad);      Value *AssignValue = NewLoad;      if (TargetType->getElementType() != SourceType->getElementType()) { +      const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout(); +      [[maybe_unused]] TypeSize TargetTypeSize = +          DL.getTypeSizeInBits(TargetType); +      [[maybe_unused]] TypeSize SourceTypeSize = +          DL.getTypeSizeInBits(SourceType); +      assert(TargetTypeSize == SourceTypeSize);        AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,                                        {TargetType, SourceType}, {NewLoad});        buildAssignType(B, TargetType, AssignValue); +      return AssignValue;      } +    assert(TargetType->getNumElements() < SourceType->getNumElements());      SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());      for (unsigned I = 0; I < TargetType->getNumElements(); ++I)        Mask[I] = I; diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index db6f2d6..d538009 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -192,31 +192,43 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,          .addUse(OpReg);  } -// We do instruction selections early instead of calling MIB.buildBitcast() -// generating the general op code G_BITCAST. When MachineVerifier validates -// G_BITCAST we see a check of a kind: if Source Type is equal to Destination -// Type then report error "bitcast must change the type". This doesn't take into -// account the notion of a typed pointer that is important for SPIR-V where a -// user may and should use bitcast between pointers with different pointee types -// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast). -// It's important for correct lowering in SPIR-V, because interpretation of the -// data type is not left to instructions that utilize the pointer, but encoded -// by the pointer declaration, and the SPIRV target can and must handle the -// declaration and use of pointers that specify the type of data they point to. -// It's not feasible to improve validation of G_BITCAST using just information -// provided by low level types of source and destination. Therefore we don't -// produce G_BITCAST as the general op code with semantics different from -// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only -// difference would be that CombinerHelper couldn't transform known patterns -// around G_BUILD_VECTOR. See discussion -// in https://github.com/llvm/llvm-project/pull/110270 for even more context. -static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, -                             MachineIRBuilder MIB) { +// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error. +// The verifier checks if the source and destination LLTs of a G_BITCAST are +// different, but this check is too strict for SPIR-V's typed pointers, which +// may have the same LLT but different SPIRVType (e.g. pointers to different +// pointee types). By lowering to OpBitcast here, we bypass the verifier's +// check. See discussion in https://github.com/llvm/llvm-project/pull/110270 +// for more context. +// +// We also handle the llvm.spv.bitcast intrinsic here. If the source and +// destination SPIR-V types are the same, we lower it to a COPY to enable +// further optimizations like copy propagation. +static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, +                          MachineIRBuilder MIB) {    SmallVector<MachineInstr *, 16> ToErase;    for (MachineBasicBlock &MBB : MF) {      for (MachineInstr &MI : MBB) { +      if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { +        Register DstReg = MI.getOperand(0).getReg(); +        Register SrcReg = MI.getOperand(2).getReg(); +        SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg); +        assert( +            DstType && +            "Expected destination SPIR-V type to have been assigned already."); +        SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg); +        assert(SrcType && +               "Expected source SPIR-V type to have been assigned already."); +        if (DstType == SrcType) { +          MIB.setInsertPt(*MI.getParent(), MI); +          MIB.buildCopy(DstReg, SrcReg); +          ToErase.push_back(&MI); +          continue; +        } +      } +        if (MI.getOpcode() != TargetOpcode::G_BITCAST)          continue; +        MIB.setInsertPt(*MI.getParent(), MI);        buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(),                       MI.getOperand(1).getReg()); @@ -237,16 +249,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,    SmallVector<MachineInstr *, 10> ToErase;    for (MachineBasicBlock &MBB : MF) {      for (MachineInstr &MI : MBB) { -      if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && -          !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) +      if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))          continue;        assert(MI.getOperand(2).isReg());        MIB.setInsertPt(*MI.getParent(), MI);        ToErase.push_back(&MI); -      if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { -        MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg()); -        continue; -      }        Register Def = MI.getOperand(0).getReg();        Register Source = MI.getOperand(2).getReg();        Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0); @@ -1089,7 +1096,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {    removeImplicitFallthroughs(MF, MIB);    insertSpirvDecorations(MF, GR, MIB);    insertInlineAsm(MF, GR, ST, MIB); -  selectOpBitcasts(MF, GR, MIB); +  lowerBitcasts(MF, GR, MIB);    return true;  } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp index 27f7e1a..5a1779c 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyFrameLowering.cpp @@ -81,7 +81,7 @@ WebAssemblyFrameLowering::getLocalForStackObject(MachineFunction &MF,    // Abuse object size to record number of WebAssembly locals allocated to    // this object.    MFI.setObjectSize(FrameIndex, ValueVTs.size()); -  return static_cast<unsigned>(Local); +  return Local;  }  /// We need a base pointer in the case of having items on the stack that diff --git a/llvm/lib/Target/X86/AsmParser/X86Operand.h b/llvm/lib/Target/X86/AsmParser/X86Operand.h index 89ac53e..a922725 100644 --- a/llvm/lib/Target/X86/AsmParser/X86Operand.h +++ b/llvm/lib/Target/X86/AsmParser/X86Operand.h @@ -620,37 +620,6 @@ struct X86Operand final : public MCParsedAsmOperand {      Inst.addOperand(MCOperand::createReg(Reg));    } -  bool isTILEPair() const { -    return Kind == Register && -           X86MCRegisterClasses[X86::TILERegClassID].contains(getReg()); -  } - -  void addTILEPairOperands(MCInst &Inst, unsigned N) const { -    assert(N == 1 && "Invalid number of operands!"); -    MCRegister Reg = getReg(); -    switch (Reg.id()) { -    default: -      llvm_unreachable("Invalid tile register!"); -    case X86::TMM0: -    case X86::TMM1: -      Reg = X86::TMM0_TMM1; -      break; -    case X86::TMM2: -    case X86::TMM3: -      Reg = X86::TMM2_TMM3; -      break; -    case X86::TMM4: -    case X86::TMM5: -      Reg = X86::TMM4_TMM5; -      break; -    case X86::TMM6: -    case X86::TMM7: -      Reg = X86::TMM6_TMM7; -      break; -    } -    Inst.addOperand(MCOperand::createReg(Reg)); -  } -    void addMemOperands(MCInst &Inst, unsigned N) const {      assert((N == 5) && "Invalid number of operands!");      if (getMemBaseReg()) diff --git a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp index 4927b45..7d2b5eb 100644 --- a/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp +++ b/llvm/lib/Target/X86/Disassembler/X86Disassembler.cpp @@ -810,10 +810,6 @@ static int readModRM(struct InternalInstruction *insn) {        if (index > 7)                                                           \          *valid = 0;                                                            \        return prefix##_TMM0 + index;                                            \ -    case TYPE_TMM_PAIR:                                                        \ -      if (index > 7)                                                           \ -        *valid = 0;                                                            \ -      return prefix##_TMM0_TMM1 + (index / 2);                                 \      case TYPE_VK:                                                              \        index &= 0xf;                                                            \        if (index > 7)                                                           \ @@ -2323,7 +2319,6 @@ static bool translateRM(MCInst &mcInst, const OperandSpecifier &operand,    case TYPE_YMM:    case TYPE_ZMM:    case TYPE_TMM: -  case TYPE_TMM_PAIR:    case TYPE_VK_PAIR:    case TYPE_VK:    case TYPE_DEBUGREG: diff --git a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h index dc9af2c..b0aa70b 100644 --- a/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h +++ b/llvm/lib/Target/X86/Disassembler/X86DisassemblerDecoder.h @@ -535,12 +535,6 @@ namespace X86Disassembler {    ENTRY(TMM6)                                                                  \    ENTRY(TMM7) -#define REGS_TMM_PAIRS                                                         \ -  ENTRY(TMM0_TMM1)                                                             \ -  ENTRY(TMM2_TMM3)                                                             \ -  ENTRY(TMM4_TMM5)                                                             \ -  ENTRY(TMM6_TMM7) -  #define ALL_EA_BASES                                                           \    EA_BASES_16BIT                                                               \    EA_BASES_32BIT                                                               \ @@ -565,7 +559,6 @@ namespace X86Disassembler {    REGS_DEBUG                                                                   \    REGS_CONTROL                                                                 \    REGS_TMM                                                                     \ -  REGS_TMM_PAIRS                                                               \    ENTRY(RIP)  /// All possible values of the base field for effective-address diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp index 1c5f166..759d95e 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.cpp @@ -467,22 +467,3 @@ void X86InstPrinterCommon::printVKPair(const MCInst *MI, unsigned OpNo,    }    llvm_unreachable("Unknown mask pair register name");  } - -void X86InstPrinterCommon::printTILEPair(const MCInst *MI, unsigned OpNo, -                                         raw_ostream &OS) { -  switch (MI->getOperand(OpNo).getReg()) { -  case X86::TMM0_TMM1: -    printRegName(OS, X86::TMM0); -    return; -  case X86::TMM2_TMM3: -    printRegName(OS, X86::TMM2); -    return; -  case X86::TMM4_TMM5: -    printRegName(OS, X86::TMM4); -    return; -  case X86::TMM6_TMM7: -    printRegName(OS, X86::TMM6); -    return; -  } -  llvm_unreachable("Unknown mask pair register name"); -} diff --git a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h index 2c9467c..cb55f2f 100644 --- a/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h +++ b/llvm/lib/Target/X86/MCTargetDesc/X86InstPrinterCommon.h @@ -40,7 +40,6 @@ protected:                        const MCSubtargetInfo &STI);    void printOptionalSegReg(const MCInst *MI, unsigned OpNo, raw_ostream &O);    void printVKPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS); -  void printTILEPair(const MCInst *MI, unsigned OpNo, raw_ostream &OS);  };  } // end namespace llvm diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td index a1fd366..9e291a6 100644 --- a/llvm/lib/Target/X86/X86.td +++ b/llvm/lib/Target/X86/X86.td @@ -274,9 +274,6 @@ def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",  def FeatureAMXMOVRS : SubtargetFeature<"amx-movrs", "HasAMXMOVRS", "true",                                         "Support AMX-MOVRS instructions",                                         [FeatureAMXTILE]>; -def FeatureAMXTRANSPOSE : SubtargetFeature<"amx-transpose", "HasAMXTRANSPOSE", "true", -                                           "Support AMX amx-transpose instructions", -                                           [FeatureAMXTILE]>;  def FeatureAMXAVX512 : SubtargetFeature<"amx-avx512",                                          "HasAMXAVX512", "true",                                          "Support AMX-AVX512 instructions", @@ -1177,8 +1174,7 @@ def ProcessorFeatures {                                                    FeatureAMXMOVRS,                                                    FeatureAMXAVX512,                                                    FeatureAMXFP8, -                                                  FeatureAMXTF32, -                                                  FeatureAMXTRANSPOSE]; +                                                  FeatureAMXTF32];    list<SubtargetFeature> DMRFeatures =      !listconcat(GNRDFeatures, DMRAdditionalFeatures); diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp index 4a9b824..e3c44c0 100644 --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -649,149 +649,6 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,      MI.setDesc(TII->get(Opc));      return true;    } -  // TILEPAIRLOAD is just for TILEPair spill, we don't have corresponding -  // AMX instruction to support it. So, split it to 2 load instructions: -  // "TILEPAIRLOAD TMM0:TMM1, Base, Scale, Index, Offset, Segment" --> -  // "TILELOAD TMM0, Base, Scale, Index, Offset, Segment" + -  // "TILELOAD TMM1, Base, Scale, Index, Offset + TMM_SIZE, Segment" -  case X86::PTILEPAIRLOAD: { -    int64_t Disp = MBBI->getOperand(1 + X86::AddrDisp).getImm(); -    Register TReg = MBBI->getOperand(0).getReg(); -    bool DstIsDead = MBBI->getOperand(0).isDead(); -    Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); -    Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); -    unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - -    MachineInstrBuilder MIBLo = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) -            .addReg(TReg0, RegState::Define | getDeadRegState(DstIsDead)); -    MachineInstrBuilder MIBHi = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILELOADD)) -            .addReg(TReg1, RegState::Define | getDeadRegState(DstIsDead)); - -    for (int i = 0; i < X86::AddrNumOperands; ++i) { -      MIBLo.add(MBBI->getOperand(1 + i)); -      if (i == X86::AddrDisp) -        MIBHi.addImm(Disp + TmmSize); -      else -        MIBHi.add(MBBI->getOperand(1 + i)); -    } - -    // Make sure the first stride reg used in first tileload is alive. -    MachineOperand &Stride = -        MIBLo.getInstr()->getOperand(1 + X86::AddrIndexReg); -    Stride.setIsKill(false); - -    // Split the memory operand, adjusting the offset and size for the halves. -    MachineMemOperand *OldMMO = MBBI->memoperands().front(); -    MachineFunction *MF = MBB.getParent(); -    MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); -    MachineMemOperand *MMOHi = -        MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - -    MIBLo.setMemRefs(MMOLo); -    MIBHi.setMemRefs(MMOHi); - -    // Delete the pseudo. -    MBB.erase(MBBI); -    return true; -  } -  // Similar with TILEPAIRLOAD, TILEPAIRSTORE is just for TILEPair spill, no -  // corresponding AMX instruction to support it. So, split it too: -  // "TILEPAIRSTORE Base, Scale, Index, Offset, Segment, TMM0:TMM1" --> -  // "TILESTORE Base, Scale, Index, Offset, Segment, TMM0" + -  // "TILESTORE Base, Scale, Index, Offset + TMM_SIZE, Segment, TMM1" -  case X86::PTILEPAIRSTORE: { -    int64_t Disp = MBBI->getOperand(X86::AddrDisp).getImm(); -    Register TReg = MBBI->getOperand(X86::AddrNumOperands).getReg(); -    bool SrcIsKill = MBBI->getOperand(X86::AddrNumOperands).isKill(); -    Register TReg0 = TRI->getSubReg(TReg, X86::sub_t0); -    Register TReg1 = TRI->getSubReg(TReg, X86::sub_t1); -    unsigned TmmSize = TRI->getRegSizeInBits(X86::TILERegClass) / 8; - -    MachineInstrBuilder MIBLo = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); -    MachineInstrBuilder MIBHi = -        BuildMI(MBB, MBBI, DL, TII->get(X86::TILESTORED)); - -    for (int i = 0; i < X86::AddrNumOperands; ++i) { -      MIBLo.add(MBBI->getOperand(i)); -      if (i == X86::AddrDisp) -        MIBHi.addImm(Disp + TmmSize); -      else -        MIBHi.add(MBBI->getOperand(i)); -    } -    MIBLo.addReg(TReg0, getKillRegState(SrcIsKill)); -    MIBHi.addReg(TReg1, getKillRegState(SrcIsKill)); - -    // Make sure the first stride reg used in first tilestore is alive. -    MachineOperand &Stride = MIBLo.getInstr()->getOperand(X86::AddrIndexReg); -    Stride.setIsKill(false); - -    // Split the memory operand, adjusting the offset and size for the halves. -    MachineMemOperand *OldMMO = MBBI->memoperands().front(); -    MachineFunction *MF = MBB.getParent(); -    MachineMemOperand *MMOLo = MF->getMachineMemOperand(OldMMO, 0, TmmSize); -    MachineMemOperand *MMOHi = -        MF->getMachineMemOperand(OldMMO, TmmSize, TmmSize); - -    MIBLo.setMemRefs(MMOLo); -    MIBHi.setMemRefs(MMOHi); - -    // Delete the pseudo. -    MBB.erase(MBBI); -    return true; -  } -  case X86::PT2RPNTLVWZ0V: -  case X86::PT2RPNTLVWZ0T1V: -  case X86::PT2RPNTLVWZ1V: -  case X86::PT2RPNTLVWZ1T1V: -  case X86::PT2RPNTLVWZ0RSV: -  case X86::PT2RPNTLVWZ0RST1V: -  case X86::PT2RPNTLVWZ1RSV: -  case X86::PT2RPNTLVWZ1RST1V: { -    for (unsigned i = 3; i > 0; --i) -      MI.removeOperand(i); -    unsigned Opc; -    switch (Opcode) { -    case X86::PT2RPNTLVWZ0V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); -      break; -    case X86::PT2RPNTLVWZ0T1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); -      break; -    case X86::PT2RPNTLVWZ1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); -      break; -    case X86::PT2RPNTLVWZ1T1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); -      break; -    case X86::PT2RPNTLVWZ0RSV: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); -      break; -    case X86::PT2RPNTLVWZ0RST1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); -      break; -    case X86::PT2RPNTLVWZ1RSV: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); -      break; -    case X86::PT2RPNTLVWZ1RST1V: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); -      break; -    default: -      llvm_unreachable("Impossible Opcode!"); -    } -    MI.setDesc(TII->get(Opc)); -    return true; -  } -  case X86::PTTRANSPOSEDV: -  case X86::PTCONJTFP16V: { -    for (int i = 2; i > 0; --i) -      MI.removeOperand(i); -    MI.setDesc(TII->get(Opcode == X86::PTTRANSPOSEDV ? X86::TTRANSPOSED -                                                     : X86::TCONJTFP16)); -    return true; -  }    case X86::PTCMMIMFP16PSV:    case X86::PTCMMRLFP16PSV:    case X86::PTDPBSSDV: @@ -800,13 +657,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,    case X86::PTDPBUUDV:    case X86::PTDPBF16PSV:    case X86::PTDPFP16PSV: -  case X86::PTTDPBF16PSV: -  case X86::PTTDPFP16PSV: -  case X86::PTTCMMIMFP16PSV: -  case X86::PTTCMMRLFP16PSV: -  case X86::PTCONJTCMMIMFP16PSV:    case X86::PTMMULTF32PSV: -  case X86::PTTMMULTF32PSV:    case X86::PTDPBF8PSV:    case X86::PTDPBHF8PSV:    case X86::PTDPHBF8PSV: @@ -816,6 +667,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,        MI.removeOperand(i);      unsigned Opc;      switch (Opcode) { +      // clang-format off      case X86::PTCMMIMFP16PSV:  Opc = X86::TCMMIMFP16PS; break;      case X86::PTCMMRLFP16PSV:  Opc = X86::TCMMRLFP16PS; break;      case X86::PTDPBSSDV:   Opc = X86::TDPBSSD; break; @@ -824,40 +676,12 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,      case X86::PTDPBUUDV:   Opc = X86::TDPBUUD; break;      case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break;      case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break; -    case X86::PTTDPBF16PSV: -      Opc = X86::TTDPBF16PS; -      break; -    case X86::PTTDPFP16PSV: -      Opc = X86::TTDPFP16PS; -      break; -    case X86::PTTCMMIMFP16PSV: -      Opc = X86::TTCMMIMFP16PS; -      break; -    case X86::PTTCMMRLFP16PSV: -      Opc = X86::TTCMMRLFP16PS; -      break; -    case X86::PTCONJTCMMIMFP16PSV: -      Opc = X86::TCONJTCMMIMFP16PS; -      break; -    case X86::PTMMULTF32PSV: -      Opc = X86::TMMULTF32PS; -      break; -    case X86::PTTMMULTF32PSV: -      Opc = X86::TTMMULTF32PS; -      break; -    case X86::PTDPBF8PSV: -      Opc = X86::TDPBF8PS; -      break; -    case X86::PTDPBHF8PSV: -      Opc = X86::TDPBHF8PS; -      break; -    case X86::PTDPHBF8PSV: -      Opc = X86::TDPHBF8PS; -      break; -    case X86::PTDPHF8PSV: -      Opc = X86::TDPHF8PS; -      break; - +    case X86::PTMMULTF32PSV: Opc = X86::TMMULTF32PS; break; +    case X86::PTDPBF8PSV: Opc = X86::TDPBF8PS; break; +    case X86::PTDPBHF8PSV: Opc = X86::TDPBHF8PS; break; +    case X86::PTDPHBF8PSV: Opc = X86::TDPHBF8PS; break; +    case X86::PTDPHF8PSV: Opc = X86::TDPHF8PS; break; +    // clang-format on      default:        llvm_unreachable("Unexpected Opcode");      } diff --git a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp index 787b71d..06f729a 100644 --- a/llvm/lib/Target/X86/X86FastPreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastPreTileConfig.cpp @@ -267,24 +267,16 @@ void X86FastPreTileConfig::reload(MachineBasicBlock::iterator UseMI,                      << printReg(TileReg, TRI) << '\n');  } -static unsigned getTileDefNum(MachineRegisterInfo *MRI, Register Reg) { -  if (Reg.isVirtual()) { -    unsigned RegClassID = MRI->getRegClass(Reg)->getID(); -    if (RegClassID == X86::TILERegClassID) -      return 1; -    if (RegClassID == X86::TILEPAIRRegClassID) -      return 2; -  } else { -    if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; +static bool isTileRegister(MachineRegisterInfo *MRI, Register Reg) { +  if (Reg.isVirtual() && +      (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID)) { +    return true;    } -  return 0; -} -static bool isTileRegister(MachineRegisterInfo *MRI, Register VirtReg) { -  return getTileDefNum(MRI, VirtReg) > 0; +  if (Reg >= X86::TMM0 && Reg <= X86::TMM7) +    return true; + +  return false;  }  static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) { @@ -296,7 +288,7 @@ static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {    if (!MO.isReg())      return false; -  return getTileDefNum(MRI, MO.getReg()) > 0; +  return isTileRegister(MRI, MO.getReg());  }  static ShapeT getShape(MachineRegisterInfo *MRI, Register TileReg) { @@ -636,19 +628,7 @@ bool X86FastPreTileConfig::configBasicBlock(MachineBasicBlock &MBB) {        else if (dominates(MBB, LastShapeMI, ColMI))          LastShapeMI = ColMI;      } -    unsigned TileDefNum = getTileDefNum(MRI, MI.getOperand(0).getReg()); -    if (TileDefNum > 1) { -      for (unsigned I = 1; I < TileDefNum; I++) { -        MachineOperand *ColxMO = &MI.getOperand(2 + I); -        MachineInstr *ColxMI = MRI->getVRegDef(ColxMO->getReg()); -        if (ColxMI->getParent() == &MBB) { -          if (!LastShapeMI) -            LastShapeMI = ColxMI; -          else if (dominates(MBB, LastShapeMI, ColxMI)) -            LastShapeMI = ColxMI; -        } -      } -    } +      // If there is user live out of the tilecfg, spill it and reload in      // before the user.      Register TileReg = MI.getOperand(0).getReg(); diff --git a/llvm/lib/Target/X86/X86FastTileConfig.cpp b/llvm/lib/Target/X86/X86FastTileConfig.cpp index 11d331b..d86ae36 100644 --- a/llvm/lib/Target/X86/X86FastTileConfig.cpp +++ b/llvm/lib/Target/X86/X86FastTileConfig.cpp @@ -77,14 +77,14 @@ INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,  INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,                      "Fast Tile Register Configure", false, false) -static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) { +static bool isTileDef(MachineRegisterInfo *MRI, MachineInstr &MI) {    // There is no phi instruction after register allocation.    assert(MI.isPHI() == false);    // The instruction must have 3 operands: tile def, row, col.    // It should be AMX pseudo instruction that have shape operand.    if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 ||        !MI.isPseudo()) -    return 0; +    return false;    MachineOperand &MO = MI.getOperand(0);    if (MO.isReg()) { @@ -93,24 +93,18 @@ static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) {      // register is not rewritten yet.      if (Reg.isVirtual()) {        if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID) -        return 1; -      if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID) -        return 2; +        return true;      }      if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; +      return true;    } -  return 0; +  return false;  }  static unsigned getTMMIndex(Register Reg) {    if (Reg >= X86::TMM0 && Reg <= X86::TMM7)      return Reg - X86::TMM0; -  if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -    return (Reg - X86::TMM0_TMM1) * 2;    llvm_unreachable("Invalid Tmm Reg!");  } @@ -120,17 +114,14 @@ bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) {    bool Change = false;    SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos;    for (MachineInstr &MI : reverse(MBB)) { -    unsigned DefNum = getNumDefTiles(MRI, MI); -    if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV) +    if (!isTileDef(MRI, MI) && MI.getOpcode() != X86::PLDTILECFGV)        continue;      // AMX instructions that define tile register.      if (MI.getOpcode() != X86::PLDTILECFGV) {        MachineOperand &Row = MI.getOperand(1);        unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg()); -      for (unsigned I = 0; I < DefNum; I++) { -        MachineOperand &Col = MI.getOperand(2 + I); -        ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)}); -      } +      MachineOperand &Col = MI.getOperand(2); +      ShapeInfos.push_back({TMMIdx, ShapeT(&Row, &Col)});      } else { // PLDTILECFGV        // Rewrite the shape information to memory. Stack slot should have        // been initialized to zero in pre config. diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 4393f6e..d4418c8 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -337,23 +337,8 @@ namespace {      // lowering but before ISEL.      bool isAMXSDNode(SDNode *N) const {        // Check if N is AMX SDNode: -      // 1. check specific opcode since these carry MVT::Untyped instead of -      // x86amx_type; -      // 2. check result type; -      // 3. check operand type; -      switch (N->getOpcode()) { -      default: -        break; -      case X86::PT2RPNTLVWZ0V: -      case X86::PT2RPNTLVWZ0T1V: -      case X86::PT2RPNTLVWZ1V: -      case X86::PT2RPNTLVWZ1T1V: -      case X86::PT2RPNTLVWZ0RSV: -      case X86::PT2RPNTLVWZ0RST1V: -      case X86::PT2RPNTLVWZ1RSV: -      case X86::PT2RPNTLVWZ1RST1V: -        return true; -      } +      // 1. check result type; +      // 2. check operand type;        for (unsigned Idx = 0, E = N->getNumValues(); Idx != E; ++Idx) {          if (N->getValueType(Idx) == MVT::x86amx)            return true; @@ -5398,65 +5383,6 @@ void X86DAGToDAGISel::Select(SDNode *Node) {        ReplaceNode(Node, CNode);        return;      } -    case Intrinsic::x86_t2rpntlvwz0rs: -    case Intrinsic::x86_t2rpntlvwz0rst1: -    case Intrinsic::x86_t2rpntlvwz1rs: -    case Intrinsic::x86_t2rpntlvwz1rst1: -      if (!Subtarget->hasAMXMOVRS()) -        break; -      [[fallthrough]]; -    case Intrinsic::x86_t2rpntlvwz0: -    case Intrinsic::x86_t2rpntlvwz0t1: -    case Intrinsic::x86_t2rpntlvwz1: -    case Intrinsic::x86_t2rpntlvwz1t1: { -      if (!Subtarget->hasAMXTRANSPOSE()) -        break; -      auto *MFI = -          CurDAG->getMachineFunction().getInfo<X86MachineFunctionInfo>(); -      MFI->setAMXProgModel(AMXProgModelEnum::DirectReg); -      unsigned Opc; -      switch (IntNo) { -      default: -        llvm_unreachable("Unexpected intrinsic!"); -      case Intrinsic::x86_t2rpntlvwz0: -        Opc = X86::PT2RPNTLVWZ0; -        break; -      case Intrinsic::x86_t2rpntlvwz0t1: -        Opc = X86::PT2RPNTLVWZ0T1; -        break; -      case Intrinsic::x86_t2rpntlvwz1: -        Opc = X86::PT2RPNTLVWZ1; -        break; -      case Intrinsic::x86_t2rpntlvwz1t1: -        Opc = X86::PT2RPNTLVWZ1T1; -        break; -      case Intrinsic::x86_t2rpntlvwz0rs: -        Opc = X86::PT2RPNTLVWZ0RS; -        break; -      case Intrinsic::x86_t2rpntlvwz0rst1: -        Opc = X86::PT2RPNTLVWZ0RST1; -        break; -      case Intrinsic::x86_t2rpntlvwz1rs: -        Opc = X86::PT2RPNTLVWZ1RS; -        break; -      case Intrinsic::x86_t2rpntlvwz1rst1: -        Opc = X86::PT2RPNTLVWZ1RST1; -        break; -      } -      // FIXME: Match displacement and scale. -      unsigned TIndex = Node->getConstantOperandVal(2); -      SDValue TReg = getI8Imm(TIndex, dl); -      SDValue Base = Node->getOperand(3); -      SDValue Scale = getI8Imm(1, dl); -      SDValue Index = Node->getOperand(4); -      SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32); -      SDValue Segment = CurDAG->getRegister(0, MVT::i16); -      SDValue Chain = Node->getOperand(0); -      SDValue Ops[] = {TReg, Base, Scale, Index, Disp, Segment, Chain}; -      MachineSDNode *CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); -      ReplaceNode(Node, CNode); -      return; -    }      }      break;    } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 5785440..007074c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -12213,7 +12213,7 @@ static int matchShuffleAsShift(MVT &ShiftVT, unsigned &Opcode,      MVT ShiftSVT = MVT::getIntegerVT(ScalarSizeInBits * Scale);      ShiftVT = ByteShift ? MVT::getVectorVT(MVT::i8, SizeInBits / 8)                          : MVT::getVectorVT(ShiftSVT, Size / Scale); -    return (int)ShiftAmt; +    return ShiftAmt;    };    // SSE/AVX supports logical shifts up to 64-bit integers - so we can just @@ -27946,67 +27946,6 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,        return DAG.getNode(ISD::MERGE_VALUES, dl, Op->getVTList(), SetCC,                           Operation.getValue(1));      } -    case Intrinsic::x86_t2rpntlvwz0rs_internal: -    case Intrinsic::x86_t2rpntlvwz0rst1_internal: -    case Intrinsic::x86_t2rpntlvwz1rs_internal: -    case Intrinsic::x86_t2rpntlvwz1rst1_internal: -    case Intrinsic::x86_t2rpntlvwz0_internal: -    case Intrinsic::x86_t2rpntlvwz0t1_internal: -    case Intrinsic::x86_t2rpntlvwz1_internal: -    case Intrinsic::x86_t2rpntlvwz1t1_internal: { -      auto *X86MFI = DAG.getMachineFunction().getInfo<X86MachineFunctionInfo>(); -      X86MFI->setAMXProgModel(AMXProgModelEnum::ManagedRA); -      unsigned IntNo = Op.getConstantOperandVal(1); -      unsigned Opc = 0; -      switch (IntNo) { -      default: -        llvm_unreachable("Unexpected intrinsic!"); -      case Intrinsic::x86_t2rpntlvwz0_internal: -        Opc = X86::PT2RPNTLVWZ0V; -        break; -      case Intrinsic::x86_t2rpntlvwz0t1_internal: -        Opc = X86::PT2RPNTLVWZ0T1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1_internal: -        Opc = X86::PT2RPNTLVWZ1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1t1_internal: -        Opc = X86::PT2RPNTLVWZ1T1V; -        break; -      case Intrinsic::x86_t2rpntlvwz0rs_internal: -        Opc = X86::PT2RPNTLVWZ0RSV; -        break; -      case Intrinsic::x86_t2rpntlvwz0rst1_internal: -        Opc = X86::PT2RPNTLVWZ0RST1V; -        break; -      case Intrinsic::x86_t2rpntlvwz1rs_internal: -        Opc = X86::PT2RPNTLVWZ1RSV; -        break; -      case Intrinsic::x86_t2rpntlvwz1rst1_internal: -        Opc = X86::PT2RPNTLVWZ1RST1V; -        break; -      } - -      SDLoc DL(Op); -      SDVTList VTs = DAG.getVTList(MVT::Untyped, MVT::Other); - -      SDValue Ops[] = {Op.getOperand(2),                       // Row -                       Op.getOperand(3),                       // Col0 -                       Op.getOperand(4),                       // Col1 -                       Op.getOperand(5),                       // Base -                       DAG.getTargetConstant(1, DL, MVT::i8),  // Scale -                       Op.getOperand(6),                       // Index -                       DAG.getTargetConstant(0, DL, MVT::i32), // Disp -                       DAG.getRegister(0, MVT::i16),           // Segment -                       Op.getOperand(0)};                      // Chain - -      MachineSDNode *Res = DAG.getMachineNode(Opc, DL, VTs, Ops); -      SDValue Res0 = DAG.getTargetExtractSubreg(X86::sub_t0, DL, MVT::x86amx, -                                                SDValue(Res, 0)); -      SDValue Res1 = DAG.getTargetExtractSubreg(X86::sub_t1, DL, MVT::x86amx, -                                                SDValue(Res, 0)); -      return DAG.getMergeValues({Res0, Res1, SDValue(Res, 1)}, DL); -    }      case Intrinsic::x86_atomic_bts_rm:      case Intrinsic::x86_atomic_btc_rm:      case Intrinsic::x86_atomic_btr_rm: { @@ -37745,10 +37684,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,      assert (Imm < 8 && "Illegal tmm index");      return X86::TMM0 + Imm;    }; -  auto TMMImmToTMMPair = [](unsigned Imm) { -    assert(Imm < 8 && "Illegal tmm pair index."); -    return X86::TMM0_TMM1 + Imm / 2; -  };    switch (MI.getOpcode()) {    default:      llvm_unreachable("Unexpected instr type to insert"); @@ -38129,53 +38064,25 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,    case X86::PTDPBHF8PS:    case X86::PTDPHBF8PS:    case X86::PTDPHF8PS: -  case X86::PTTDPBF16PS: -  case X86::PTTDPFP16PS: -  case X86::PTTCMMIMFP16PS: -  case X86::PTTCMMRLFP16PS: -  case X86::PTCONJTCMMIMFP16PS: -  case X86::PTMMULTF32PS: -  case X86::PTTMMULTF32PS: { +  case X86::PTMMULTF32PS: {      unsigned Opc;      switch (MI.getOpcode()) {      default: llvm_unreachable("illegal opcode!"); +      // clang-format off      case X86::PTDPBSSD: Opc = X86::TDPBSSD; break;      case X86::PTDPBSUD: Opc = X86::TDPBSUD; break;      case X86::PTDPBUSD: Opc = X86::TDPBUSD; break;      case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;      case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;      case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break; -    case X86::PTCMMIMFP16PS: -      Opc = X86::TCMMIMFP16PS; -      break; -    case X86::PTCMMRLFP16PS: -      Opc = X86::TCMMRLFP16PS; -      break; +    case X86::PTCMMIMFP16PS: Opc = X86::TCMMIMFP16PS; break; +    case X86::PTCMMRLFP16PS: Opc = X86::TCMMRLFP16PS; break;      case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;      case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;      case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;      case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break; -    case X86::PTTDPBF16PS: -      Opc = X86::TTDPBF16PS; -      break; -    case X86::PTTDPFP16PS: -      Opc = X86::TTDPFP16PS; -      break; -    case X86::PTTCMMIMFP16PS: -      Opc = X86::TTCMMIMFP16PS; -      break; -    case X86::PTTCMMRLFP16PS: -      Opc = X86::TTCMMRLFP16PS; -      break; -    case X86::PTCONJTCMMIMFP16PS: -      Opc = X86::TCONJTCMMIMFP16PS; -      break; -    case X86::PTMMULTF32PS: -      Opc = X86::TMMULTF32PS; -      break; -    case X86::PTTMMULTF32PS: -      Opc = X86::TTMMULTF32PS; -      break; +    case X86::PTMMULTF32PS: Opc = X86::TMMULTF32PS; break; +      // clang-format on      }      MachineInstrBuilder MIB = BuildMI(*BB, MI, MIMD, TII->get(Opc)); @@ -38246,70 +38153,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,      MI.eraseFromParent(); // The pseudo is gone now.      return BB;    } -  case X86::PT2RPNTLVWZ0: -  case X86::PT2RPNTLVWZ0T1: -  case X86::PT2RPNTLVWZ1: -  case X86::PT2RPNTLVWZ1T1: -  case X86::PT2RPNTLVWZ0RS: -  case X86::PT2RPNTLVWZ0RST1: -  case X86::PT2RPNTLVWZ1RS: -  case X86::PT2RPNTLVWZ1RST1: { -    const DebugLoc &DL = MI.getDebugLoc(); -    unsigned Opc; -#define GET_EGPR_IF_ENABLED(OPC) (Subtarget.hasEGPR() ? OPC##_EVEX : OPC) -    switch (MI.getOpcode()) { -    default: -      llvm_unreachable("Unexpected instruction!"); -    case X86::PT2RPNTLVWZ0: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0); -      break; -    case X86::PT2RPNTLVWZ0T1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0T1); -      break; -    case X86::PT2RPNTLVWZ1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1); -      break; -    case X86::PT2RPNTLVWZ1T1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1T1); -      break; -    case X86::PT2RPNTLVWZ0RS: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RS); -      break; -    case X86::PT2RPNTLVWZ0RST1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ0RST1); -      break; -    case X86::PT2RPNTLVWZ1RS: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RS); -      break; -    case X86::PT2RPNTLVWZ1RST1: -      Opc = GET_EGPR_IF_ENABLED(X86::T2RPNTLVWZ1RST1); -      break; -    } -#undef GET_EGPR_IF_ENABLED -    MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); -    MIB.addReg(TMMImmToTMMPair(MI.getOperand(0).getImm()), RegState::Define); - -    MIB.add(MI.getOperand(1)); // base -    MIB.add(MI.getOperand(2)); // scale -    MIB.add(MI.getOperand(3)); // index -    MIB.add(MI.getOperand(4)); // displacement -    MIB.add(MI.getOperand(5)); // segment -    MI.eraseFromParent();      // The pseudo is gone now. -    return BB; -  } -  case X86::PTTRANSPOSED: -  case X86::PTCONJTFP16: { -    const DebugLoc &DL = MI.getDebugLoc(); -    unsigned Opc = MI.getOpcode() == X86::PTTRANSPOSED ? X86::TTRANSPOSED -                                                       : X86::TCONJTFP16; - -    MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc)); -    MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define); -    MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef); - -    MI.eraseFromParent(); // The pseudo is gone now. -    return BB; -  }    case X86::PTCVTROWPS2BF16Hrri:    case X86::PTCVTROWPS2BF16Lrri:    case X86::PTCVTROWPS2PHHrri: @@ -48778,15 +48621,19 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,        SDValue BC0 = peekThroughBitcasts(Op0);        if (BC0.getOpcode() == X86ISD::PCMPEQ &&            ISD::isBuildVectorAllZeros(BC0.getOperand(1).getNode())) { -        SDLoc DL(EFLAGS);          CC = (CC == X86::COND_B ? X86::COND_E : X86::COND_NE); -        SDValue X = DAG.getBitcast(OpVT, BC0.getOperand(0)); -        return DAG.getNode(EFLAGS.getOpcode(), DL, VT, X, X); +        SDValue X = DAG.getBitcast(OpVT, DAG.getFreeze(BC0.getOperand(0))); +        return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, X, X);        }      }    }    if (CC == X86::COND_E || CC == X86::COND_NE) { +    // Canonicalize constant to RHS if we're just using ZF. +    if (Op0 != Op1 && DAG.isConstantIntBuildVectorOrConstantInt(Op0) && +        !DAG.isConstantIntBuildVectorOrConstantInt(Op1)) +      return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op1, Op0); +      // TESTZ(X,~Y) == TESTC(Y,X)      if (SDValue NotOp1 = IsNOT(Op1, DAG)) {        CC = (CC == X86::COND_E ? X86::COND_B : X86::COND_AE); @@ -48832,7 +48679,7 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,                MVT FloatSVT = MVT::getFloatingPointVT(EltBits);                MVT FloatVT =                    MVT::getVectorVT(FloatSVT, OpVT.getSizeInBits() / EltBits); -              Res = DAG.getBitcast(FloatVT, Res); +              Res = DAG.getBitcast(FloatVT, DAG.getFreeze(Res));                return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Res, Res);              } else if (EltBits == 16) {                MVT MovmskVT = BCVT.is128BitVector() ? MVT::v16i8 : MVT::v32i8; @@ -48850,13 +48697,31 @@ static SDValue combinePTESTCC(SDValue EFLAGS, X86::CondCode &CC,        }      } -    // TESTZ(-1,X) == TESTZ(X,X) -    if (ISD::isBuildVectorAllOnes(Op0.getNode())) -      return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op1, Op1); -      // TESTZ(X,-1) == TESTZ(X,X) -    if (ISD::isBuildVectorAllOnes(Op1.getNode())) +    if (ISD::isBuildVectorAllOnes(Op1.getNode())) { +      Op0 = DAG.getFreeze(Op0);        return DAG.getNode(EFLAGS.getOpcode(), SDLoc(EFLAGS), VT, Op0, Op0); +    } + +    // Attempt to convert PTESTZ(X,SIGNMASK) -> VTESTPD/PSZ(X,X) on AVX targets. +    if (EFLAGS.getOpcode() == X86ISD::PTEST && Subtarget.hasAVX()) { +      KnownBits KnownOp1 = DAG.computeKnownBits(Op1); +      assert(KnownOp1.getBitWidth() == 64 && +             "Illegal PTEST vector element width"); +      if (KnownOp1.isConstant()) { +        const APInt &Mask = KnownOp1.getConstant(); +        if (Mask.isSignMask()) { +          MVT FpVT = MVT::getVectorVT(MVT::f64, OpVT.getSizeInBits() / 64); +          Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0)); +          return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0); +        } +        if (Mask.isSplat(32) && Mask.trunc(32).isSignMask()) { +          MVT FpVT = MVT::getVectorVT(MVT::f32, OpVT.getSizeInBits() / 32); +          Op0 = DAG.getBitcast(FpVT, DAG.getFreeze(Op0)); +          return DAG.getNode(X86ISD::TESTP, SDLoc(EFLAGS), VT, Op0, Op0); +        } +      } +    }      // TESTZ(OR(LO(X),HI(X)),OR(LO(Y),HI(Y))) -> TESTZ(X,Y)      // TODO: Add COND_NE handling? @@ -53479,6 +53344,105 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,    return SDValue();  } +// Look for a RMW operation that only touches one bit of a larger than legal +// type and fold it to a BTC/BTR/BTS or bit insertion pattern acting on a single +// i32 sub value. +static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL, +                              SelectionDAG &DAG, +                              const X86Subtarget &Subtarget) { +  using namespace SDPatternMatch; + +  // Only handle normal stores and its chain was a matching normal load. +  auto *Ld = dyn_cast<LoadSDNode>(St->getChain()); +  if (!ISD::isNormalStore(St) || !St->isSimple() || !Ld || +      !ISD::isNormalLoad(Ld) || !Ld->isSimple() || +      Ld->getBasePtr() != St->getBasePtr() || +      Ld->getOffset() != St->getOffset()) +    return SDValue(); + +  SDValue LoadVal(Ld, 0); +  SDValue StoredVal = St->getValue(); +  EVT VT = StoredVal.getValueType(); + +  // Only narrow larger than legal scalar integers. +  if (!VT.isScalarInteger() || +      VT.getSizeInBits() <= (Subtarget.is64Bit() ? 64 : 32)) +    return SDValue(); + +  // BTR: X & ~(1 << ShAmt) +  // BTS: X | (1 << ShAmt) +  // BTC: X ^ (1 << ShAmt) +  // +  // BitInsert: (X & ~(1 << ShAmt)) | (InsertBit << ShAmt) +  SDValue InsertBit, ShAmt; +  if (!StoredVal.hasOneUse() || +      !(sd_match(StoredVal, m_And(m_Specific(LoadVal), +                                  m_Not(m_Shl(m_One(), m_Value(ShAmt))))) || +        sd_match(StoredVal, +                 m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || +        sd_match(StoredVal, +                 m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || +        sd_match(StoredVal, +                 m_Or(m_And(m_Specific(LoadVal), +                            m_Not(m_Shl(m_One(), m_Value(ShAmt)))), +                      m_Shl(m_Value(InsertBit), m_Deferred(ShAmt)))))) +    return SDValue(); + +  // Ensure the shift amount is in bounds. +  KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); +  if (KnownAmt.getMaxValue().uge(VT.getSizeInBits())) +    return SDValue(); + +  // If we're inserting a bit then it must be the LSB. +  if (InsertBit) { +    KnownBits KnownInsert = DAG.computeKnownBits(InsertBit); +    if (KnownInsert.countMinLeadingZeros() < (VT.getSizeInBits() - 1)) +      return SDValue(); +  } + +  // Split the shift into an alignment shift that moves the active i32 block to +  // the bottom bits for truncation and a modulo shift that can act on the i32. +  EVT AmtVT = ShAmt.getValueType(); +  SDValue AlignAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, +                                 DAG.getSignedConstant(-32LL, DL, AmtVT)); +  SDValue ModuloAmt = +      DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT)); +  ModuloAmt = DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8); + +  // Compute the byte offset for the i32 block that is changed by the RMW. +  // combineTruncate will adjust the load for us in a similar way. +  EVT PtrVT = St->getBasePtr().getValueType(); +  SDValue PtrBitOfs = DAG.getZExtOrTrunc(AlignAmt, DL, PtrVT); +  SDValue PtrByteOfs = DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs, +                                   DAG.getShiftAmountConstant(3, PtrVT, DL)); +  SDValue NewPtr = DAG.getMemBasePlusOffset(St->getBasePtr(), PtrByteOfs, DL, +                                            SDNodeFlags::NoUnsignedWrap); + +  // Reconstruct the BTC/BTR/BTS pattern for the i32 block and store. +  SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt); +  X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); + +  SDValue Mask = DAG.getNode(ISD::SHL, DL, MVT::i32, +                             DAG.getConstant(1, DL, MVT::i32), ModuloAmt); + +  SDValue Res; +  if (InsertBit) { +    SDValue BitMask = +        DAG.getNode(ISD::SHL, DL, MVT::i32, +                    DAG.getZExtOrTrunc(InsertBit, DL, MVT::i32), ModuloAmt); +    Res = +        DAG.getNode(ISD::AND, DL, MVT::i32, X, DAG.getNOT(DL, Mask, MVT::i32)); +    Res = DAG.getNode(ISD::OR, DL, MVT::i32, Res, BitMask); +  } else { +    if (StoredVal.getOpcode() == ISD::AND) +      Mask = DAG.getNOT(DL, Mask, MVT::i32); +    Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask); +  } + +  return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(), +                      Align(), St->getMemOperand()->getFlags()); +} +  static SDValue combineStore(SDNode *N, SelectionDAG &DAG,                              TargetLowering::DAGCombinerInfo &DCI,                              const X86Subtarget &Subtarget) { @@ -53705,6 +53669,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,      }    } +  if (SDValue R = narrowBitOpRMW(St, dl, DAG, Subtarget)) +    return R; +    // Convert store(cmov(load(p), x, CC), p) to cstore(x, p, CC)    //         store(cmov(x, load(p), CC), p) to cstore(x, p, InvertCC)    if ((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) && @@ -54492,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,  static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,                                 const X86Subtarget &Subtarget,                                 const SDLoc &DL) { +  using namespace SDPatternMatch;    if (!VT.isVector() || !Subtarget.hasSSSE3())      return SDValue(); @@ -54501,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,      return SDValue();    SDValue SSatVal = detectSSatPattern(In, VT); -  if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) +  if (!SSatVal)      return SDValue(); -  // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs -  // of multiplies from even/odd elements. -  SDValue N0 = SSatVal.getOperand(0); -  SDValue N1 = SSatVal.getOperand(1); - -  if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) -    return SDValue(); - -  SDValue N00 = N0.getOperand(0); -  SDValue N01 = N0.getOperand(1); -  SDValue N10 = N1.getOperand(0); -  SDValue N11 = N1.getOperand(1); - +  // See if this is a signed saturation of an ADD, adding pairs of multiplies +  // from even/odd elements, from zero_extend/sign_extend operands. +  //    // TODO: Handle constant vectors and use knownbits/computenumsignbits? -  // Canonicalize zero_extend to LHS. -  if (N01.getOpcode() == ISD::ZERO_EXTEND) -    std::swap(N00, N01); -  if (N11.getOpcode() == ISD::ZERO_EXTEND) -    std::swap(N10, N11); - -  // Ensure we have a zero_extend and a sign_extend. -  if (N00.getOpcode() != ISD::ZERO_EXTEND || -      N01.getOpcode() != ISD::SIGN_EXTEND || -      N10.getOpcode() != ISD::ZERO_EXTEND || -      N11.getOpcode() != ISD::SIGN_EXTEND) +  SDValue N00, N01, N10, N11; +  if (!sd_match(SSatVal, +                m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))), +                      m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))      return SDValue(); -  // Peek through the extends. -  N00 = N00.getOperand(0); -  N01 = N01.getOperand(0); -  N10 = N10.getOperand(0); -  N11 = N11.getOperand(0); -    // Ensure the extend is from vXi8.    if (N00.getValueType().getVectorElementType() != MVT::i8 ||        N01.getValueType().getVectorElementType() != MVT::i8 || @@ -54659,8 +54604,9 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,    // truncation, see if we can convert the shift into a pointer offset instead.    // Limit this to normal (non-ext) scalar integer loads.    if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL && -      Src.hasOneUse() && Src.getOperand(0).hasOneUse() && -      ISD::isNormalLoad(Src.getOperand(0).getNode())) { +      Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) && +      (Src.getOperand(0).hasOneUse() || +       !DAG.getTargetLoweringInfo().isOperationLegal(ISD::LOAD, SrcVT))) {      auto *Ld = cast<LoadSDNode>(Src.getOperand(0));      if (Ld->isSimple() && VT.isByteSized() &&          isPowerOf2_64(VT.getSizeInBits())) { @@ -54668,9 +54614,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,        KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);        // Check the shift amount is byte aligned.        // Check the truncation doesn't use any shifted in (zero) top bits. +      // Check the shift amount doesn't depend on the original load.        if (KnownAmt.countMinTrailingZeros() >= 3 &&            KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() - -                                     VT.getSizeInBits())) { +                                     VT.getSizeInBits()) && +          !Ld->isPredecessorOf(ShAmt.getNode())) {          EVT PtrVT = Ld->getBasePtr().getValueType();          SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT);          SDValue PtrByteOfs = @@ -56458,6 +56406,7 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,  static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,                              TargetLowering::DAGCombinerInfo &DCI,                              const X86Subtarget &Subtarget) { +  using namespace SDPatternMatch;    const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();    const SDValue LHS = N->getOperand(0);    const SDValue RHS = N->getOperand(1); @@ -56516,6 +56465,37 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,        if (SDValue AndN = MatchAndCmpEq(RHS, LHS))          return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC); +      // If we're performing a bit test on a larger than legal type, attempt +      // to (aligned) shift down the value to the bottom 32-bits and then +      // perform the bittest on the i32 value. +      // ICMP_ZERO(AND(X,SHL(1,IDX))) +      // --> ICMP_ZERO(AND(TRUNC(SRL(X,AND(IDX,-32))),SHL(1,AND(IDX,31)))) +      if (isNullConstant(RHS) && +          OpVT.getScalarSizeInBits() > (Subtarget.is64Bit() ? 64 : 32)) { +        SDValue X, ShAmt; +        if (sd_match(LHS, m_OneUse(m_And(m_Value(X), +                                         m_Shl(m_One(), m_Value(ShAmt)))))) { +          // Only attempt this if the shift amount is known to be in bounds. +          KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); +          if (KnownAmt.getMaxValue().ult(OpVT.getScalarSizeInBits())) { +            EVT AmtVT = ShAmt.getValueType(); +            SDValue AlignAmt = +                DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, +                            DAG.getSignedConstant(-32LL, DL, AmtVT)); +            SDValue ModuloAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, +                                            DAG.getConstant(31, DL, AmtVT)); +            SDValue Mask = DAG.getNode( +                ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32), +                DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8)); +            X = DAG.getNode(ISD::SRL, DL, OpVT, X, AlignAmt); +            X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); +            X = DAG.getNode(ISD::AND, DL, MVT::i32, X, Mask); +            return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, MVT::i32), +                                CC); +          } +        } +      } +        // cmpeq(trunc(x),C) --> cmpeq(x,C)        // cmpne(trunc(x),C) --> cmpne(x,C)        // iff x upper bits are zero. diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td index 69a5115..522782a 100644 --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -338,188 +338,6 @@ let Predicates = [HasAMXFP8, In64BitMode] in {    }  } -let Predicates = [HasAMXTILE, In64BitMode], isPseudo = true, SchedRW = [WriteSystem] in { -  let mayStore = 1 in -  def PTILEPAIRSTORE : PseudoI<(outs), (ins opaquemem:$src1, TILEPair:$src2), []>; -  let mayLoad = 1 in -  def PTILEPAIRLOAD : PseudoI<(outs TILEPair:$dst), (ins opaquemem:$src), []>; -} - -multiclass T2RPNTLVW_Base<bits<8> op1, bits<8> op2, string rs, string suffix> { -  def Z0#rs#suffix    : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz0" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PS; -  def Z0#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz0" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PS; -  def Z1#rs#suffix    : I<op1, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz1" #!tolower(rs)# "\t{$src, $dst|$dst, $src}", []>, PD; -  def Z1#rs#T1#suffix : I<op2, MRMSrcMemFSIB, (outs TILEPair:$dst), (ins sibmem:$src), -                          "t2rpntlvwz1" #!tolower(rs)# "t1\t{$src, $dst|$dst, $src}", []>, PD; -} - -let Predicates = [HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "">, T8, VEX; - -let Predicates = [HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0x6e, 0x6f, "", "_EVEX">, T8, EVEX, NoCD8; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "">, T_MAP5, VEX; - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, HasEGPR, In64BitMode], SchedRW = [WriteSystem] in -  defm T2RPNTLVW : T2RPNTLVW_Base<0xf8, 0xf9, "RS", "_EVEX">, T_MAP5, EVEX, NoCD8; - -let Predicates = [HasAMXTRANSPOSE, In64BitMode] in { -  let SchedRW = [WriteSystem] in { -    def TTRANSPOSED : I<0x5f, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), -                        "ttransposed\t{$src, $dst|$dst, $src}", []>, VEX, T8, XS; -    let isPseudo = true in { -      def PT2RPNTLVWZ0V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ0T1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -      def PT2RPNTLVWZ1T1V : PseudoI<(outs TILEPair:$dst), -                                  (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                                  []>; -    } - -    def PTTRANSPOSEDV : PseudoI<(outs TILE:$dst), -                                (ins GR16:$src1, GR16:$src2, TILE:$src), -                                [(set TILE: $dst, -                                 (int_x86_ttransposed_internal GR16:$src1, GR16:$src2, -                                  TILE:$src))]>; - -    let usesCustomInserter = 1 in { -      def PT2RPNTLVWZ0 : PseudoI<(outs), (ins u8imm:$dst, -                                 sibmem:$src1), []>; -      def PT2RPNTLVWZ0T1 : PseudoI<(outs), (ins u8imm:$dst, -                                   sibmem:$src1), []>; -      def PT2RPNTLVWZ1 : PseudoI<(outs), (ins u8imm:$dst, -                                 sibmem:$src1), []>; -      def PT2RPNTLVWZ1T1 : PseudoI<(outs), (ins u8imm:$dst, -                                   sibmem:$src1), []>; -      def PTTRANSPOSED : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), -                                 [(int_x86_ttransposed timm:$dst, timm:$src)]>; -    } -  } -} // HasAMXTILE, HasAMXTRANSPOSE - -let Predicates = [HasAMXBF16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in -    def TTDPBF16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), -                       (ins TILE:$src1, TILE:$src2, TILE:$src3), -                       "ttdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                       []>, VEX, VVVV, T8,XS; -  let Constraints = "$src4 = $dst" in -    def PTTDPBF16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                GR16:$src2, GR16:$src3, TILE:$src4, -                                TILE:$src5, TILE:$src6), -                                [(set TILE: $dst, -                                  (int_x86_ttdpbf16ps_internal GR16:$src1, GR16:$src2, -                                   GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  let usesCustomInserter = 1 in -    def PTTDPBF16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                              [(int_x86_ttdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXFP16, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in -    def TTDPFP16PS : I<0x6c, MRMSrcReg4VOp3, (outs TILE:$dst), -                       (ins TILE:$src1, TILE:$src2, TILE:$src3), -                       "ttdpfp16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                       []>, VEX, VVVV, T8,XD; -  let Constraints = "$src4 = $dst" in -    def PTTDPFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                GR16:$src2, GR16:$src3, TILE:$src4, -                                TILE:$src5, TILE:$src6), -                                [(set TILE: $dst, -                                  (int_x86_ttdpfp16ps_internal GR16:$src1, GR16:$src2, -                                   GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  let usesCustomInserter = 1 in -    def PTTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                              [(int_x86_ttdpfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -} - -let Predicates = [HasAMXCOMPLEX, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let Constraints = "$src1 = $dst" in { -    def TTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                          (ins TILE:$src1, TILE:$src2, TILE:$src3), -                          "ttcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                          []>, VEX, VVVV, T8,XD; -    def TTCMMRLFP16PS: I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                         (ins TILE:$src1, TILE:$src2, TILE:$src3), -                         "ttcmmrlfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                         []>, VEX, VVVV, T8,XS; -    def TCONJTCMMIMFP16PS : I<0x6b, MRMSrcReg4VOp3, (outs TILE:$dst), -                          (ins TILE:$src1, TILE:$src2, TILE:$src3), -                          "tconjtcmmimfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", -                          []>, VEX, VVVV, WIG, T8,PS; -  } -  def TCONJTFP16 : I<0x6b, MRMSrcReg, (outs TILE:$dst), (ins TILE:$src), -                     "tconjtfp16\t{$src, $dst|$dst, $src}", []>, VEX, T8,PD; - -  let Constraints = "$src4 = $dst" in { -    def PTTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                  GR16:$src2, GR16:$src3, TILE:$src4, -                                  TILE:$src5, TILE:$src6), -                                  [(set TILE: $dst, -                                    (int_x86_ttcmmimfp16ps_internal GR16:$src1, GR16:$src2, -                                     GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -    def PTTCMMRLFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                  GR16:$src2, GR16:$src3, TILE:$src4, -                                  TILE:$src5, TILE:$src6), -                                  [(set TILE: $dst, -                                    (int_x86_ttcmmrlfp16ps_internal GR16:$src1, GR16:$src2, -                                     GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -    def PTCONJTCMMIMFP16PSV : PseudoI<(outs TILE:$dst), (ins GR16:$src1, -                                      GR16:$src2, GR16:$src3, TILE:$src4, -                                      TILE:$src5, TILE:$src6), -                                      [(set TILE: $dst, -                                        (int_x86_tconjtcmmimfp16ps_internal GR16:$src1, GR16:$src2, -                                         GR16:$src3, TILE:$src4, TILE:$src5, TILE:$src6))]>; -  } -  def PTCONJTFP16V : PseudoI<(outs TILE:$dst), (ins GR16:$src1, GR16:$src2, TILE:$src3), -                             [(set TILE: $dst, (int_x86_tconjtfp16_internal GR16:$src1, GR16:$src2, TILE:$src3))]>; - -  let usesCustomInserter = 1 in { -    def PTTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                 [(int_x86_ttcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTTCMMRLFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                 [(int_x86_ttcmmrlfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTCONJTCMMIMFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                     [(int_x86_tconjtcmmimfp16ps timm:$src1, timm:$src2, timm:$src3)]>; -    def PTCONJTFP16 : PseudoI<(outs), (ins u8imm:$dst, u8imm:$src), -                              [(int_x86_tconjtfp16 timm:$dst, timm:$src)]>; -  } -} - -let Predicates = [HasAMXMOVRS, HasAMXTRANSPOSE, In64BitMode], SchedRW = [WriteSystem] in { -  let isPseudo = true in { -    def PT2RPNTLVWZ0RSV   : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ0RST1V : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ1RSV   : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -    def PT2RPNTLVWZ1RST1V : PseudoI<(outs TILEPair:$dst), -                              (ins GR16:$src1, GR16:$src2, GR16:$src3, opaquemem:$src4), -                              []>; -  } -  let  usesCustomInserter = 1 in { -    def PT2RPNTLVWZ0RS   : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ0RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ1RS   : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -    def PT2RPNTLVWZ1RST1 : PseudoI<(outs), (ins u8imm:$dst, sibmem:$src1), []>; -  } -} // HasAMXMOVRS, HasAMXTRANSPOSE -  multiclass TILELOADDRS_Base<string suffix> {    def suffix    : I<0x4a, MRMSrcMemFSIB, (outs TILE:$dst), (ins sibmem:$src1),                      "tileloaddrs\t{$src1, $dst|$dst, $src1}", []>, T8, XD; @@ -721,29 +539,3 @@ let Predicates = [HasAMXTF32, In64BitMode] in {      }    } // SchedRW = [WriteSystem]  } // HasAMXTF32 - -let Predicates = [HasAMXTF32, HasAMXTRANSPOSE, In64BitMode] in { -  let SchedRW = [WriteSystem] in { -    let Constraints = "$src1 = $dst" in { -      def TTMMULTF32PS: I<0x48, MRMSrcReg4VOp3, (outs TILE:$dst), -                         (ins TILE:$src1, TILE:$src2, TILE:$src3), -                         "ttmmultf32ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", -                         []>, VEX, VVVV, T8, PS; -    } -    let Constraints = "$src4 = $dst" in { -      def PTTMMULTF32PSV : PseudoI<(outs TILE:$dst), -                                   (ins GR16:$src1, GR16:$src2, GR16:$src3, -                                    TILE:$src4, TILE:$src5, TILE:$src6), -                                   [(set TILE:$dst, -                                     (int_x86_ttmmultf32ps_internal GR16:$src1, -                                      GR16:$src2, GR16:$src3, TILE:$src4, -                                      TILE:$src5, TILE:$src6))]>; -    } -    let usesCustomInserter = 1 in { -      def PTTMMULTF32PS : PseudoI<(outs), -                                  (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), -                                  [(int_x86_ttmmultf32ps timm:$src1, timm:$src2, -                                    timm:$src3)]>; -    } -  } // SchedRW = [WriteSystem] -} // HasAMXTF32, HasAMXTRANSPOSE diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 5c23f91..6b2a7a4 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -4544,11 +4544,6 @@ static unsigned getLoadStoreRegOpcode(Register Reg,      return Load ? GET_EGPR_IF_ENABLED(X86::TILELOADD)                  : GET_EGPR_IF_ENABLED(X86::TILESTORED);  #undef GET_EGPR_IF_ENABLED -  case 2048: -    assert(X86::TILEPAIRRegClass.hasSubClassEq(RC) && -           "Unknown 2048-byte regclass"); -    assert(STI.hasAMXTILE() && "Using 2048-bit register requires AMX-TILE"); -    return Load ? X86::PTILEPAIRLOAD : X86::PTILEPAIRSTORE;    }  } @@ -4743,8 +4738,6 @@ static bool isAMXOpcode(unsigned Opc) {    case X86::TILESTORED:    case X86::TILELOADD_EVEX:    case X86::TILESTORED_EVEX: -  case X86::PTILEPAIRLOAD: -  case X86::PTILEPAIRSTORE:      return true;    }  } @@ -4757,8 +4750,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,    default:      llvm_unreachable("Unexpected special opcode!");    case X86::TILESTORED: -  case X86::TILESTORED_EVEX: -  case X86::PTILEPAIRSTORE: { +  case X86::TILESTORED_EVEX: {      // tilestored %tmm, (%sp, %idx)      MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();      Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); @@ -4772,8 +4764,7 @@ void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB,      break;    }    case X86::TILELOADD: -  case X86::TILELOADD_EVEX: -  case X86::PTILEPAIRLOAD: { +  case X86::TILELOADD_EVEX: {      // tileloadd (%sp, %idx), %tmm      MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();      Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); diff --git a/llvm/lib/Target/X86/X86InstrOperands.td b/llvm/lib/Target/X86/X86InstrOperands.td index 5207eca..6ba07f7 100644 --- a/llvm/lib/Target/X86/X86InstrOperands.td +++ b/llvm/lib/Target/X86/X86InstrOperands.td @@ -536,10 +536,3 @@ def VK8Pair : RegisterOperand<VK8PAIR, "printVKPair"> {  def VK16Pair : RegisterOperand<VK16PAIR, "printVKPair"> {    let ParserMatchClass = VK16PairAsmOperand;  } - -let RenderMethod = "addTILEPairOperands" in -  def TILEPairAsmOperand : AsmOperandClass { let Name = "TILEPair"; } - -def TILEPair : RegisterOperand<TILEPAIR, "printTILEPair"> { -  let ParserMatchClass = TILEPairAsmOperand; -} diff --git a/llvm/lib/Target/X86/X86InstrPredicates.td b/llvm/lib/Target/X86/X86InstrPredicates.td index c20bb05..98104a6f 100644 --- a/llvm/lib/Target/X86/X86InstrPredicates.td +++ b/llvm/lib/Target/X86/X86InstrPredicates.td @@ -183,7 +183,6 @@ def HasAMXINT8   : Predicate<"Subtarget->hasAMXINT8()">;  def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;  def HasAMXFP8    : Predicate<"Subtarget->hasAMXFP8()">;  def HasAMXMOVRS  : Predicate<"Subtarget->hasAMXMOVRS()">; -def HasAMXTRANSPOSE : Predicate<"Subtarget->hasAMXTRANSPOSE()">;  def HasAMXAVX512 : Predicate<"Subtarget->hasAMXAVX512()">;  def HasAMXTF32   : Predicate<"Subtarget->hasAMXTF32()">;  def HasUINTR     : Predicate<"Subtarget->hasUINTR()">; diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 8ffd454..2fc5d38 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -74,22 +74,6 @@ static bool isAMXCast(Instruction *II) {           match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));  } -// Some instructions may return more than one tiles. -// e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal -static unsigned getNumDefTiles(IntrinsicInst *II) { -  Type *Ty = II->getType(); -  if (Ty->isX86_AMXTy()) -    return 1; - -  unsigned Num = 0; -  for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) { -    Type *STy = Ty->getContainedType(i); -    if (STy->isX86_AMXTy()) -      Num++; -  } -  return Num; -} -  static bool isAMXIntrinsic(Value *I) {    auto *II = dyn_cast<IntrinsicInst>(I);    if (!II) @@ -98,7 +82,7 @@ static bool isAMXIntrinsic(Value *I) {      return false;    // Check if return type or parameter is x86_amx. If it is x86_amx    // the intrinsic must be x86 amx intrinsics. -  if (getNumDefTiles(II) > 0) +  if (II->getType()->isX86_AMXTy())      return true;    for (Value *V : II->args()) {      if (V->getType()->isX86_AMXTy()) @@ -137,27 +121,7 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {    llvm_unreachable("No terminator in the entry block!");  } -class ShapeCalculator { -private: -  const TargetMachine *TM = nullptr; - -  // In AMX intrinsics we let Shape = {Row, Col}, but the -  // RealCol = Col / ElementSize. We may use the RealCol -  // as a new Row for other new created AMX intrinsics. -  std::map<Value *, Value *> Col2Row, Row2Col; - -public: -  ShapeCalculator(const TargetMachine *TargetM) : TM(TargetM) {} -  std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); -  std::pair<Value *, Value *> getShape(PHINode *Phi); -  Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); -  Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity); -}; - -Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, -                                      unsigned Granularity) { -  if (auto It = Col2Row.find(V); It != Col2Row.end()) -    return It->second; +static Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity) {    IRBuilder<> Builder(II);    Value *RealRow = nullptr;    if (isa<ConstantInt>(V)) @@ -186,47 +150,16 @@ Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,          getFirstNonAllocaInTheEntryBlock(*II->getFunction()));      RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));    } -  Col2Row[V] = RealRow;    return RealRow;  } -Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V, -                                      unsigned Granularity) { -  if (auto It = Row2Col.find(V); It != Row2Col.end()) -    return It->second; -  IRBuilder<> Builder(II); -  Value *RealCol = nullptr; -  if (isa<ConstantInt>(V)) -    RealCol = -        Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity); -  else if (isa<Instruction>(V)) { -    Builder.SetInsertPoint(cast<Instruction>(V)); -    RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity)); -    cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V)); -  } else { -    // When it is not a const value and it is a function argument, we create -    // Row at the entry bb. -    IRBuilder<> NewBuilder( -        getFirstNonAllocaInTheEntryBlock(*II->getFunction())); -    RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity)); -  } -  Row2Col[V] = RealCol; -  return RealCol; -} -  // TODO: Refine the row and col-in-bytes of tile to row and col of matrix. -std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, -                                                      unsigned OpNo) { -  (void)TM; +std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {    IRBuilder<> Builder(II);    Value *Row = nullptr, *Col = nullptr;    switch (II->getIntrinsicID()) {    default:      llvm_unreachable("Expect amx intrinsics"); -  case Intrinsic::x86_t2rpntlvwz0_internal: -  case Intrinsic::x86_t2rpntlvwz0t1_internal: -  case Intrinsic::x86_t2rpntlvwz1_internal: -  case Intrinsic::x86_t2rpntlvwz1t1_internal:    case Intrinsic::x86_tileloadd64_internal:    case Intrinsic::x86_tileloaddt164_internal:    case Intrinsic::x86_tilestored64_internal: @@ -271,13 +204,6 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,      }      break;    } -  case Intrinsic::x86_ttransposed_internal: -  case Intrinsic::x86_tconjtfp16_internal: { -    assert((OpNo == 2) && "Illegal Operand Number."); -    Row = getRowFromCol(II, II->getArgOperand(1), 4); -    Col = getColFromRow(II, II->getArgOperand(0), 4); -    break; -  }    case Intrinsic::x86_tcvtrowd2ps_internal:    case Intrinsic::x86_tcvtrowps2bf16h_internal:    case Intrinsic::x86_tcvtrowps2bf16l_internal: @@ -289,34 +215,12 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,      Col = II->getArgOperand(1);      break;    } -  case Intrinsic::x86_ttdpbf16ps_internal: -  case Intrinsic::x86_ttdpfp16ps_internal: -  case Intrinsic::x86_ttcmmimfp16ps_internal: -  case Intrinsic::x86_ttcmmrlfp16ps_internal: -  case Intrinsic::x86_tconjtcmmimfp16ps_internal: -  case Intrinsic::x86_ttmmultf32ps_internal: { -    switch (OpNo) { -    case 3: -      Row = II->getArgOperand(0); -      Col = II->getArgOperand(1); -      break; -    case 4: -      Row = getRowFromCol(II, II->getArgOperand(2), 4); -      Col = getColFromRow(II, II->getArgOperand(0), 4); -      break; -    case 5: -      Row = getRowFromCol(II, II->getArgOperand(2), 4); -      Col = II->getArgOperand(1); -      break; -    } -    break; -  }    }    return std::make_pair(Row, Col);  } -std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) { +static std::pair<Value *, Value *> getShape(PHINode *Phi) {    Use &U = *(Phi->use_begin());    unsigned OpNo = U.getOperandNo();    User *V = U.getUser(); @@ -349,15 +253,14 @@ std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {  namespace {  class X86LowerAMXType {    Function &Func; -  ShapeCalculator *SC;    // In AMX intrinsics we let Shape = {Row, Col}, but the    // RealCol = Col / ElementSize. We may use the RealCol    // as a new Row for other new created AMX intrinsics. -  std::map<Value *, Value *> Col2Row, Row2Col; +  std::map<Value *, Value *> Col2Row;  public: -  X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {} +  X86LowerAMXType(Function &F) : Func(F) {}    bool visit();    void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);    void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); @@ -374,7 +277,7 @@ void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {    Use &U = *(Bitcast->use_begin());    unsigned OpNo = U.getOperandNo();    auto *II = cast<IntrinsicInst>(U.getUser()); -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(Bitcast);    // Use the maximun column as stride.    Value *Stride = Builder.getInt64(64); @@ -454,7 +357,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {      Builder.CreateStore(Src, AllocaAddr);      // TODO we can pick an constant operand for the shape.      Value *Row = nullptr, *Col = nullptr; -    std::tie(Row, Col) = SC->getShape(II, OpNo); +    std::tie(Row, Col) = getShape(II, OpNo);      std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};      Value *NewInst =          Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, Args); @@ -594,18 +497,11 @@ static Value *getAllocaPos(BasicBlock *BB) {  static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {    assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); -  auto *II = dyn_cast<IntrinsicInst>(TileDef); -  unsigned Idx = 0; -  // Extract tile from multiple tiles' def. -  if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) { -    assert(Extr->hasIndices() && "Tile extract miss index!"); -    Idx = Extr->getIndices()[0]; -    II = cast<IntrinsicInst>(Extr->getOperand(0)); -  } +  auto *II = cast<IntrinsicInst>(TileDef);    assert(II && "Not tile intrinsic!"); -  Value *Row = II->getOperand(Idx); -  Value *Col = II->getOperand(Idx + 1); +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1);    BasicBlock *BB = TileDef->getParent();    BasicBlock::iterator Iter = TileDef->getIterator(); @@ -624,20 +520,14 @@ static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {    // Get tile shape.    IntrinsicInst *II = nullptr; -  unsigned Idx = 0;    if (IsPHI) {      Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);      II = cast<IntrinsicInst>(PhiOp); -  } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) { -    // Extract tile from multiple tiles' def. -    assert(Extr->hasIndices() && "Tile extract miss index!"); -    Idx = Extr->getIndices()[0]; -    II = cast<IntrinsicInst>(Extr->getOperand(0));    } else {      II = cast<IntrinsicInst>(V);    } -  Value *Row = II->getOperand(Idx); -  Value *Col = II->getOperand(Idx + 1); +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1);    Instruction *UserI = cast<Instruction>(U.getUser());    IRBuilder<> Builder(UserI); @@ -848,12 +738,10 @@ namespace {  class X86LowerAMXCast {    Function &Func; -  ShapeCalculator *SC;    std::unique_ptr<DominatorTree> DT;  public: -  X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC) -      : Func(F), SC(ShapeC), DT(nullptr) {} +  X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}    bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);    bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);    bool combineTilezero(IntrinsicInst *Cast); @@ -932,7 +820,7 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(          if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())            return false;          Value *Row = nullptr, *Col = nullptr; -        std::tie(Row, Col) = SC->getShape(OldPN); +        std::tie(Row, Col) = getShape(OldPN);          // TODO: If it is not constant the Row and Col must domoniate tilezero          // that we are going to create.          if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) @@ -1063,19 +951,6 @@ bool X86LowerAMXCast::optimizeAMXCastFromPhi(    return true;  } -static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx, -                                       bool IsRow) { -  if (!isAMXIntrinsic(Inst)) -    return nullptr; - -  auto *II = cast<IntrinsicInst>(Inst); -  if (IsRow) -    return II->getOperand(0); - -  assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!"); -  return II->getOperand(ShapeIdx + 1); -} -  // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)  // store <256 x i32> %43, <256 x i32>* %p, align 64  // --> @@ -1090,38 +965,13 @@ bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {    if (!Tile->hasOneUse())      return false; -  // We don't fetch shape from tilestore, we only get shape from tiledef, -  // so we can set the max tile shape to tilestore for special cases. +  auto *II = cast<IntrinsicInst>(Tile); +  // Tile is output from AMX intrinsic. The first operand of the +  // intrinsic is row, the second operand of the intrinsic is column. +  Value *Row = II->getOperand(0); +  Value *Col = II->getOperand(1); +    IRBuilder<> Builder(ST); -  Value *Row = nullptr; -  Value *Col = nullptr; - -  if (isAMXIntrinsic(Tile)) { -    auto *II = cast<IntrinsicInst>(Tile); -    // Tile is output from AMX intrinsic. The first operand of the -    // intrinsic is row, the second operand of the intrinsic is column. -    Row = II->getOperand(0); -    Col = II->getOperand(1); -  } else { -    // Now we supported multi-tiles value in structure, so we may get tile -    // from extracting multi-tiles structure. -    // For example: -    // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1, -    //      i16 %2, i16 %3, i8* %4, i64 %5) -    // %7 = extractvalue { x86_amx, x86_amx } %6, 0 -    // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7) -    // store <256 x i32> %8, <256 x i32>* %0, align 1024 -    // -    // TODO: Currently we only handle extractvalue case, enhance me for other -    // cases if possible. -    auto *II = cast<ExtractValueInst>(Tile); -    assert(II && "We meet unhandle source in fetching tile value!"); -    unsigned ShapeIdx = II->getIndices()[0]; -    Value *Tiles = II->getOperand(0); -    Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true); -    Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false); -  } -  assert(Row && Col && "Shape got failed!");    // Stride should be equal to col(measured by bytes)    Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1146,7 +996,7 @@ bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {    // shape information through def-use chain.    if (!isAMXIntrinsic(II))      return false; -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(LD);    // Stride should be equal to col(measured by bytes)    Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); @@ -1189,7 +1039,7 @@ bool X86LowerAMXCast::combineTilezero(IntrinsicInst *Cast) {    if (!isAMXIntrinsic(II))      return false; -  std::tie(Row, Col) = SC->getShape(II, OpNo); +  std::tie(Row, Col) = getShape(II, OpNo);    IRBuilder<> Builder(Cast);    Value *NewInst = @@ -1384,7 +1234,7 @@ bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {      Builder.CreateStore(Src, AllocaAddr);      // TODO we can pick an constant operand for the shape.      Value *Row = nullptr, *Col = nullptr; -    std::tie(Row, Col) = SC->getShape(II, OpNo); +    std::tie(Row, Col) = getShape(II, OpNo);      std::array<Value *, 4> Args = {          Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};      Value *NewInst = @@ -1445,14 +1295,13 @@ bool lowerAmxType(Function &F, const TargetMachine *TM,      return false;    bool C = false; -  ShapeCalculator SC(TM); -  X86LowerAMXCast LAC(F, &SC); +  X86LowerAMXCast LAC(F);    C |= LAC.combineAMXcast(TLI);    // There might be remaining AMXcast after combineAMXcast and they should be    // handled elegantly.    C |= LAC.transformAllAMXCast(); -  X86LowerAMXType LAT(F, &SC); +  X86LowerAMXType LAT(F);    C |= LAT.visit();    // Prepare for fast register allocation at O0. diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp index 2a1c499..8a1d00d 100644 --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -141,15 +141,10 @@ class X86PreTileConfig : public MachineFunctionPass {      if (!MO.isReg() || !MO.getReg().isVirtual())        return false; -    unsigned Shapes = 0; -    if (MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) -      Shapes = 1; -    if (MRI->getRegClass(MO.getReg())->getID() == X86::TILEPAIRRegClassID) -      Shapes = 2; -    if (!Shapes) +    if (MRI->getRegClass(MO.getReg())->getID() != X86::TILERegClassID)        return false; -    collectShapeInfo(MI, Shapes); +    collectShapeInfo(MI);      return true;    } @@ -165,7 +160,7 @@ class X86PreTileConfig : public MachineFunctionPass {    }    /// Collect the shape def information for later use. -  void collectShapeInfo(MachineInstr &MI, unsigned Shapes); +  void collectShapeInfo(MachineInstr &MI);    /// Try to hoist shapes definded below AMX instructions.    bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { @@ -231,7 +226,7 @@ INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)  INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",                      "Tile Register Pre-configure", false, false) -void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) { +void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {    auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {      MIRef MIR(MI, MBB);      auto &Refs = ShapeBBs[MBB]; @@ -240,10 +235,8 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {        Refs.insert(I, MIR);    }; -  // All shapes have same row in multi-tile operand. -  SmallVector<Register, 8> WorkList; -  for (unsigned I = 1; I < Shapes + 2; ++I) -    WorkList.push_back(MI.getOperand(I).getReg()); +  SmallVector<Register, 8> WorkList( +      {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});    while (!WorkList.empty()) {      Register R = WorkList.pop_back_val();      MachineInstr *DefMI = MRI->getVRegDef(R); @@ -252,13 +245,6 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {      if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)        continue; -    // This happens when column = 0 in multi-tile operand. -    if (DefMI->getOpcode() == X86::COPY) { -      MachineInstr *MI = MRI->getVRegDef(DefMI->getOperand(1).getReg()); -      if (MI && MI->isMoveImmediate()) -        continue; -    } -      if (DefMI->isPHI()) {        for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)          if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp index 76979e3..72f3813 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -597,10 +597,6 @@ BitVector X86RegisterInfo::getReservedRegs(const MachineFunction &MF) const {        Reserved.set(*AI);    } -  // Reserve low half pair registers in case they are used by RA aggressively. -  Reserved.set(X86::TMM0_TMM1); -  Reserved.set(X86::TMM2_TMM3); -    assert(checkAllSuperRegsMarked(Reserved,                                   {X86::SIL, X86::DIL, X86::BPL, X86::SPL,                                    X86::SIH, X86::DIH, X86::BPH, X86::SPH})); @@ -621,7 +617,7 @@ unsigned X86RegisterInfo::getNumSupportedRegs(const MachineFunction &MF) const {    // and try to return the minimum number of registers supported by the target.    static_assert((X86::R15WH + 1 == X86::YMM0) && (X86::YMM15 + 1 == X86::K0) &&                      (X86::K6_K7 + 1 == X86::TMMCFG) && -                    (X86::TMM6_TMM7 + 1 == X86::R16) && +                    (X86::TMM7 + 1 == X86::R16) &&                      (X86::R31WH + 1 == X86::NUM_TARGET_REGS),                  "Register number may be incorrect"); @@ -694,8 +690,7 @@ bool X86RegisterInfo::isFixedRegister(const MachineFunction &MF,  }  bool X86RegisterInfo::isTileRegisterClass(const TargetRegisterClass *RC) const { -  return RC->getID() == X86::TILERegClassID || -         RC->getID() == X86::TILEPAIRRegClassID; +  return RC->getID() == X86::TILERegClassID;  }  void X86RegisterInfo::adjustStackMapLiveOutMask(uint32_t *Mask) const { @@ -1062,17 +1057,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,    case X86::PTDPFP16PSV:    case X86::PTCMMIMFP16PSV:    case X86::PTCMMRLFP16PSV: -  case X86::PTTRANSPOSEDV: -  case X86::PTTDPBF16PSV: -  case X86::PTTDPFP16PSV: -  case X86::PTTCMMIMFP16PSV: -  case X86::PTTCMMRLFP16PSV: -  case X86::PTCONJTCMMIMFP16PSV: -  case X86::PTCONJTFP16V:    case X86::PTILELOADDRSV:    case X86::PTILELOADDRST1V:    case X86::PTMMULTF32PSV: -  case X86::PTTMMULTF32PSV:    case X86::PTDPBF8PSV:    case X86::PTDPBHF8PSV:    case X86::PTDPHBF8PSV: @@ -1083,56 +1070,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,      VRM->assignVirt2Shape(VirtReg, Shape);      return Shape;    } -  case X86::PT2RPNTLVWZ0V: -  case X86::PT2RPNTLVWZ0T1V: -  case X86::PT2RPNTLVWZ1V: -  case X86::PT2RPNTLVWZ1T1V: -  case X86::PT2RPNTLVWZ0RSV: -  case X86::PT2RPNTLVWZ0RST1V: -  case X86::PT2RPNTLVWZ1RSV: -  case X86::PT2RPNTLVWZ1RST1V: { -    MachineOperand &MO1 = MI->getOperand(1); -    MachineOperand &MO2 = MI->getOperand(2); -    MachineOperand &MO3 = MI->getOperand(3); -    ShapeT Shape({&MO1, &MO2, &MO1, &MO3}, MRI); -    VRM->assignVirt2Shape(VirtReg, Shape); -    return Shape; -  } -  } -} - -static bool canHintShape(ShapeT &PhysShape, ShapeT &VirtShape) { -  unsigned PhysShapeNum = PhysShape.getShapeNum(); -  unsigned VirtShapeNum = VirtShape.getShapeNum(); - -  if (PhysShapeNum < VirtShapeNum) -    return false; - -  if (PhysShapeNum == VirtShapeNum) { -    if (PhysShapeNum == 1) -      return PhysShape == VirtShape; - -    for (unsigned I = 0; I < PhysShapeNum; I++) { -      ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); -      ShapeT VShape(VirtShape.getRow(I), VirtShape.getCol(I)); -      if (VShape != PShape) -        return false; -    } -    return true; -  } - -  // Hint subreg of mult-tile reg to single tile reg. -  if (VirtShapeNum == 1) { -    for (unsigned I = 0; I < PhysShapeNum; I++) { -      ShapeT PShape(PhysShape.getRow(I), PhysShape.getCol(I)); -      if (VirtShape == PShape) -        return true; -    }    } - -  // Note: Currently we have no requirement for case of -  // (VirtShapeNum > 1 and PhysShapeNum > VirtShapeNum) -  return false;  }  bool X86RegisterInfo::getRegAllocationHints(Register VirtReg, @@ -1153,7 +1091,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,    if (!VRM)      return BaseImplRetVal; -  if (ID != X86::TILERegClassID && ID != X86::TILEPAIRRegClassID) { +  if (ID != X86::TILERegClassID) {      if (DisableRegAllocNDDHints || !ST.hasNDD() ||          !TRI.isGeneralPurposeRegisterClass(&RC))        return BaseImplRetVal; @@ -1204,7 +1142,7 @@ bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,        return;      }      ShapeT PhysShape = getTileShape(VReg, const_cast<VirtRegMap *>(VRM), MRI); -    if (canHintShape(PhysShape, VirtShape)) +    if (PhysShape == VirtShape)        Hints.push_back(PhysReg);    }; diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td index 99b7910..692e42a 100644 --- a/llvm/lib/Target/X86/X86RegisterInfo.td +++ b/llvm/lib/Target/X86/X86RegisterInfo.td @@ -30,8 +30,6 @@ let Namespace = "X86" in {    def sub_ymm      : SubRegIndex<256>;    def sub_mask_0   : SubRegIndex<-1>;    def sub_mask_1   : SubRegIndex<-1, -1>; -  def sub_t0       : SubRegIndex<8192>; -  def sub_t1       : SubRegIndex<8192, 8192>;  }  //===----------------------------------------------------------------------===// @@ -432,10 +430,6 @@ def TMM4:  X86Reg<"tmm4",   4>;  def TMM5:  X86Reg<"tmm5",   5>;  def TMM6:  X86Reg<"tmm6",   6>;  def TMM7:  X86Reg<"tmm7",   7>; -// TMM register pairs -def TPAIRS : RegisterTuples<[sub_t0, sub_t1], -                            [(add TMM0, TMM2, TMM4, TMM6), -                             (add TMM1, TMM3, TMM5, TMM7)]>;  }  // Floating point stack registers. These don't map one-to-one to the FP @@ -862,9 +856,6 @@ def VK64WM  : RegisterClass<"X86", [v64i1], 64, (add VK32WM)> {let Size = 64;}  let CopyCost = -1 in // Don't allow copying of tile registers  def TILE : RegisterClass<"X86", [x86amx], 8192,                           (sequence "TMM%u", 0, 7)> {let Size = 8192;} -// Need check alignment 3rd operand size=1024*2*8 -let isAllocatable = 1 in -def TILEPAIR : RegisterClass<"X86", [untyped], 512, (add TPAIRS)> {let Size = 16384;}  //===----------------------------------------------------------------------===//  // Register categories. diff --git a/llvm/lib/Target/X86/X86TileConfig.cpp b/llvm/lib/Target/X86/X86TileConfig.cpp index 17a44dd..09ef8fb 100644 --- a/llvm/lib/Target/X86/X86TileConfig.cpp +++ b/llvm/lib/Target/X86/X86TileConfig.cpp @@ -74,63 +74,6 @@ INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)  INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,                      false) -unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) { -  if (Reg.isVirtual()) { -    unsigned RegClassID = MRI->getRegClass(Reg)->getID(); -    if (RegClassID == X86::TILERegClassID) -      return 1; -    if (RegClassID == X86::TILEPAIRRegClassID) -      return 2; -  } else { -    if (Reg >= X86::TMM0 && Reg <= X86::TMM7) -      return 1; -    if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) -      return 2; -  } -  return 0; -} - -static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM, -                                 Register VirtReg, -                                 SmallVector<ShapeT, 8> &Phys2Shapes) { -  unsigned Num = getAMXRegNum(MRI, VirtReg); -  MCRegister PhysReg = VRM.getPhys(VirtReg); -  if (!PhysReg) -    return; - -  if (Num == 1) { -    unsigned Index = PhysReg - X86::TMM0; -    if (!Phys2Shapes[Index].isValid()) { -      ShapeT Shape = VRM.getShape(VirtReg); -      Phys2Shapes[Index] = std::move(Shape); -      return; -    } -  } -  // Split tile pair shape info to 2 single tile shape info. e.g: -  // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes. -  if (Num == 2) { -    unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2; -    unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1; - -    ShapeT Shape = VRM.getShape(VirtReg); -    assert(Shape.getShapeNum() == 2 && "Unexpected shape number!"); - -    if (!Phys2Shapes[Index0].isValid()) { -      ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI); -      Phys2Shapes[Index0] = std::move(Shape0); -    } - -    if (!Phys2Shapes[Index1].isValid()) { -      ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI); -      Phys2Shapes[Index1] = std::move(Shape1); -    } -  } -} - -static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) { -  return getAMXRegNum(MRI, Reg) > 0; -} -  bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {    X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();    // Early exit in the common case of non-AMX code. @@ -138,7 +81,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {      return false;    const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); -  const TargetRegisterInfo *TRI = ST.getRegisterInfo(); +  const X86RegisterInfo *TRI = ST.getRegisterInfo();    const TargetInstrInfo *TII = ST.getInstrInfo();    MachineRegisterInfo &MRI = MF.getRegInfo();    LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); @@ -176,24 +119,29 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {    assert(ConstMI && "Cannot find an insertion point");    unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs(); -  SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT()); +  SmallVector<Register, 8> Phys2Virt(AMXRegNum, 0);    for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {      Register VirtReg = Register::index2VirtReg(I);      if (MRI.reg_nodbg_empty(VirtReg))        continue; -    if (!isAMXRegClass(&MRI, VirtReg)) +    if (!TRI->isTileRegisterClass(MRI.getRegClass(VirtReg))) +      continue; +    MCRegister PhysReg = VRM.getPhys(VirtReg); +    if (!PhysReg)        continue; -    collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes); +    unsigned Index = PhysReg - X86::TMM0; +    if (!Phys2Virt[Index]) +      Phys2Virt[Index] = VirtReg;    }    // Fill in the shape of each tile physical register.    for (unsigned I = 0; I < AMXRegNum; ++I) { -    ShapeT Shape = Phys2Shapes[I]; -    if (!Shape.isValid()) +    if (!Phys2Virt[I])        continue;      DebugLoc DL;      bool IsRow = true;      MachineInstr *NewMI = nullptr; +    ShapeT Shape = VRM.getShape(Phys2Virt[I]);      for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {        // Here is the data format for the tile config.        // 0      palette @@ -222,14 +170,7 @@ bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {                     "Cannot initialize with different shapes");              continue;            } -          if (DefMI.getOperand(1).isImm()) { -            Imm = DefMI.getOperand(1).getImm(); -          } else { -            assert(DefMI.getOpcode() == X86::MOV32r0 && -                   "The opcode is assumed to be MOV32r0 if the operand is not " -                   "immediate."); -            Imm = 0; -          } +          Imm = DefMI.getOperand(1).getImm();            NewMI = addFrameReference(                        BuildMI(MF.front(), ++ConstMI->getIterator(), DL, diff --git a/llvm/lib/TargetParser/Host.cpp b/llvm/lib/TargetParser/Host.cpp index 0849fc7..c164762 100644 --- a/llvm/lib/TargetParser/Host.cpp +++ b/llvm/lib/TargetParser/Host.cpp @@ -2192,7 +2192,6 @@ StringMap<bool> sys::getHostCPUFeatures() {    bool HasLeaf1E = MaxLevel >= 0x1e &&                     !getX86CpuIDAndInfoEx(0x1e, 0x1, &EAX, &EBX, &ECX, &EDX);    Features["amx-fp8"] = HasLeaf1E && ((EAX >> 4) & 1) && HasAMXSave; -  Features["amx-transpose"] = HasLeaf1E && ((EAX >> 5) & 1) && HasAMXSave;    Features["amx-tf32"] = HasLeaf1E && ((EAX >> 6) & 1) && HasAMXSave;    Features["amx-avx512"] = HasLeaf1E && ((EAX >> 7) & 1) && HasAMXSave;    Features["amx-movrs"] = HasLeaf1E && ((EAX >> 8) & 1) && HasAMXSave; diff --git a/llvm/lib/TargetParser/TargetDataLayout.cpp b/llvm/lib/TargetParser/TargetDataLayout.cpp index d765d9c..d735923 100644 --- a/llvm/lib/TargetParser/TargetDataLayout.cpp +++ b/llvm/lib/TargetParser/TargetDataLayout.cpp @@ -208,7 +208,7 @@ static std::string computeMipsDataLayout(const Triple &TT, StringRef ABIName) {    return Ret;  } -static std::string computePowerDataLayout(const Triple &T) { +static std::string computePowerDataLayout(const Triple &T, StringRef ABIName) {    bool is64Bit = T.isPPC64();    std::string Ret; @@ -228,7 +228,8 @@ static std::string computePowerDataLayout(const Triple &T) {    // If the target ABI uses function descriptors, then the alignment of function    // pointers depends on the alignment used to emit the descriptor. Otherwise,    // function pointers are aligned to 32 bits because the instructions must be. -  if ((T.getArch() == Triple::ppc64 && !T.isPPC64ELFv2ABI())) { +  if ((T.getArch() == Triple::ppc64 && +       (!T.isPPC64ELFv2ABI() && ABIName != "elfv2"))) {      Ret += "-Fi64";    } else if (T.isOSAIX()) {      Ret += is64Bit ? "-Fi64" : "-Fi32"; @@ -573,7 +574,7 @@ std::string Triple::computeDataLayout(StringRef ABIName) const {    case Triple::ppcle:    case Triple::ppc64:    case Triple::ppc64le: -    return computePowerDataLayout(*this); +    return computePowerDataLayout(*this, ABIName);    case Triple::r600:    case Triple::amdgcn:      return computeAMDDataLayout(*this); diff --git a/llvm/lib/TargetParser/X86TargetParser.cpp b/llvm/lib/TargetParser/X86TargetParser.cpp index b13c795..37e8ad9 100644 --- a/llvm/lib/TargetParser/X86TargetParser.cpp +++ b/llvm/lib/TargetParser/X86TargetParser.cpp @@ -143,7 +143,7 @@ constexpr FeatureBitset FeaturesDiamondRapids =      FeatureAVXVNNIINT8 | FeatureAVXVNNIINT16 | FeatureSHA512 | FeatureSM3 |      FeatureSM4 | FeatureEGPR | FeatureZU | FeatureCCMP | FeaturePush2Pop2 |      FeaturePPX | FeatureNDD | FeatureNF | FeatureMOVRS | FeatureAMX_MOVRS | -    FeatureAMX_AVX512 | FeatureAMX_FP8 | FeatureAMX_TF32 | FeatureAMX_TRANSPOSE; +    FeatureAMX_AVX512 | FeatureAMX_FP8 | FeatureAMX_TF32;  // Intel Atom processors.  // Bonnell has feature parity with Core2 and adds MOVBE. @@ -615,7 +615,6 @@ constexpr FeatureBitset ImpliedFeaturesAMX_FP16 = FeatureAMX_TILE;  constexpr FeatureBitset ImpliedFeaturesAMX_INT8 = FeatureAMX_TILE;  constexpr FeatureBitset ImpliedFeaturesAMX_COMPLEX = FeatureAMX_TILE;  constexpr FeatureBitset ImpliedFeaturesAMX_FP8 = FeatureAMX_TILE; -constexpr FeatureBitset ImpliedFeaturesAMX_TRANSPOSE = FeatureAMX_TILE;  constexpr FeatureBitset ImpliedFeaturesAMX_MOVRS = FeatureAMX_TILE;  constexpr FeatureBitset ImpliedFeaturesAMX_AVX512 =      FeatureAMX_TILE | FeatureAVX10_2; diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 5ba2167..cc53ec2 100644 --- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -1957,8 +1957,12 @@ Value *DataFlowSanitizer::getShadowAddress(Value *Addr,  Value *DataFlowSanitizer::getShadowAddress(Value *Addr,                                             BasicBlock::iterator Pos) {    IRBuilder<> IRB(Pos->getParent(), Pos); -  Value *ShadowOffset = getShadowOffset(Addr, IRB); -  return getShadowAddress(Addr, Pos, ShadowOffset); +  Value *ShadowAddr = getShadowOffset(Addr, IRB); +  uint64_t ShadowBase = MapParams->ShadowBase; +  if (ShadowBase != 0) +    ShadowAddr = +        IRB.CreateAdd(ShadowAddr, ConstantInt::get(IntptrTy, ShadowBase)); +  return getShadowAddress(Addr, Pos, ShadowAddr);  }  Value *DFSanFunction::combineShadowsThenConvert(Type *T, Value *V1, Value *V2, diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index 7795cce..b5548d4 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -69,14 +69,6 @@ namespace llvm {  // Command line option to enable vtable value profiling. Defined in  // ProfileData/InstrProf.cpp: -enable-vtable-value-profiling=  extern cl::opt<bool> EnableVTableValueProfiling; -// TODO: Remove -debug-info-correlate in next LLVM release, in favor of -// -profile-correlate=debug-info. -cl::opt<bool> DebugInfoCorrelate( -    "debug-info-correlate", -    cl::desc("Use debug info to correlate profiles. (Deprecated, use " -             "-profile-correlate=debug-info)"), -    cl::init(false)); -  LLVM_ABI cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate(      "profile-correlate",      cl::desc("Use debug info or binary file to correlate profiles."), @@ -1047,7 +1039,7 @@ void InstrLowerer::lowerValueProfileInst(InstrProfValueProfileInst *Ind) {    // in lightweight mode. We need to move the value profile pointer to the    // Counter struct to get this working.    assert( -      !DebugInfoCorrelate && ProfileCorrelate == InstrProfCorrelator::NONE && +      ProfileCorrelate == InstrProfCorrelator::NONE &&        "Value profiling is not yet supported with lightweight instrumentation");    GlobalVariable *Name = Ind->getName();    auto It = ProfileDataMap.find(Name); @@ -1504,7 +1496,7 @@ static inline Constant *getVTableAddrForProfData(GlobalVariable *GV) {  }  void InstrLowerer::getOrCreateVTableProfData(GlobalVariable *GV) { -  assert(!DebugInfoCorrelate && +  assert(ProfileCorrelate != InstrProfCorrelator::DEBUG_INFO &&           "Value profiling is not supported with lightweight instrumentation");    if (GV->isDeclaration() || GV->hasAvailableExternallyLinkage())      return; @@ -1584,8 +1576,7 @@ GlobalVariable *InstrLowerer::setupProfileSection(InstrProfInstBase *Inc,    // Use internal rather than private linkage so the counter variable shows up    // in the symbol table when using debug info for correlation. -  if ((DebugInfoCorrelate || -       ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO) && +  if (ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO &&        TT.isOSBinFormatMachO() && Linkage == GlobalValue::PrivateLinkage)      Linkage = GlobalValue::InternalLinkage; @@ -1691,8 +1682,7 @@ InstrLowerer::getOrCreateRegionCounters(InstrProfCntrInstBase *Inc) {    auto *CounterPtr = setupProfileSection(Inc, IPSK_cnts);    PD.RegionCounters = CounterPtr; -  if (DebugInfoCorrelate || -      ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO) { +  if (ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO) {      LLVMContext &Ctx = M.getContext();      Function *Fn = Inc->getParent()->getParent();      if (auto *SP = Fn->getSubprogram()) { @@ -1737,7 +1727,7 @@ InstrLowerer::getOrCreateRegionCounters(InstrProfCntrInstBase *Inc) {  void InstrLowerer::createDataVariable(InstrProfCntrInstBase *Inc) {    // When debug information is correlated to profile data, a data variable    // is not needed. -  if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO) +  if (ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)      return;    GlobalVariable *NamePtr = Inc->getName(); diff --git a/llvm/lib/Transforms/Instrumentation/MemProfUse.cpp b/llvm/lib/Transforms/Instrumentation/MemProfUse.cpp index a6ec6c1..b72d41a 100644 --- a/llvm/lib/Transforms/Instrumentation/MemProfUse.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemProfUse.cpp @@ -127,15 +127,19 @@ static uint64_t computeStackId(const memprof::Frame &Frame) {    return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column);  } +static AllocationType getAllocType(const AllocationInfo *AllocInfo) { +  return getAllocType(AllocInfo->Info.getTotalLifetimeAccessDensity(), +                      AllocInfo->Info.getAllocCount(), +                      AllocInfo->Info.getTotalLifetime()); +} +  static AllocationType addCallStack(CallStackTrie &AllocTrie,                                     const AllocationInfo *AllocInfo,                                     uint64_t FullStackId) {    SmallVector<uint64_t> StackIds;    for (const auto &StackFrame : AllocInfo->CallStack)      StackIds.push_back(computeStackId(StackFrame)); -  auto AllocType = getAllocType(AllocInfo->Info.getTotalLifetimeAccessDensity(), -                                AllocInfo->Info.getAllocCount(), -                                AllocInfo->Info.getTotalLifetime()); +  auto AllocType = getAllocType(AllocInfo);    std::vector<ContextTotalSize> ContextSizeInfo;    if (recordContextSizeInfoForAnalysis()) {      auto TotalSize = AllocInfo->Info.getTotalSize(); @@ -216,7 +220,6 @@ static void HandleUnsupportedAnnotationKinds(GlobalVariable &GVar,    }    LLVM_DEBUG(dbgs() << "Skip annotation for " << GVar.getName() << " due to "                      << Reason << ".\n"); -  return;  }  struct AllocMatchInfo { @@ -406,22 +409,39 @@ handleAllocSite(Instruction &I, CallBase *CI,                  const std::set<const AllocationInfo *> &AllocInfoSet,                  std::map<std::pair<uint64_t, unsigned>, AllocMatchInfo>                      &FullStackIdToAllocMatchInfo) { +  // TODO: Remove this once the profile creation logic deduplicates contexts +  // that are the same other than the IsInlineFrame bool. Until then, keep the +  // largest. +  DenseMap<uint64_t, const AllocationInfo *> UniqueFullContextIdAllocInfo; +  for (auto *AllocInfo : AllocInfoSet) { +    auto FullStackId = computeFullStackId(AllocInfo->CallStack); +    auto [It, Inserted] = +        UniqueFullContextIdAllocInfo.insert({FullStackId, AllocInfo}); +    // If inserted entry, done. +    if (Inserted) +      continue; +    // Keep the larger one, or the noncold one if they are the same size. +    auto CurSize = It->second->Info.getTotalSize(); +    auto NewSize = AllocInfo->Info.getTotalSize(); +    if ((CurSize > NewSize) || +        (CurSize == NewSize && +         getAllocType(AllocInfo) != AllocationType::NotCold)) +      continue; +    It->second = AllocInfo; +  }    // We may match this instruction's location list to multiple MIB    // contexts. Add them to a Trie specialized for trimming the contexts to    // the minimal needed to disambiguate contexts with unique behavior.    CallStackTrie AllocTrie(&ORE, MaxColdSize);    uint64_t TotalSize = 0;    uint64_t TotalColdSize = 0; -  for (auto *AllocInfo : AllocInfoSet) { +  for (auto &[FullStackId, AllocInfo] : UniqueFullContextIdAllocInfo) {      // Check the full inlined call stack against this one.      // If we found and thus matched all frames on the call, include      // this MIB.      if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack,                                             InlinedCallStack)) {        NumOfMemProfMatchedAllocContexts++; -      uint64_t FullStackId = 0; -      if (ClPrintMemProfMatchInfo || recordContextSizeInfoForAnalysis()) -        FullStackId = computeFullStackId(AllocInfo->CallStack);        auto AllocType = addCallStack(AllocTrie, AllocInfo, FullStackId);        TotalSize += AllocInfo->Info.getTotalSize();        if (AllocType == AllocationType::Cold) diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 71736cf..af53fa0 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -456,7 +456,7 @@ createIRLevelProfileFlagVar(Module &M,      ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;    if (PGOInstrumentLoopEntries)      ProfileVersion |= VARIANT_MASK_INSTR_LOOP_ENTRIES; -  if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO) +  if (ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)      ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;    if (PGOFunctionEntryCoverage)      ProfileVersion |= diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 66e45ec..e84ca81 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -122,16 +122,22 @@ static cl::opt<unsigned>                    cl::desc("Maximum cost accepted for the transformation"),                    cl::Hidden, cl::init(50)); -extern cl::opt<bool> ProfcheckDisableMetadataFixes; - -} // namespace llvm -  static cl::opt<double> MaxClonedRate(      "dfa-max-cloned-rate",      cl::desc(          "Maximum cloned instructions rate accepted for the transformation"),      cl::Hidden, cl::init(7.5)); +static cl::opt<unsigned> +    MaxOuterUseBlocks("dfa-max-out-use-blocks", +                      cl::desc("Maximum unduplicated blocks with outer uses " +                               "accepted for the transformation"), +                      cl::Hidden, cl::init(40)); + +extern cl::opt<bool> ProfcheckDisableMetadataFixes; + +} // namespace llvm +  namespace {  class SelectInstToUnfold {    SelectInst *SI; @@ -965,8 +971,16 @@ private:      // SLPVectorizer.      // TODO: Thread the switch partially before reaching the threshold.      uint64_t NumOrigInst = 0; -    for (auto *BB : DuplicateMap.keys()) +    uint64_t NumOuterUseBlock = 0; +    for (auto *BB : DuplicateMap.keys()) {        NumOrigInst += BB->sizeWithoutDebug(); +      // Only unduplicated blocks with single predecessor require new phi +      // nodes. +      for (auto *Succ : successors(BB)) +        if (!DuplicateMap.count(Succ) && Succ->getSinglePredecessor()) +          NumOuterUseBlock++; +    } +      if (double(NumClonedInst) / double(NumOrigInst) > MaxClonedRate) {        LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, too much "                             "instructions wll be cloned\n"); @@ -977,6 +991,20 @@ private:        return false;      } +    // Too much unduplicated blocks with outer uses may cause too much +    // insertions of phi nodes for duplicated definitions. TODO: Drop this +    // threshold if we come up with another way to reduce the number of inserted +    // phi nodes. +    if (NumOuterUseBlock > MaxOuterUseBlocks) { +      LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, too much " +                           "blocks with outer uses\n"); +      ORE->emit([&]() { +        return OptimizationRemarkMissed(DEBUG_TYPE, "NotProfitable", Switch) +               << "Too much blocks with outer uses."; +      }); +      return false; +    } +      InstructionCost DuplicationCost = 0;      unsigned JumpTableSize = 0; diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 7ebcc21..4ba4ba3 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -162,8 +162,6 @@ class IndVarSimplify {                                   const SCEV *ExitCount,                                   PHINode *IndVar, SCEVExpander &Rewriter); -  bool sinkUnusedInvariants(Loop *L); -  public:    IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,                   const DataLayout &DL, TargetLibraryInfo *TLI, @@ -1079,85 +1077,6 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,    return true;  } -//===----------------------------------------------------------------------===// -//  sinkUnusedInvariants. A late subpass to cleanup loop preheaders. -//===----------------------------------------------------------------------===// - -/// If there's a single exit block, sink any loop-invariant values that -/// were defined in the preheader but not used inside the loop into the -/// exit block to reduce register pressure in the loop. -bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { -  BasicBlock *ExitBlock = L->getExitBlock(); -  if (!ExitBlock) return false; - -  BasicBlock *Preheader = L->getLoopPreheader(); -  if (!Preheader) return false; - -  bool MadeAnyChanges = false; -  for (Instruction &I : llvm::make_early_inc_range(llvm::reverse(*Preheader))) { - -    // Skip BB Terminator. -    if (Preheader->getTerminator() == &I) -      continue; - -    // New instructions were inserted at the end of the preheader. -    if (isa<PHINode>(I)) -      break; - -    // Don't move instructions which might have side effects, since the side -    // effects need to complete before instructions inside the loop.  Also don't -    // move instructions which might read memory, since the loop may modify -    // memory. Note that it's okay if the instruction might have undefined -    // behavior: LoopSimplify guarantees that the preheader dominates the exit -    // block. -    if (I.mayHaveSideEffects() || I.mayReadFromMemory()) -      continue; - -    // Skip debug or pseudo instructions. -    if (I.isDebugOrPseudoInst()) -      continue; - -    // Skip eh pad instructions. -    if (I.isEHPad()) -      continue; - -    // Don't sink alloca: we never want to sink static alloca's out of the -    // entry block, and correctly sinking dynamic alloca's requires -    // checks for stacksave/stackrestore intrinsics. -    // FIXME: Refactor this check somehow? -    if (isa<AllocaInst>(&I)) -      continue; - -    // Determine if there is a use in or before the loop (direct or -    // otherwise). -    bool UsedInLoop = false; -    for (Use &U : I.uses()) { -      Instruction *User = cast<Instruction>(U.getUser()); -      BasicBlock *UseBB = User->getParent(); -      if (PHINode *P = dyn_cast<PHINode>(User)) { -        unsigned i = -          PHINode::getIncomingValueNumForOperand(U.getOperandNo()); -        UseBB = P->getIncomingBlock(i); -      } -      if (UseBB == Preheader || L->contains(UseBB)) { -        UsedInLoop = true; -        break; -      } -    } - -    // If there is, the def must remain in the preheader. -    if (UsedInLoop) -      continue; - -    // Otherwise, sink it to the exit block. -    I.moveBefore(ExitBlock->getFirstInsertionPt()); -    SE->forgetValue(&I); -    MadeAnyChanges = true; -  } - -  return MadeAnyChanges; -} -  static void replaceExitCond(BranchInst *BI, Value *NewCond,                              SmallVectorImpl<WeakTrackingVH> &DeadInsts) {    auto *OldCond = BI->getCondition(); @@ -2065,10 +1984,6 @@ bool IndVarSimplify::run(Loop *L) {    // The Rewriter may not be used from this point on. -  // Loop-invariant instructions in the preheader that aren't used in the -  // loop may be sunk below the loop to reduce register pressure. -  Changed |= sinkUnusedInvariants(L); -    // rewriteFirstIterationLoopExitValues does not rely on the computation of    // trip count and therefore can further simplify exit values in addition to    // rewriteLoopExitValues. diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index b2c526b..d13b990 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -211,9 +211,15 @@ static Instruction *cloneInstructionInExitBlock(  static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo,                               MemorySSAUpdater &MSSAU); -static void moveInstructionBefore(Instruction &I, BasicBlock::iterator Dest, -                                  ICFLoopSafetyInfo &SafetyInfo, -                                  MemorySSAUpdater &MSSAU, ScalarEvolution *SE); +static void moveInstructionBefore( +    Instruction &I, BasicBlock::iterator Dest, ICFLoopSafetyInfo &SafetyInfo, +    MemorySSAUpdater &MSSAU, ScalarEvolution *SE, +    MemorySSA::InsertionPlace Point = MemorySSA::BeforeTerminator); + +static bool sinkUnusedInvariantsFromPreheaderToExit( +    Loop *L, AAResults *AA, ICFLoopSafetyInfo *SafetyInfo, +    MemorySSAUpdater &MSSAU, ScalarEvolution *SE, DominatorTree *DT, +    SinkAndHoistLICMFlags &SinkFlags, OptimizationRemarkEmitter *ORE);  static void foreachMemoryAccess(MemorySSA *MSSA, Loop *L,                                  function_ref<void(Instruction *)> Fn); @@ -471,6 +477,12 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,                                      TLI, TTI, L, MSSAU, &SafetyInfo, Flags, ORE)              : sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L,                           MSSAU, &SafetyInfo, Flags, ORE); + +  // sink pre-header defs that are unused in-loop into the unique exit to reduce +  // pressure. +  Changed |= sinkUnusedInvariantsFromPreheaderToExit(L, AA, &SafetyInfo, MSSAU, +                                                     SE, DT, Flags, ORE); +    Flags.setIsSink(false);    if (Preheader)      Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, AC, TLI, L, @@ -1456,19 +1468,80 @@ static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo,  static void moveInstructionBefore(Instruction &I, BasicBlock::iterator Dest,                                    ICFLoopSafetyInfo &SafetyInfo, -                                  MemorySSAUpdater &MSSAU, -                                  ScalarEvolution *SE) { +                                  MemorySSAUpdater &MSSAU, ScalarEvolution *SE, +                                  MemorySSA::InsertionPlace Point) {    SafetyInfo.removeInstruction(&I);    SafetyInfo.insertInstructionTo(&I, Dest->getParent());    I.moveBefore(*Dest->getParent(), Dest);    if (MemoryUseOrDef *OldMemAcc = cast_or_null<MemoryUseOrDef>(            MSSAU.getMemorySSA()->getMemoryAccess(&I))) -    MSSAU.moveToPlace(OldMemAcc, Dest->getParent(), -                      MemorySSA::BeforeTerminator); +    MSSAU.moveToPlace(OldMemAcc, Dest->getParent(), Point);    if (SE)      SE->forgetBlockAndLoopDispositions(&I);  } +// If there's a single exit block, sink any loop-invariant values that were +// defined in the preheader but not used inside the loop into the exit block +// to reduce register pressure in the loop. +static bool sinkUnusedInvariantsFromPreheaderToExit( +    Loop *L, AAResults *AA, ICFLoopSafetyInfo *SafetyInfo, +    MemorySSAUpdater &MSSAU, ScalarEvolution *SE, DominatorTree *DT, +    SinkAndHoistLICMFlags &SinkFlags, OptimizationRemarkEmitter *ORE) { +  BasicBlock *ExitBlock = L->getExitBlock(); +  if (!ExitBlock) +    return false; + +  BasicBlock *Preheader = L->getLoopPreheader(); +  if (!Preheader) +    return false; + +  bool MadeAnyChanges = false; + +  for (Instruction &I : llvm::make_early_inc_range(llvm::reverse(*Preheader))) { + +    // Skip terminator. +    if (Preheader->getTerminator() == &I) +      continue; + +    // New instructions were inserted at the end of the preheader. +    if (isa<PHINode>(I)) +      break; + +    // Don't move instructions which might have side effects, since the side +    // effects need to complete before instructions inside the loop. Note that +    // it's okay if the instruction might have undefined behavior: LoopSimplify +    // guarantees that the preheader dominates the exit block. +    if (I.mayHaveSideEffects()) +      continue; + +    if (!canSinkOrHoistInst(I, AA, DT, L, MSSAU, true, SinkFlags, nullptr)) +      continue; + +    // Determine if there is a use in or before the loop (direct or +    // otherwise). +    bool UsedInLoopOrPreheader = false; +    for (Use &U : I.uses()) { +      auto *UserI = cast<Instruction>(U.getUser()); +      BasicBlock *UseBB = UserI->getParent(); +      if (auto *PN = dyn_cast<PHINode>(UserI)) { +        UseBB = PN->getIncomingBlock(U); +      } +      if (UseBB == Preheader || L->contains(UseBB)) { +        UsedInLoopOrPreheader = true; +        break; +      } +    } +    if (UsedInLoopOrPreheader) +      continue; + +    moveInstructionBefore(I, ExitBlock->getFirstInsertionPt(), *SafetyInfo, +                          MSSAU, SE, MemorySSA::Beginning); +    MadeAnyChanges = true; +  } + +  return MadeAnyChanges; +} +  static Instruction *sinkThroughTriviallyReplaceablePHI(      PHINode *TPN, Instruction *I, LoopInfo *LI,      SmallDenseMap<BasicBlock *, Instruction *, 32> &SunkCopies, diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 1a279b6..001215a 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -1318,6 +1318,11 @@ public:    /// the loop, in which case some special-case heuristics may be used.    bool AllFixupsOutsideLoop = true; +  /// This records whether all of the fixups using this LSRUse are unconditional +  /// within the loop, meaning they will be executed on every path to the loop +  /// latch. This includes fixups before early exits. +  bool AllFixupsUnconditional = true; +    /// RigidFormula is set to true to guarantee that this use will be associated    /// with a single formula--the one that initially matched. Some SCEV    /// expressions cannot be expanded. This allows LSR to consider the registers @@ -1421,16 +1426,22 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg,      if (TTI->isIndexedLoadLegal(TTI->MIM_PostInc, AR->getType()) ||          TTI->isIndexedStoreLegal(TTI->MIM_PostInc, AR->getType())) {        const SCEV *Start; -      const SCEVConstant *Step; -      if (match(AR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEVConstant(Step)))) +      const APInt *Step; +      if (match(AR, m_scev_AffineAddRec(m_SCEV(Start), m_scev_APInt(Step)))) {          // If the step size matches the base offset, we could use pre-indexed          // addressing. -        if (((AMK & TTI::AMK_PreIndexed) && F.BaseOffset.isFixed() && -             Step->getAPInt() == F.BaseOffset.getFixedValue()) || -            ((AMK & TTI::AMK_PostIndexed) && !isa<SCEVConstant>(Start) && -             SE->isLoopInvariant(Start, L))) +        bool CanPreIndex = (AMK & TTI::AMK_PreIndexed) && +                           F.BaseOffset.isFixed() && +                           *Step == F.BaseOffset.getFixedValue(); +        bool CanPostIndex = (AMK & TTI::AMK_PostIndexed) && +                            !isa<SCEVConstant>(Start) && +                            SE->isLoopInvariant(Start, L); +        // We can only pre or post index when the load/store is unconditional. +        if ((CanPreIndex || CanPostIndex) && LU.AllFixupsUnconditional)            LoopCost = 0; +      }      } +      // If the loop counts down to zero and we'll be using a hardware loop then      // the addrec will be combined into the hardware loop instruction.      if (LU.Kind == LSRUse::ICmpZero && F.countsDownToZero() && @@ -1783,6 +1794,9 @@ void LSRUse::print(raw_ostream &OS) const {    if (AllFixupsOutsideLoop)      OS << ", all-fixups-outside-loop"; +  if (AllFixupsUnconditional) +    OS << ", all-fixups-unconditional"; +    if (WidestFixupType)      OS << ", widest fixup type: " << *WidestFixupType;  } @@ -2213,6 +2227,7 @@ class LSRInstance {    void InsertSupplementalFormula(const SCEV *S, LSRUse &LU, size_t LUIdx);    void CountRegisters(const Formula &F, size_t LUIdx);    bool InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F); +  bool IsFixupExecutedEachIncrement(const LSRFixup &LF) const;    void CollectLoopInvariantFixupsAndFormulae(); @@ -3607,6 +3622,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {      LF.PostIncLoops = TmpPostIncLoops;      LF.Offset = Offset;      LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); +    LU.AllFixupsUnconditional &= IsFixupExecutedEachIncrement(LF);      // Create SCEV as Formula for calculating baseline cost      if (!VisitedLSRUse.count(LUIdx) && !LF.isUseFullyOutsideLoop(L)) { @@ -3680,6 +3696,14 @@ bool LSRInstance::InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F) {    return true;  } +/// Test whether this fixup will be executed each time the corresponding IV +/// increment instruction is executed. +bool LSRInstance::IsFixupExecutedEachIncrement(const LSRFixup &LF) const { +  // If the fixup block dominates the IV increment block then there is no path +  // through the loop to the increment that doesn't pass through the fixup. +  return DT.dominates(LF.UserInst->getParent(), IVIncInsertPos->getParent()); +} +  /// Check for other uses of loop-invariant values which we're tracking. These  /// other uses will pin these values in registers, making them less profitable  /// for elimination. @@ -3803,6 +3827,7 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() {          LF.OperandValToReplace = U;          LF.Offset = Offset;          LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); +        LU.AllFixupsUnconditional &= IsFixupExecutedEachIncrement(LF);          if (!LU.WidestFixupType ||              SE.getTypeSizeInBits(LU.WidestFixupType) <              SE.getTypeSizeInBits(LF.OperandValToReplace->getType())) @@ -4940,6 +4965,7 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() {        LLVM_DEBUG(dbgs() << "  Deleting use "; LU.print(dbgs()); dbgs() << '\n');        LUThatHas->AllFixupsOutsideLoop &= LU.AllFixupsOutsideLoop; +      LUThatHas->AllFixupsUnconditional &= LU.AllFixupsUnconditional;        // Transfer the fixups of LU to LUThatHas.        for (LSRFixup &Fixup : LU.Fixups) { diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index e043d07..08be5df 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -1534,8 +1534,8 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,    bool SrcNotDom = false;    auto CaptureTrackingWithModRef = -      [&](Instruction *AI, -          function_ref<bool(Instruction *)> ModRefCallback) -> bool { +      [&](Instruction *AI, function_ref<bool(Instruction *)> ModRefCallback, +          bool &AddressCaptured) -> bool {      SmallVector<Instruction *, 8> Worklist;      Worklist.push_back(AI);      unsigned MaxUsesToExplore = getDefaultMaxUsesToExploreForCaptureTracking(); @@ -1559,8 +1559,9 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,          if (!Visited.insert(&U).second)            continue;          UseCaptureInfo CI = DetermineUseCaptureKind(U, AI); -        if (capturesAnything(CI.UseCC)) +        if (capturesAnyProvenance(CI.UseCC))            return false; +        AddressCaptured |= capturesAddress(CI.UseCC);          if (UI->mayReadOrWriteMemory()) {            if (UI->isLifetimeStartOrEnd()) { @@ -1627,7 +1628,9 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,      return true;    }; -  if (!CaptureTrackingWithModRef(DestAlloca, DestModRefCallback)) +  bool DestAddressCaptured = false; +  if (!CaptureTrackingWithModRef(DestAlloca, DestModRefCallback, +                                 DestAddressCaptured))      return false;    // Bailout if Dest may have any ModRef before Store.    if (!ReachabilityWorklist.empty() && @@ -1653,7 +1656,14 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store,      return true;    }; -  if (!CaptureTrackingWithModRef(SrcAlloca, SrcModRefCallback)) +  bool SrcAddressCaptured = false; +  if (!CaptureTrackingWithModRef(SrcAlloca, SrcModRefCallback, +                                 SrcAddressCaptured)) +    return false; + +  // If both the source and destination address are captured, the fact that they +  // are no longer two separate allocations may be observed. +  if (DestAddressCaptured && SrcAddressCaptured)      return false;    // We can do the transformation. First, move the SrcAlloca to the start of the diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 5af6c96..bb6c879 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -81,6 +81,7 @@ STATISTIC(  STATISTIC(NumInvariantConditionsInjected,            "Number of invariant conditions injected and unswitched"); +namespace llvm {  static cl::opt<bool> EnableNonTrivialUnswitch(      "enable-nontrivial-unswitch", cl::init(false), cl::Hidden,      cl::desc("Forcibly enables non-trivial loop unswitching rather than " @@ -131,11 +132,17 @@ static cl::opt<bool> InjectInvariantConditions(  static cl::opt<unsigned> InjectInvariantConditionHotnesThreshold(      "simple-loop-unswitch-inject-invariant-condition-hotness-threshold", -    cl::Hidden, cl::desc("Only try to inject loop invariant conditions and " -                         "unswitch on them to eliminate branches that are " -                         "not-taken 1/<this option> times or less."), +    cl::Hidden, +    cl::desc("Only try to inject loop invariant conditions and " +             "unswitch on them to eliminate branches that are " +             "not-taken 1/<this option> times or less."),      cl::init(16)); +static cl::opt<bool> EstimateProfile("simple-loop-unswitch-estimate-profile", +                                     cl::Hidden, cl::init(true)); +extern cl::opt<bool> ProfcheckDisableMetadataFixes; +} // namespace llvm +  AnalysisKey ShouldRunExtraSimpleLoopUnswitch::Key;  namespace {  struct CompareDesc { @@ -268,13 +275,42 @@ static bool areLoopExitPHIsLoopInvariant(const Loop &L,    llvm_unreachable("Basic blocks should never be empty!");  } -/// Copy a set of loop invariant values \p ToDuplicate and insert them at the +/// Copy a set of loop invariant values \p Invariants and insert them at the  /// end of \p BB and conditionally branch on the copied condition. We only  /// branch on a single value. +/// We attempt to estimate the profile of the resulting conditional branch from +/// \p ComputeProfFrom, which is the original conditional branch we're +/// unswitching. +/// When \p Direction is true, the \p Invariants form a disjunction, and the +/// branch conditioned on it exits the loop on the "true" case. When \p +/// Direction is false, the \p Invariants form a conjunction and the branch +/// exits on the "false" case.  static void buildPartialUnswitchConditionalBranch(      BasicBlock &BB, ArrayRef<Value *> Invariants, bool Direction,      BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, bool InsertFreeze, -    const Instruction *I, AssumptionCache *AC, const DominatorTree &DT) { +    const Instruction *I, AssumptionCache *AC, const DominatorTree &DT, +    const BranchInst &ComputeProfFrom) { + +  SmallVector<uint32_t> BranchWeights; +  bool HasBranchWeights = EstimateProfile && !ProfcheckDisableMetadataFixes && +                          extractBranchWeights(ComputeProfFrom, BranchWeights); +  // If Direction is true, that means we had a disjunction and that the "true" +  // case exits. The probability of the disjunction of the subset of terms is at +  // most as high as the original one. So, if the probability is higher than the +  // one we'd assign in absence of a profile (i.e. 0.5), we will use 0.5, +  // but if it's lower, we will use the original probability. +  // Conversely, if Direction is false, that means we had a conjunction, and the +  // probability of exiting is captured in the second branch weight. That +  // probability is a disjunction (of the negation of the original terms). The +  // same reasoning applies as above. +  // Issue #165649: should we expect BFI to conserve, and use that to calculate +  // the branch weights? +  if (HasBranchWeights && +      static_cast<double>(BranchWeights[Direction ? 0 : 1]) / +              static_cast<double>(sum_of(BranchWeights)) > +          0.5) +    HasBranchWeights = false; +    IRBuilder<> IRB(&BB);    IRB.SetCurrentDebugLocation(DebugLoc::getCompilerGenerated()); @@ -287,8 +323,14 @@ static void buildPartialUnswitchConditionalBranch(    Value *Cond = Direction ? IRB.CreateOr(FrozenInvariants)                            : IRB.CreateAnd(FrozenInvariants); -  IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, -                   Direction ? &NormalSucc : &UnswitchedSucc); +  auto *BR = IRB.CreateCondBr( +      Cond, Direction ? &UnswitchedSucc : &NormalSucc, +      Direction ? &NormalSucc : &UnswitchedSucc, +      HasBranchWeights ? ComputeProfFrom.getMetadata(LLVMContext::MD_prof) +                       : nullptr); +  if (!HasBranchWeights) +    setExplicitlyUnknownBranchWeightsIfProfiled( +        *BR, *BR->getParent()->getParent(), DEBUG_TYPE);  }  /// Copy a set of loop invariant values, and conditionally branch on them. @@ -658,7 +700,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,               " condition!");      buildPartialUnswitchConditionalBranch(          *OldPH, Invariants, ExitDirection, *UnswitchedBB, *NewPH, -        FreezeLoopUnswitchCond, OldPH->getTerminator(), nullptr, DT); +        FreezeLoopUnswitchCond, OldPH->getTerminator(), nullptr, DT, BI);    }    // Update the dominator tree with the added edge. @@ -2477,7 +2519,7 @@ static void unswitchNontrivialInvariants(      else {        buildPartialUnswitchConditionalBranch(            *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, -          FreezeLoopUnswitchCond, BI, &AC, DT); +          FreezeLoopUnswitchCond, BI, &AC, DT, *BI);      }      DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index 9829d4d..11db0ec 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -674,6 +674,79 @@ BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, DominatorTree *DT,    return SplitBlock(BB, BB->getTerminator(), DT, LI, MSSAU, BBName);  } +/// Helper function to update the cycle or loop information after inserting a +/// new block between a callbr instruction and one of its target blocks.  Adds +/// the new block to the innermost cycle or loop that the callbr instruction and +/// the original target block share. +/// \p LCI            cycle or loop information to update +/// \p CallBrBlock    block containing the callbr instruction +/// \p CallBrTarget   new target block of the callbr instruction +/// \p Succ           original target block of the callbr instruction +template <typename TI, typename T> +static bool updateCycleLoopInfo(TI *LCI, BasicBlock *CallBrBlock, +                                BasicBlock *CallBrTarget, BasicBlock *Succ) { +  static_assert(std::is_same_v<TI, CycleInfo> || std::is_same_v<TI, LoopInfo>, +                "type must be CycleInfo or LoopInfo"); +  if (!LCI) +    return false; + +  T *LC; +  if constexpr (std::is_same_v<TI, CycleInfo>) +    LC = LCI->getSmallestCommonCycle(CallBrBlock, Succ); +  else +    LC = LCI->getSmallestCommonLoop(CallBrBlock, Succ); +  if (!LC) +    return false; + +  if constexpr (std::is_same_v<TI, CycleInfo>) +    LCI->addBlockToCycle(CallBrTarget, LC); +  else +    LC->addBasicBlockToLoop(CallBrTarget, *LCI); + +  return true; +} + +BasicBlock *llvm::SplitCallBrEdge(BasicBlock *CallBrBlock, BasicBlock *Succ, +                                  unsigned SuccIdx, DomTreeUpdater *DTU, +                                  CycleInfo *CI, LoopInfo *LI, +                                  bool *UpdatedLI) { +  CallBrInst *CallBr = dyn_cast<CallBrInst>(CallBrBlock->getTerminator()); +  assert(CallBr && "expected callbr terminator"); +  assert(SuccIdx < CallBr->getNumSuccessors() && +         Succ == CallBr->getSuccessor(SuccIdx) && "invalid successor index"); + +  // Create a new block between callbr and the specified successor. +  // splitBlockBefore cannot be re-used here since it cannot split if the split +  // point is a PHI node (because BasicBlock::splitBasicBlockBefore cannot +  // handle that). But we don't need to rewire every part of a potential PHI +  // node. We only care about the edge between CallBrBlock and the original +  // successor. +  BasicBlock *CallBrTarget = +      BasicBlock::Create(CallBrBlock->getContext(), +                         CallBrBlock->getName() + ".target." + Succ->getName(), +                         CallBrBlock->getParent()); +  // Rewire control flow from the new target block to the original successor. +  Succ->replacePhiUsesWith(CallBrBlock, CallBrTarget); +  // Rewire control flow from callbr to the new target block. +  CallBr->setSuccessor(SuccIdx, CallBrTarget); +  // Jump from the new target block to the original successor. +  BranchInst::Create(Succ, CallBrTarget); + +  bool Updated = +      updateCycleLoopInfo<LoopInfo, Loop>(LI, CallBrBlock, CallBrTarget, Succ); +  if (UpdatedLI) +    *UpdatedLI = Updated; +  updateCycleLoopInfo<CycleInfo, Cycle>(CI, CallBrBlock, CallBrTarget, Succ); +  if (DTU) { +    DTU->applyUpdates({{DominatorTree::Insert, CallBrBlock, CallBrTarget}}); +    if (DTU->getDomTree().dominates(CallBrBlock, Succ)) +      DTU->applyUpdates({{DominatorTree::Delete, CallBrBlock, Succ}, +                         {DominatorTree::Insert, CallBrTarget, Succ}}); +  } + +  return CallBrTarget; +} +  void llvm::setUnwindEdgeTo(Instruction *TI, BasicBlock *Succ) {    if (auto *II = dyn_cast<InvokeInst>(TI))      II->setUnwindDest(Succ); diff --git a/llvm/lib/Transforms/Utils/ControlFlowUtils.cpp b/llvm/lib/Transforms/Utils/ControlFlowUtils.cpp index 0046a00..287a177 100644 --- a/llvm/lib/Transforms/Utils/ControlFlowUtils.cpp +++ b/llvm/lib/Transforms/Utils/ControlFlowUtils.cpp @@ -13,6 +13,7 @@  #include "llvm/Transforms/Utils/ControlFlowUtils.h"  #include "llvm/ADT/SetVector.h"  #include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/Instructions.h"  #include "llvm/IR/ValueHandle.h" @@ -281,7 +282,9 @@ std::pair<BasicBlock *, bool> ControlFlowHub::finalize(    for (auto [BB, Succ0, Succ1] : Branches) {  #ifndef NDEBUG -    assert(Incoming.insert(BB).second && "Duplicate entry for incoming block."); +    assert( +        (Incoming.insert(BB).second || isa<CallBrInst>(BB->getTerminator())) && +        "Duplicate entry for incoming block.");  #endif      if (Succ0)        Outgoing.insert(Succ0); diff --git a/llvm/lib/Transforms/Utils/FixIrreducible.cpp b/llvm/lib/Transforms/Utils/FixIrreducible.cpp index 45e1d12..804af22 100644 --- a/llvm/lib/Transforms/Utils/FixIrreducible.cpp +++ b/llvm/lib/Transforms/Utils/FixIrreducible.cpp @@ -79,6 +79,53 @@  // Limitation: The pass cannot handle switch statements and indirect  //             branches. Both must be lowered to plain branches first.  // +// CallBr support: CallBr is handled as a more general branch instruction which +// can have multiple successors. The pass redirects the edges to intermediate +// target blocks that unconditionally branch to the original callbr target +// blocks. This allows the control flow hub to know to which of the original +// target blocks to jump to. +// Example input CFG: +//                        Entry (callbr) +//                       /     \ +//                      v       v +//                      H ----> B +//                      ^      /| +//                       `----' | +//                              v +//                             Exit +// +// becomes: +//                        Entry (callbr) +//                       /     \ +//                      v       v +//                 target.H   target.B +//                      |       | +//                      v       v +//                      H ----> B +//                      ^      /| +//                       `----' | +//                              v +//                             Exit +// +// Note +// OUTPUT CFG: Converted to a natural loop with a new header N. +// +//                        Entry (callbr) +//                       /     \ +//                      v       v +//                 target.H   target.B +//                      \       / +//                       \     / +//                        v   v +//                          N <---. +//                         / \     \ +//                        /   \     | +//                       v     v    / +//                       H --> B --' +//                             | +//                             v +//                            Exit +//  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/Utils/FixIrreducible.h" @@ -231,6 +278,7 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,      return false;    LLVM_DEBUG(dbgs() << "Processing cycle:\n" << CI.print(&C) << "\n";); +  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);    ControlFlowHub CHub;    SetVector<BasicBlock *> Predecessors; @@ -242,18 +290,32 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,    }    for (BasicBlock *P : Predecessors) { -    auto *Branch = cast<BranchInst>(P->getTerminator()); -    // Exactly one of the two successors is the header. -    BasicBlock *Succ0 = Branch->getSuccessor(0) == Header ? Header : nullptr; -    BasicBlock *Succ1 = Succ0 ? nullptr : Header; -    if (!Succ0) -      assert(Branch->getSuccessor(1) == Header); -    assert(Succ0 || Succ1); -    CHub.addBranch(P, Succ0, Succ1); - -    LLVM_DEBUG(dbgs() << "Added internal branch: " << P->getName() << " -> " -                      << (Succ0 ? Succ0->getName() : "") << " " -                      << (Succ1 ? Succ1->getName() : "") << "\n"); +    if (BranchInst *Branch = dyn_cast<BranchInst>(P->getTerminator())) { +      // Exactly one of the two successors is the header. +      BasicBlock *Succ0 = Branch->getSuccessor(0) == Header ? Header : nullptr; +      BasicBlock *Succ1 = Succ0 ? nullptr : Header; +      assert(Succ0 || Branch->getSuccessor(1) == Header); +      assert(Succ0 || Succ1); +      CHub.addBranch(P, Succ0, Succ1); + +      LLVM_DEBUG(dbgs() << "Added internal branch: " << printBasicBlock(P) +                        << " -> " << printBasicBlock(Succ0) +                        << (Succ0 && Succ1 ? " " : "") << printBasicBlock(Succ1) +                        << '\n'); +    } else if (CallBrInst *CallBr = dyn_cast<CallBrInst>(P->getTerminator())) { +      for (unsigned I = 0; I < CallBr->getNumSuccessors(); ++I) { +        BasicBlock *Succ = CallBr->getSuccessor(I); +        if (Succ != Header) +          continue; +        BasicBlock *NewSucc = SplitCallBrEdge(P, Succ, I, &DTU, &CI, LI); +        CHub.addBranch(NewSucc, Succ); +        LLVM_DEBUG(dbgs() << "Added internal branch: " +                          << printBasicBlock(NewSucc) << " -> " +                          << printBasicBlock(Succ) << '\n'); +      } +    } else { +      llvm_unreachable("unsupported block terminator"); +    }    }    // Redirect external incoming edges. This includes the edges on the header. @@ -266,17 +328,32 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,    }    for (BasicBlock *P : Predecessors) { -    auto *Branch = cast<BranchInst>(P->getTerminator()); -    BasicBlock *Succ0 = Branch->getSuccessor(0); -    Succ0 = C.contains(Succ0) ? Succ0 : nullptr; -    BasicBlock *Succ1 = -        Branch->isUnconditional() ? nullptr : Branch->getSuccessor(1); -    Succ1 = Succ1 && C.contains(Succ1) ? Succ1 : nullptr; -    CHub.addBranch(P, Succ0, Succ1); - -    LLVM_DEBUG(dbgs() << "Added external branch: " << P->getName() << " -> " -                      << (Succ0 ? Succ0->getName() : "") << " " -                      << (Succ1 ? Succ1->getName() : "") << "\n"); +    if (BranchInst *Branch = dyn_cast<BranchInst>(P->getTerminator()); Branch) { +      BasicBlock *Succ0 = Branch->getSuccessor(0); +      Succ0 = C.contains(Succ0) ? Succ0 : nullptr; +      BasicBlock *Succ1 = +          Branch->isUnconditional() ? nullptr : Branch->getSuccessor(1); +      Succ1 = Succ1 && C.contains(Succ1) ? Succ1 : nullptr; +      CHub.addBranch(P, Succ0, Succ1); + +      LLVM_DEBUG(dbgs() << "Added external branch: " << printBasicBlock(P) +                        << " -> " << printBasicBlock(Succ0) +                        << (Succ0 && Succ1 ? " " : "") << printBasicBlock(Succ1) +                        << '\n'); +    } else if (CallBrInst *CallBr = dyn_cast<CallBrInst>(P->getTerminator())) { +      for (unsigned I = 0; I < CallBr->getNumSuccessors(); ++I) { +        BasicBlock *Succ = CallBr->getSuccessor(I); +        if (!C.contains(Succ)) +          continue; +        BasicBlock *NewSucc = SplitCallBrEdge(P, Succ, I, &DTU, &CI, LI); +        CHub.addBranch(NewSucc, Succ); +        LLVM_DEBUG(dbgs() << "Added external branch: " +                          << printBasicBlock(NewSucc) << " -> " +                          << printBasicBlock(Succ) << '\n'); +      } +    } else { +      llvm_unreachable("unsupported block terminator"); +    }    }    // Redirect all the backedges through a "hub" consisting of a series @@ -292,7 +369,6 @@ static bool fixIrreducible(Cycle &C, CycleInfo &CI, DominatorTree &DT,    SetVector<BasicBlock *> Entries;    Entries.insert(C.entry_rbegin(), C.entry_rend()); -  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);    CHub.finalize(&DTU, GuardBlocks, "irr");  #if defined(EXPENSIVE_CHECKS)    assert(DT.verify(DominatorTree::VerificationLevel::Full)); @@ -325,8 +401,6 @@ static bool FixIrreducibleImpl(Function &F, CycleInfo &CI, DominatorTree &DT,    LLVM_DEBUG(dbgs() << "===== Fix irreducible control-flow in function: "                      << F.getName() << "\n"); -  assert(hasOnlySimpleTerminator(F) && "Unsupported block terminator."); -    bool Changed = false;    for (Cycle *TopCycle : CI.toplevel_cycles()) {      for (Cycle *C : depth_first(TopCycle)) { diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index 4fe736a..94dfd3a 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -499,9 +499,9 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,    const unsigned MaxTripCount = SE->getSmallConstantMaxTripCount(L);    const bool MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L); -  unsigned EstimatedLoopInvocationWeight = 0;    std::optional<unsigned> OriginalTripCount = -      llvm::getLoopEstimatedTripCount(L, &EstimatedLoopInvocationWeight); +      llvm::getLoopEstimatedTripCount(L); +  BranchProbability OriginalLoopProb = llvm::getLoopProbability(L);    // Effectively "DCE" unrolled iterations that are beyond the max tripcount    // and will never be executed. @@ -592,11 +592,11 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,                                                : isEpilogProfitable(L);    if (ULO.Runtime && -      !UnrollRuntimeLoopRemainder(L, ULO.Count, ULO.AllowExpensiveTripCount, -                                  EpilogProfitability, ULO.UnrollRemainder, -                                  ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI, -                                  PreserveLCSSA, ULO.SCEVExpansionBudget, -                                  ULO.RuntimeUnrollMultiExit, RemainderLoop)) { +      !UnrollRuntimeLoopRemainder( +          L, ULO.Count, ULO.AllowExpensiveTripCount, EpilogProfitability, +          ULO.UnrollRemainder, ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI, +          PreserveLCSSA, ULO.SCEVExpansionBudget, ULO.RuntimeUnrollMultiExit, +          RemainderLoop, OriginalTripCount, OriginalLoopProb)) {      if (ULO.Force)        ULO.Runtime = false;      else { @@ -1130,11 +1130,46 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,      LI->erase(L);      // We shouldn't try to use `L` anymore.      L = nullptr; -  } else if (OriginalTripCount) { -    // Update the trip count. Note that the remainder has already logic -    // computing it in `UnrollRuntimeLoopRemainder`. -    setLoopEstimatedTripCount(L, *OriginalTripCount / ULO.Count, -                              EstimatedLoopInvocationWeight); +  } else { +    // Update metadata for the loop's branch weights and estimated trip count: +    // - If ULO.Runtime, UnrollRuntimeLoopRemainder sets the guard branch +    //   weights, latch branch weights, and estimated trip count of the +    //   remainder loop it creates.  It also sets the branch weights for the +    //   unrolled loop guard it creates.  The branch weights for the unrolled +    //   loop latch are adjusted below.  FIXME: Handle prologue loops. +    // - Otherwise, if unrolled loop iteration latches become unconditional, +    //   branch weights are adjusted above.  FIXME: Actually handle such +    //   unconditional latches. +    // - Otherwise, the original loop's branch weights are correct for the +    //   unrolled loop, so do not adjust them. +    // - In all cases, the unrolled loop's estimated trip count is set below. +    // +    // As an example of the last case, consider what happens if the unroll count +    // is 4 for a loop with an estimated trip count of 10 when we do not create +    // a remainder loop and all iterations' latches remain conditional.  Each +    // unrolled iteration's latch still has the same probability of exiting the +    // loop as it did when in the original loop, and thus it should still have +    // the same branch weights.  Each unrolled iteration's non-zero probability +    // of exiting already appropriately reduces the probability of reaching the +    // remaining iterations just as it did in the original loop.  Trying to also +    // adjust the branch weights of the final unrolled iteration's latch (i.e., +    // the backedge for the unrolled loop as a whole) to reflect its new trip +    // count of 3 will erroneously further reduce its block frequencies. +    // However, in case an analysis later needs to estimate the trip count of +    // the unrolled loop as a whole without considering the branch weights for +    // each unrolled iteration's latch within it, we store the new trip count as +    // separate metadata. +    if (!OriginalLoopProb.isUnknown() && ULO.Runtime && EpilogProfitability) { +      // Where p is always the probability of executing at least 1 more +      // iteration, the probability for at least n more iterations is p^n. +      setLoopProbability(L, OriginalLoopProb.pow(ULO.Count)); +    } +    if (OriginalTripCount) { +      unsigned NewTripCount = *OriginalTripCount / ULO.Count; +      if (!ULO.Runtime && *OriginalTripCount % ULO.Count) +        ++NewTripCount; +      setLoopEstimatedTripCount(L, NewTripCount); +    }    }    // LoopInfo should not be valid, confirm that. diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 6312831..1e8f6cc 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -40,6 +40,7 @@  #include "llvm/Transforms/Utils/LoopUtils.h"  #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include "llvm/Transforms/Utils/UnrollLoop.h" +#include <cmath>  using namespace llvm; @@ -195,6 +196,21 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,    }  } +/// Assume, due to our position in the remainder loop or its guard, anywhere +/// from 0 to \p N more iterations can possibly execute.  Among such cases in +/// the original loop (with loop probability \p OriginalLoopProb), what is the +/// probability of executing at least one more iteration? +static BranchProbability +probOfNextInRemainder(BranchProbability OriginalLoopProb, unsigned N) { +  // Each of these variables holds the original loop's probability that the +  // number of iterations it will execute is some m in the specified range. +  BranchProbability ProbOne = OriginalLoopProb;                // 1 <= m +  BranchProbability ProbTooMany = ProbOne.pow(N + 1);          // N + 1 <= m +  BranchProbability ProbNotTooMany = ProbTooMany.getCompl();   // 0 <= m <= N +  BranchProbability ProbOneNotTooMany = ProbOne - ProbTooMany; // 1 <= m <= N +  return ProbOneNotTooMany / ProbNotTooMany; +} +  /// Connect the unrolling epilog code to the original loop.  /// The unrolling epilog code contains code to execute the  /// 'extra' iterations if the run-time trip count modulo the @@ -221,7 +237,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,                            BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,                            ValueToValueMapTy &VMap, DominatorTree *DT,                            LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE, -                          unsigned Count, AssumptionCache &AC) { +                          unsigned Count, AssumptionCache &AC, +                          BranchProbability OriginalLoopProb) {    BasicBlock *Latch = L->getLoopLatch();    assert(Latch && "Loop must have a latch");    BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]); @@ -332,12 +349,19 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,                           PreserveLCSSA);    // Add the branch to the exit block (around the epilog loop)    MDNode *BranchWeights = nullptr; -  if (hasBranchWeightMD(*Latch->getTerminator())) { +  if (OriginalLoopProb.isUnknown() && +      hasBranchWeightMD(*Latch->getTerminator())) {      // Assume equal distribution in interval [0, Count).      MDBuilder MDB(B.getContext());      BranchWeights = MDB.createBranchWeights(1, Count - 1);    } -  B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights); +  BranchInst *RemainderLoopGuard = +      B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights); +  if (!OriginalLoopProb.isUnknown()) { +    setBranchProbability(RemainderLoopGuard, +                         probOfNextInRemainder(OriginalLoopProb, Count - 1), +                         /*ForFirstTarget=*/true); +  }    InsertPt->eraseFromParent();    if (DT) {      auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit); @@ -357,14 +381,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,  /// The cloned blocks should be inserted between InsertTop and InsertBot.  /// InsertTop should be new preheader, InsertBot new loop exit.  /// Returns the new cloned loop that is created. -static Loop * -CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder, -                const bool UnrollRemainder, -                BasicBlock *InsertTop, -                BasicBlock *InsertBot, BasicBlock *Preheader, +static Loop *CloneLoopBlocks(Loop *L, Value *NewIter, +                             const bool UseEpilogRemainder, +                             const bool UnrollRemainder, BasicBlock *InsertTop, +                             BasicBlock *InsertBot, BasicBlock *Preheader,                               std::vector<BasicBlock *> &NewBlocks,                               LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, -                             DominatorTree *DT, LoopInfo *LI, unsigned Count) { +                             DominatorTree *DT, LoopInfo *LI, unsigned Count, +                             std::optional<unsigned> OriginalTripCount, +                             BranchProbability OriginalLoopProb) {    StringRef suffix = UseEpilogRemainder ? "epil" : "prol";    BasicBlock *Header = L->getHeader();    BasicBlock *Latch = L->getLoopLatch(); @@ -419,7 +444,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,            Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next");        Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp");        MDNode *BranchWeights = nullptr; -      if (hasBranchWeightMD(*LatchBR)) { +      if ((OriginalLoopProb.isUnknown() || !UseEpilogRemainder) && +          hasBranchWeightMD(*LatchBR)) {          uint32_t ExitWeight;          uint32_t BackEdgeWeight;          if (Count >= 3) { @@ -437,7 +463,29 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,          MDBuilder MDB(Builder.getContext());          BranchWeights = MDB.createBranchWeights(BackEdgeWeight, ExitWeight);        } -      Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights); +      BranchInst *RemainderLoopLatch = +          Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights); +      if (!OriginalLoopProb.isUnknown() && UseEpilogRemainder) { +        // Compute the total frequency of the original loop body from the +        // remainder iterations.  Once we've reached them, the first of them +        // always executes, so its frequency and probability are 1. +        double FreqRemIters = 1; +        if (Count > 2) { +          BranchProbability ProbReaching = BranchProbability::getOne(); +          for (unsigned N = Count - 2; N >= 1; --N) { +            ProbReaching *= probOfNextInRemainder(OriginalLoopProb, N); +            FreqRemIters += double(ProbReaching.getNumerator()) / +                            ProbReaching.getDenominator(); +          } +        } +        // Solve for the loop probability that would produce that frequency. +        // Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters. +        double ProbDouble = 1 - 1 / FreqRemIters; +        BranchProbability Prob = BranchProbability::getBranchProbability( +            std::round(ProbDouble * BranchProbability::getDenominator()), +            BranchProbability::getDenominator()); +        setBranchProbability(RemainderLoopLatch, Prob, /*ForFirstTarget=*/true); +      }        NewIdx->addIncoming(Zero, InsertTop);        NewIdx->addIncoming(IdxNext, NewBB);        LatchBR->eraseFromParent(); @@ -460,25 +508,13 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,    Loop *NewLoop = NewLoops[L];    assert(NewLoop && "L should have been cloned"); -  MDNode *LoopID = NewLoop->getLoopID(); -  // Only add loop metadata if the loop is not going to be completely -  // unrolled. -  if (UnrollRemainder) -    return NewLoop; - -  std::optional<MDNode *> NewLoopID = makeFollowupLoopID( -      LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); -  if (NewLoopID) { -    NewLoop->setLoopID(*NewLoopID); - -    // Do not setLoopAlreadyUnrolled if loop attributes have been defined -    // explicitly. -    return NewLoop; -  } +  if (OriginalTripCount && UseEpilogRemainder) +    setLoopEstimatedTripCount(NewLoop, *OriginalTripCount % Count);    // Add unroll disable metadata to disable future unrolling for this loop. -  NewLoop->setLoopAlreadyUnrolled(); +  if (!UnrollRemainder) +    NewLoop->setLoopAlreadyUnrolled();    return NewLoop;  } @@ -603,7 +639,8 @@ bool llvm::UnrollRuntimeLoopRemainder(      LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,      const TargetTransformInfo *TTI, bool PreserveLCSSA,      unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit, -    Loop **ResultLoop) { +    Loop **ResultLoop, std::optional<unsigned> OriginalTripCount, +    BranchProbability OriginalLoopProb) {    LLVM_DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n");    LLVM_DEBUG(L->dump());    LLVM_DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n" @@ -823,12 +860,23 @@ bool llvm::UnrollRuntimeLoopRemainder(    BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;    // Branch to either remainder (extra iterations) loop or unrolling loop.    MDNode *BranchWeights = nullptr; -  if (hasBranchWeightMD(*Latch->getTerminator())) { +  if ((OriginalLoopProb.isUnknown() || !UseEpilogRemainder) && +      hasBranchWeightMD(*Latch->getTerminator())) {      // Assume loop is nearly always entered.      MDBuilder MDB(B.getContext());      BranchWeights = MDB.createBranchWeights(EpilogHeaderWeights);    } -  B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights); +  BranchInst *UnrollingLoopGuard = +      B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights); +  if (!OriginalLoopProb.isUnknown() && UseEpilogRemainder) { +    // The original loop's first iteration always happens.  Compute the +    // probability of the original loop executing Count-1 iterations after that +    // to complete the first iteration of the unrolled loop. +    BranchProbability ProbOne = OriginalLoopProb; +    BranchProbability ProbRest = ProbOne.pow(Count - 1); +    setBranchProbability(UnrollingLoopGuard, ProbRest, +                         /*ForFirstTarget=*/false); +  }    PreHeaderBR->eraseFromParent();    if (DT) {      if (UseEpilogRemainder) @@ -855,9 +903,10 @@ bool llvm::UnrollRuntimeLoopRemainder(    // iterations. This function adds the appropriate CFG connections.    BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;    BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; -  Loop *remainderLoop = CloneLoopBlocks( -      L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot, -      NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count); +  Loop *remainderLoop = +      CloneLoopBlocks(L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, +                      InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, +                      LI, Count, OriginalTripCount, OriginalLoopProb);    // Insert the cloned blocks into the function.    F->splice(InsertBot->getIterator(), F, NewBlocks[0]->getIterator(), F->end()); @@ -956,7 +1005,8 @@ bool llvm::UnrollRuntimeLoopRemainder(      // Connect the epilog code to the original loop and update the      // PHI functions.      ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader, -                  NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC); +                  NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC, +                  OriginalLoopProb);      // Update counter in loop for unrolling.      // Use an incrementing IV.  Pre-incr/post-incr is backedge/trip count. diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index b6ba822..8be471b 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -962,13 +962,51 @@ bool llvm::setLoopEstimatedTripCount(    if (LatchBranch->getSuccessor(0) != L->getHeader())      std::swap(BackedgeTakenWeight, LatchExitWeight); -  MDBuilder MDB(LatchBranch->getContext()); -    // Set/Update profile metadata. -  LatchBranch->setMetadata( -      LLVMContext::MD_prof, -      MDB.createBranchWeights(BackedgeTakenWeight, LatchExitWeight)); +  setBranchWeights(*LatchBranch, {BackedgeTakenWeight, LatchExitWeight}, +                   /*IsExpected=*/false); + +  return true; +} + +BranchProbability llvm::getLoopProbability(Loop *L) { +  BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L); +  if (!LatchBranch) +    return BranchProbability::getUnknown(); +  bool FirstTargetIsLoop = LatchBranch->getSuccessor(0) == L->getHeader(); +  return getBranchProbability(LatchBranch, FirstTargetIsLoop); +} +bool llvm::setLoopProbability(Loop *L, BranchProbability P) { +  BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L); +  if (!LatchBranch) +    return false; +  bool FirstTargetIsLoop = LatchBranch->getSuccessor(0) == L->getHeader(); +  return setBranchProbability(LatchBranch, P, FirstTargetIsLoop); +} + +BranchProbability llvm::getBranchProbability(BranchInst *B, +                                             bool ForFirstTarget) { +  if (B->getNumSuccessors() != 2) +    return BranchProbability::getUnknown(); +  uint64_t Weight0, Weight1; +  if (!extractBranchWeights(*B, Weight0, Weight1)) +    return BranchProbability::getUnknown(); +  if (!ForFirstTarget) +    std::swap(Weight0, Weight1); +  return BranchProbability::getBranchProbability(Weight0, Weight0 + Weight1); +} + +bool llvm::setBranchProbability(BranchInst *B, BranchProbability P, +                                bool ForFirstTarget) { +  if (B->getNumSuccessors() != 2) +    return false; +  BranchProbability Prob0 = P; +  BranchProbability Prob1 = P.getCompl(); +  if (!ForFirstTarget) +    std::swap(Prob0, Prob1); +  setBranchWeights(*B, {Prob0.getNumerator(), Prob1.getNumerator()}, +                   /*IsExpected=*/false);    return true;  } diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp index a9ab3b3..27fed73 100644 --- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -809,7 +809,6 @@ public:    void emitInstructionAnnot(const Instruction *I,                              formatted_raw_ostream &OS) override {      if (const auto *PI = PredInfo->getPredicateInfoFor(I)) { -      OS << "; Has predicate info\n";        if (const auto *PB = dyn_cast<PredicateBranch>(PI)) {          OS << "; branch predicate info { TrueEdge: " << PB->TrueEdge             << " Comparison:" << *PB->Condition << " Edge: ["; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 4fac5d3..7f6d779 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1866,10 +1866,19 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(Instruction *TI,    // If either of the blocks has it's address taken, then we can't do this fold,    // because the code we'd hoist would no longer run when we jump into the block    // by it's address. -  for (auto *Succ : successors(BB)) -    if (Succ->hasAddressTaken() || !Succ->getSinglePredecessor()) +  for (auto *Succ : successors(BB)) { +    if (Succ->hasAddressTaken())        return false; - +    if (Succ->getSinglePredecessor()) +      continue; +    // If Succ has >1 predecessors, continue to check if the Succ contains only +    // one `unreachable` inst. Since executing `unreachable` inst is an UB, we +    // can relax the condition based on the assumptiom that the program would +    // never enter Succ and trigger such an UB. +    if (isa<UnreachableInst>(*Succ->begin())) +      continue; +    return false; +  }    // The second of pair is a SkipFlags bitmask.    using SuccIterPair = std::pair<BasicBlock::iterator, unsigned>;    SmallVector<SuccIterPair, 8> SuccIterPairs; @@ -5968,14 +5977,14 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,    }    // Prune obsolete incoming values off the successors' PHI nodes. -  for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) { +  for (auto &PHI : make_early_inc_range(Dest->phis())) {      unsigned PreviousEdges = Cases->size();      if (Dest == SI->getDefaultDest())        ++PreviousEdges;      for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) -      cast<PHINode>(BBI)->removeIncomingValue(SI->getParent()); +      PHI.removeIncomingValue(SI->getParent());    } -  for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) { +  for (auto &PHI : make_early_inc_range(OtherDest->phis())) {      unsigned PreviousEdges = OtherCases->size();      if (OtherDest == SI->getDefaultDest())        ++PreviousEdges; @@ -5984,7 +5993,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,      if (NewBI->isUnconditional())        ++E;      for (unsigned I = 0; I != E; ++I) -      cast<PHINode>(BBI)->removeIncomingValue(SI->getParent()); +      PHI.removeIncomingValue(SI->getParent());    }    // Clean up the default block - it may have phis or other instructions before @@ -7623,7 +7632,9 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,      auto *DefaultCaseBB = SI->getDefaultDest();      BasicBlock *SplitBB = SplitBlock(OrigBB, SI, DTU);      auto It = OrigBB->getTerminator()->getIterator(); -    BranchInst::Create(SplitBB, DefaultCaseBB, IsPow2, It); +    auto *BI = BranchInst::Create(SplitBB, DefaultCaseBB, IsPow2, It); +    // BI is handling the default case for SI, and so should share its DebugLoc. +    BI->setDebugLoc(SI->getDebugLoc());      It->eraseFromParent();      addPredecessorToBlock(DefaultCaseBB, OrigBB, SplitBB); diff --git a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp index 9f338db..94c5c170 100644 --- a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp +++ b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp @@ -12,7 +12,11 @@  //  // Limitation: This assumes that all terminators in the CFG are direct branches  //             (the "br" instruction). The presence of any other control flow -//             such as indirectbr, switch or callbr will cause an assert. +//             such as indirectbr or switch will cause an assert. +//             The callbr terminator is supported by creating intermediate +//             target blocks that unconditionally branch to the original target +//             blocks. These intermediate target blocks can then be redirected +//             through the ControlFlowHub as usual.  //  //===----------------------------------------------------------------------===// @@ -150,25 +154,55 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) {    SmallVector<BasicBlock *, 8> ExitingBlocks;    L->getExitingBlocks(ExitingBlocks); +  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); +  SmallVector<BasicBlock *, 8> CallBrTargetBlocksToFix;    // Redirect exiting edges through a control flow hub.    ControlFlowHub CHub; -  for (auto *BB : ExitingBlocks) { -    auto *Branch = cast<BranchInst>(BB->getTerminator()); -    BasicBlock *Succ0 = Branch->getSuccessor(0); -    Succ0 = L->contains(Succ0) ? nullptr : Succ0; - -    BasicBlock *Succ1 = -        Branch->isUnconditional() ? nullptr : Branch->getSuccessor(1); -    Succ1 = L->contains(Succ1) ? nullptr : Succ1; -    CHub.addBranch(BB, Succ0, Succ1); - -    LLVM_DEBUG(dbgs() << "Added exiting branch: " << BB->getName() << " -> {" -                      << (Succ0 ? Succ0->getName() : "<none>") << ", " -                      << (Succ1 ? Succ1->getName() : "<none>") << "}\n"); + +  for (unsigned I = 0; I < ExitingBlocks.size(); ++I) { +    BasicBlock *BB = ExitingBlocks[I]; +    if (BranchInst *Branch = dyn_cast<BranchInst>(BB->getTerminator())) { +      BasicBlock *Succ0 = Branch->getSuccessor(0); +      Succ0 = L->contains(Succ0) ? nullptr : Succ0; + +      BasicBlock *Succ1 = +          Branch->isUnconditional() ? nullptr : Branch->getSuccessor(1); +      Succ1 = L->contains(Succ1) ? nullptr : Succ1; +      CHub.addBranch(BB, Succ0, Succ1); + +      LLVM_DEBUG(dbgs() << "Added extiting branch: " << printBasicBlock(BB) +                        << " -> " << printBasicBlock(Succ0) +                        << (Succ0 && Succ1 ? " " : "") << printBasicBlock(Succ1) +                        << '\n'); +    } else if (CallBrInst *CallBr = dyn_cast<CallBrInst>(BB->getTerminator())) { +      for (unsigned J = 0; J < CallBr->getNumSuccessors(); ++J) { +        BasicBlock *Succ = CallBr->getSuccessor(J); +        if (L->contains(Succ)) +          continue; +        bool UpdatedLI = false; +        BasicBlock *NewSucc = +            SplitCallBrEdge(BB, Succ, J, &DTU, nullptr, &LI, &UpdatedLI); +        // Even if CallBr and Succ do not have a common parent loop, we need to +        // add the new target block to the parent loop of the current loop. +        if (!UpdatedLI) +          CallBrTargetBlocksToFix.push_back(NewSucc); +        // ExitingBlocks is later used to restore SSA, so we need to make sure +        // that the blocks used for phi nodes in the guard blocks match the +        // predecessors of the guard blocks, which, in the case of callbr, are +        // the new intermediate target blocks instead of the callbr blocks +        // themselves. +        ExitingBlocks[I] = NewSucc; +        CHub.addBranch(NewSucc, Succ); +        LLVM_DEBUG(dbgs() << "Added exiting branch: " +                          << printBasicBlock(NewSucc) << " -> " +                          << printBasicBlock(Succ) << '\n'); +      } +    } else { +      llvm_unreachable("unsupported block terminator"); +    }    }    SmallVector<BasicBlock *, 8> GuardBlocks; -  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);    BasicBlock *LoopExitBlock;    bool ChangedCFG;    std::tie(LoopExitBlock, ChangedCFG) = CHub.finalize( @@ -187,10 +221,19 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) {    // The guard blocks were created outside the loop, so they need to become    // members of the parent loop. -  if (auto ParentLoop = L->getParentLoop()) { +  // Same goes for the callbr target blocks.  Although we try to add them to the +  // smallest common parent loop of the callbr block and the corresponding +  // original target block, there might not have been such a loop, in which case +  // the newly created callbr target blocks are not part of any loop. For nested +  // loops, this might result in them leading to a loop with multiple entry +  // points. +  if (auto *ParentLoop = L->getParentLoop()) {      for (auto *G : GuardBlocks) {        ParentLoop->addBasicBlockToLoop(G, LI);      } +    for (auto *C : CallBrTargetBlocksToFix) { +      ParentLoop->addBasicBlockToLoop(C, LI); +    }      ParentLoop->verifyLoop();    } @@ -218,8 +261,6 @@ bool UnifyLoopExitsLegacyPass::runOnFunction(Function &F) {    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); -  assert(hasOnlySimpleTerminator(F) && "Unsupported block terminator."); -    return runImpl(LI, DT);  } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index f7968ab..25bf49d 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -3908,7 +3908,7 @@ void LoopVectorizationPlanner::emitInvalidCostRemarks(          continue;        VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, -                            *CM.PSE.getSE()); +                            *CM.PSE.getSE(), OrigLoop);        precomputeCosts(*Plan, VF, CostCtx);        auto Iter = vp_depth_first_deep(Plan->getVectorLoopRegion()->getEntry());        for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { @@ -4166,7 +4166,7 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() {        // Add on other costs that are modelled in VPlan, but not in the legacy        // cost model.        VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind, -                            *CM.PSE.getSE()); +                            *CM.PSE.getSE(), OrigLoop);        VPRegionBlock *VectorRegion = P->getVectorLoopRegion();        assert(VectorRegion && "Expected to have a vector region!");        for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( @@ -5750,13 +5750,18 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) {               getMemoryInstructionCost(I, ElementCount::getFixed(1))));          UpdateMemOpUserCost(cast<LoadInst>(I));        } else if (const auto *Group = getInterleavedAccessGroup(I)) { -        // Scalarize an interleave group of address loads. -        for (unsigned I = 0; I < Group->getFactor(); ++I) { -          if (Instruction *Member = Group->getMember(I)) { -            setWideningDecision( -                Member, VF, CM_Scalarize, -                (VF.getKnownMinValue() * -                 getMemoryInstructionCost(Member, ElementCount::getFixed(1)))); +        // Scalarize all members of this interleaved group when any member +        // is used as an address. The address-used load skips scalarization +        // overhead, other members include it. +        for (unsigned Idx = 0; Idx < Group->getFactor(); ++Idx) { +          if (Instruction *Member = Group->getMember(Idx)) { +            InstructionCost Cost = +                AddrDefs.contains(Member) +                    ? (VF.getKnownMinValue() * +                       getMemoryInstructionCost(Member, +                                                ElementCount::getFixed(1))) +                    : getMemInstScalarizationCost(Member, VF); +            setWideningDecision(Member, VF, CM_Scalarize, Cost);              UpdateMemOpUserCost(cast<LoadInst>(Member));            }          } @@ -6871,7 +6876,8 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,  InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan,                                                 ElementCount VF) const { -  VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind, *PSE.getSE()); +  VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind, *PSE.getSE(), +                        OrigLoop);    InstructionCost Cost = precomputeCosts(Plan, VF, CostCtx);    // Now compute and add the VPlan-based cost. @@ -7105,12 +7111,13 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {    // case, don't trigger the assertion, as the extra simplifications may cause a    // different VF to be picked by the VPlan-based cost model.    VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind, -                        *CM.PSE.getSE()); +                        *CM.PSE.getSE(), OrigLoop);    precomputeCosts(BestPlan, BestFactor.Width, CostCtx);    // Verify that the VPlan-based and legacy cost models agree, except for VPlans    // with early exits and plans with additional VPlan simplifications. The    // legacy cost model doesn't properly model costs for such loops.    assert((BestFactor.Width == LegacyVF.Width || BestPlan.hasEarlyExit() || +          !Legal->getLAI()->getSymbolicStrides().empty() ||            planContainsAdditionalSimplifications(getPlanFor(BestFactor.Width),                                                  CostCtx, OrigLoop,                                                  BestFactor.Width) || @@ -8335,11 +8342,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(                &R) ||            (isa<VPInstruction>(&R) && !UnderlyingValue))          continue; - -      // FIXME: VPlan0, which models a copy of the original scalar loop, should -      // not use VPWidenPHIRecipe to model the phis. -      assert((isa<VPWidenPHIRecipe>(&R) || isa<VPInstruction>(&R)) && -             UnderlyingValue && "unsupported recipe"); +      assert(isa<VPInstruction>(&R) && UnderlyingValue && "unsupported recipe");        // TODO: Gradually replace uses of underlying instruction by analyses on        // VPlan. @@ -8440,7 +8443,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(    // and mulacc-reduction are implemented.    if (!CM.foldTailWithEVL()) {      VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, -                          *CM.PSE.getSE()); +                          *CM.PSE.getSE(), OrigLoop);      VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,                               CostCtx, Range);    } @@ -9910,7 +9913,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {      bool ForceVectorization =          Hints.getForce() == LoopVectorizeHints::FK_Enabled;      VPCostContext CostCtx(CM.TTI, *CM.TLI, LVP.getPlanFor(VF.Width), CM, -                          CM.CostKind, *CM.PSE.getSE()); +                          CM.CostKind, *CM.PSE.getSE(), L);      if (!ForceVectorization &&          !isOutsideLoopWorkProfitable(Checks, VF, L, PSE, CostCtx,                                       LVP.getPlanFor(VF.Width), SEL, diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 43166c0..1b55a3b 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -16920,7 +16920,10 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(        // otherwise TEPtr depends on TE.        if ((TEInsertBlock != InsertPt->getParent() ||             TEUseEI.EdgeIdx < UseEI.EdgeIdx || TEUseEI.UserTE != UseEI.UserTE) && -          !CheckOrdering(InsertPt)) +          (!CheckOrdering(InsertPt) || +           (UseEI.UserTE->hasCopyableElements() && +            isUsedOutsideBlock(const_cast<Instruction *>(TEInsertPt)) && +            is_contained(UseEI.UserTE->Scalars, TEInsertPt))))          continue;        // The node is reused - exit.        if (CheckAndUseSameNode(TEPtr)) diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index 2aaabd9..965426f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -350,13 +350,14 @@ struct VPCostContext {    SmallPtrSet<Instruction *, 8> SkipCostComputation;    TargetTransformInfo::TargetCostKind CostKind;    ScalarEvolution &SE; +  const Loop *L;    VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI,                  const VPlan &Plan, LoopVectorizationCostModel &CM,                  TargetTransformInfo::TargetCostKind CostKind, -                ScalarEvolution &SE) +                ScalarEvolution &SE, const Loop *L)        : TTI(TTI), TLI(TLI), Types(Plan), LLVMCtx(Plan.getContext()), CM(CM), -        CostKind(CostKind), SE(SE) {} +        CostKind(CostKind), SE(SE), L(L) {}    /// Return the cost for \p UI with \p VF using the legacy cost model as    /// fallback until computing the cost of all recipes migrates to VPlan. diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 9a63c80..bde62dd 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -3167,26 +3167,30 @@ bool VPReplicateRecipe::shouldPack() const {    });  } -/// Returns true if \p Ptr is a pointer computation for which the legacy cost -/// model computes a SCEV expression when computing the address cost. -static bool shouldUseAddressAccessSCEV(const VPValue *Ptr) { +/// Returns a SCEV expression for \p Ptr if it is a pointer computation for +/// which the legacy cost model computes a SCEV expression when computing the +/// address cost. Computing SCEVs for VPValues is incomplete and returns +/// SCEVCouldNotCompute in cases the legacy cost model can compute SCEVs. In +/// those cases we fall back to the legacy cost model. Otherwise return nullptr. +static const SCEV *getAddressAccessSCEV(const VPValue *Ptr, ScalarEvolution &SE, +                                        const Loop *L) {    auto *PtrR = Ptr->getDefiningRecipe();    if (!PtrR || !((isa<VPReplicateRecipe>(PtrR) &&                    cast<VPReplicateRecipe>(PtrR)->getOpcode() ==                        Instruction::GetElementPtr) ||                   isa<VPWidenGEPRecipe>(PtrR) ||                   match(Ptr, m_GetElementPtr(m_VPValue(), m_VPValue())))) -    return false; +    return nullptr;    // We are looking for a GEP where all indices are either loop invariant or    // inductions.    for (VPValue *Opd : drop_begin(PtrR->operands())) {      if (!Opd->isDefinedOutsideLoopRegions() &&          !isa<VPScalarIVStepsRecipe, VPWidenIntOrFpInductionRecipe>(Opd)) -      return false; +      return nullptr;    } -  return true; +  return vputils::getSCEVExprForVPValue(Ptr, SE, L);  }  /// Returns true if \p V is used as part of the address of another load or @@ -3354,9 +3358,8 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,      bool IsLoad = UI->getOpcode() == Instruction::Load;      const VPValue *PtrOp = getOperand(!IsLoad); -    // TODO: Handle cases where we need to pass a SCEV to -    // getAddressComputationCost. -    if (shouldUseAddressAccessSCEV(PtrOp)) +    const SCEV *PtrSCEV = getAddressAccessSCEV(PtrOp, Ctx.SE, Ctx.L); +    if (isa_and_nonnull<SCEVCouldNotCompute>(PtrSCEV))        break;      Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0)); @@ -3374,7 +3377,7 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,      InstructionCost ScalarCost =          ScalarMemOpCost + Ctx.TTI.getAddressComputationCost(                                PtrTy, UsedByLoadStoreAddress ? nullptr : &Ctx.SE, -                              nullptr, Ctx.CostKind); +                              PtrSCEV, Ctx.CostKind);      if (isSingleScalar())        return ScalarCost; diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index d9ac26bb..986c801 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1419,6 +1419,8 @@ static void narrowToSingleScalarRecipes(VPlan &Plan) {                                            true /*IsSingleScalar*/);        Clone->insertBefore(RepOrWidenR);        RepOrWidenR->replaceAllUsesWith(Clone); +      if (isDeadRecipe(*RepOrWidenR)) +        RepOrWidenR->eraseFromParent();      }    }  } @@ -4062,7 +4064,7 @@ void VPlanTransforms::materializeVFAndVFxUF(VPlan &Plan, VPBasicBlock *VectorPH,  DenseMap<const SCEV *, Value *>  VPlanTransforms::expandSCEVs(VPlan &Plan, ScalarEvolution &SE) {    const DataLayout &DL = SE.getDataLayout(); -  SCEVExpander Expander(SE, DL, "induction", /*PreserveLCSSA=*/true); +  SCEVExpander Expander(SE, DL, "induction", /*PreserveLCSSA=*/false);    auto *Entry = cast<VPIRBasicBlock>(Plan.getEntry());    BasicBlock *EntryBB = Entry->getIRBasicBlock(); diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 4db92e7..8c23e78 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -75,7 +75,8 @@ bool vputils::isHeaderMask(const VPValue *V, const VPlan &Plan) {           B == Plan.getBackedgeTakenCount();  } -const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) { +const SCEV *vputils::getSCEVExprForVPValue(const VPValue *V, +                                           ScalarEvolution &SE, const Loop *L) {    if (V->isLiveIn()) {      if (Value *LiveIn = V->getLiveInIRValue())        return SE.getSCEV(LiveIn); @@ -86,6 +87,53 @@ const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {    return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())        .Case<VPExpandSCEVRecipe>(            [](const VPExpandSCEVRecipe *R) { return R->getSCEV(); }) +      .Case<VPCanonicalIVPHIRecipe>([&SE, L](const VPCanonicalIVPHIRecipe *R) { +        if (!L) +          return SE.getCouldNotCompute(); +        const SCEV *Start = getSCEVExprForVPValue(R->getOperand(0), SE, L); +        return SE.getAddRecExpr(Start, SE.getOne(Start->getType()), L, +                                SCEV::FlagAnyWrap); +      }) +      .Case<VPDerivedIVRecipe>([&SE, L](const VPDerivedIVRecipe *R) { +        const SCEV *Start = getSCEVExprForVPValue(R->getOperand(0), SE, L); +        const SCEV *IV = getSCEVExprForVPValue(R->getOperand(1), SE, L); +        const SCEV *Scale = getSCEVExprForVPValue(R->getOperand(2), SE, L); +        if (any_of(ArrayRef({Start, IV, Scale}), IsaPred<SCEVCouldNotCompute>)) +          return SE.getCouldNotCompute(); + +        return SE.getAddExpr(SE.getTruncateOrSignExtend(Start, IV->getType()), +                             SE.getMulExpr(IV, SE.getTruncateOrSignExtend( +                                                   Scale, IV->getType()))); +      }) +      .Case<VPScalarIVStepsRecipe>([&SE, L](const VPScalarIVStepsRecipe *R) { +        const SCEV *IV = getSCEVExprForVPValue(R->getOperand(0), SE, L); +        const SCEV *Step = getSCEVExprForVPValue(R->getOperand(1), SE, L); +        if (isa<SCEVCouldNotCompute>(IV) || isa<SCEVCouldNotCompute>(Step) || +            !Step->isOne()) +          return SE.getCouldNotCompute(); +        return SE.getMulExpr(SE.getTruncateOrSignExtend(IV, Step->getType()), +                             Step); +      }) +      .Case<VPReplicateRecipe>([&SE, L](const VPReplicateRecipe *R) { +        if (R->getOpcode() != Instruction::GetElementPtr) +          return SE.getCouldNotCompute(); + +        const SCEV *Base = getSCEVExprForVPValue(R->getOperand(0), SE, L); +        if (isa<SCEVCouldNotCompute>(Base)) +          return SE.getCouldNotCompute(); + +        SmallVector<const SCEV *> IndexExprs; +        for (VPValue *Index : drop_begin(R->operands())) { +          const SCEV *IndexExpr = getSCEVExprForVPValue(Index, SE, L); +          if (isa<SCEVCouldNotCompute>(IndexExpr)) +            return SE.getCouldNotCompute(); +          IndexExprs.push_back(IndexExpr); +        } + +        Type *SrcElementTy = cast<GetElementPtrInst>(R->getUnderlyingInstr()) +                                 ->getSourceElementType(); +        return SE.getGEPExpr(Base, IndexExprs, SrcElementTy); +      })        .Default([&SE](const VPRecipeBase *) { return SE.getCouldNotCompute(); });  } diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h index 37cd413..c21a0e7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h @@ -37,7 +37,8 @@ VPValue *getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr);  /// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no  /// SCEV expression could be constructed. -const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE); +const SCEV *getSCEVExprForVPValue(const VPValue *V, ScalarEvolution &SE, +                                  const Loop *L = nullptr);  /// Returns true if \p VPV is a single scalar, either because it produces the  /// same value for all lanes or only has its first lane used. | 
