diff options
Diffstat (limited to 'clang/lib/CodeGen/CGOpenMPRuntime.cpp')
-rw-r--r-- | clang/lib/CodeGen/CGOpenMPRuntime.cpp | 391 |
1 files changed, 391 insertions, 0 deletions
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index a503aaf6..c90e1a4 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6799,6 +6799,240 @@ LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE(); // code for that information. class MappableExprsHandler { public: + /// Custom comparator for attach-pointer expressions that compares them by + /// complexity (i.e. their component-depth) first, then by the order in which + /// they were computed by collectAttachPtrExprInfo(), if they are semantically + /// different. + struct AttachPtrExprComparator { + const MappableExprsHandler *Handler = nullptr; + // Cache of previous equality comparison results. + mutable llvm::DenseMap<std::pair<const Expr *, const Expr *>, bool> + CachedEqualityComparisons; + + AttachPtrExprComparator(const MappableExprsHandler *H) : Handler(H) {} + + // Return true iff LHS is "less than" RHS. + bool operator()(const Expr *LHS, const Expr *RHS) const { + if (LHS == RHS) + return false; + + // First, compare by complexity (depth) + const auto ItLHS = Handler->AttachPtrComponentDepthMap.find(LHS); + const auto ItRHS = Handler->AttachPtrComponentDepthMap.find(RHS); + + std::optional<size_t> DepthLHS = + (ItLHS != Handler->AttachPtrComponentDepthMap.end()) ? ItLHS->second + : std::nullopt; + std::optional<size_t> DepthRHS = + (ItRHS != Handler->AttachPtrComponentDepthMap.end()) ? ItRHS->second + : std::nullopt; + + // std::nullopt (no attach pointer) has lowest complexity + if (!DepthLHS.has_value() && !DepthRHS.has_value()) { + // Both have same complexity, now check semantic equality + if (areEqual(LHS, RHS)) + return false; + // Different semantically, compare by computation order + return wasComputedBefore(LHS, RHS); + } + if (!DepthLHS.has_value()) + return true; // LHS has lower complexity + if (!DepthRHS.has_value()) + return false; // RHS has lower complexity + + // Both have values, compare by depth (lower depth = lower complexity) + if (DepthLHS.value() != DepthRHS.value()) + return DepthLHS.value() < DepthRHS.value(); + + // Same complexity, now check semantic equality + if (areEqual(LHS, RHS)) + return false; + // Different semantically, compare by computation order + return wasComputedBefore(LHS, RHS); + } + + public: + /// Return true if \p LHS and \p RHS are semantically equal. Uses pre-cached + /// results, if available, otherwise does a recursive semantic comparison. + bool areEqual(const Expr *LHS, const Expr *RHS) const { + // Check cache first for faster lookup + const auto CachedResultIt = CachedEqualityComparisons.find({LHS, RHS}); + if (CachedResultIt != CachedEqualityComparisons.end()) + return CachedResultIt->second; + + bool ComparisonResult = areSemanticallyEqual(LHS, RHS); + + // Cache the result for future lookups (both orders since semantic + // equality is commutative) + CachedEqualityComparisons[{LHS, RHS}] = ComparisonResult; + CachedEqualityComparisons[{RHS, LHS}] = ComparisonResult; + return ComparisonResult; + } + + /// Compare the two attach-ptr expressions by their computation order. + /// Returns true iff LHS was computed before RHS by + /// collectAttachPtrExprInfo(). + bool wasComputedBefore(const Expr *LHS, const Expr *RHS) const { + const size_t &OrderLHS = Handler->AttachPtrComputationOrderMap.at(LHS); + const size_t &OrderRHS = Handler->AttachPtrComputationOrderMap.at(RHS); + + return OrderLHS < OrderRHS; + } + + private: + /// Helper function to compare attach-pointer expressions semantically. + /// This function handles various expression types that can be part of an + /// attach-pointer. + /// TODO: Not urgent, but we should ideally return true when comparing + /// `p[10]`, `*(p + 10)`, `*(p + 5 + 5)`, `p[10:1]` etc. + bool areSemanticallyEqual(const Expr *LHS, const Expr *RHS) const { + if (LHS == RHS) + return true; + + // If only one is null, they aren't equal + if (!LHS || !RHS) + return false; + + ASTContext &Ctx = Handler->CGF.getContext(); + // Strip away parentheses and no-op casts to get to the core expression + LHS = LHS->IgnoreParenNoopCasts(Ctx); + RHS = RHS->IgnoreParenNoopCasts(Ctx); + + // Direct pointer comparison of the underlying expressions + if (LHS == RHS) + return true; + + // Check if the expression classes match + if (LHS->getStmtClass() != RHS->getStmtClass()) + return false; + + // Handle DeclRefExpr (variable references) + if (const auto *LD = dyn_cast<DeclRefExpr>(LHS)) { + const auto *RD = dyn_cast<DeclRefExpr>(RHS); + if (!RD) + return false; + return LD->getDecl()->getCanonicalDecl() == + RD->getDecl()->getCanonicalDecl(); + } + + // Handle ArraySubscriptExpr (array indexing like a[i]) + if (const auto *LA = dyn_cast<ArraySubscriptExpr>(LHS)) { + const auto *RA = dyn_cast<ArraySubscriptExpr>(RHS); + if (!RA) + return false; + return areSemanticallyEqual(LA->getBase(), RA->getBase()) && + areSemanticallyEqual(LA->getIdx(), RA->getIdx()); + } + + // Handle MemberExpr (member access like s.m or p->m) + if (const auto *LM = dyn_cast<MemberExpr>(LHS)) { + const auto *RM = dyn_cast<MemberExpr>(RHS); + if (!RM) + return false; + if (LM->getMemberDecl()->getCanonicalDecl() != + RM->getMemberDecl()->getCanonicalDecl()) + return false; + return areSemanticallyEqual(LM->getBase(), RM->getBase()); + } + + // Handle UnaryOperator (unary operations like *p, &x, etc.) + if (const auto *LU = dyn_cast<UnaryOperator>(LHS)) { + const auto *RU = dyn_cast<UnaryOperator>(RHS); + if (!RU) + return false; + if (LU->getOpcode() != RU->getOpcode()) + return false; + return areSemanticallyEqual(LU->getSubExpr(), RU->getSubExpr()); + } + + // Handle BinaryOperator (binary operations like p + offset) + if (const auto *LB = dyn_cast<BinaryOperator>(LHS)) { + const auto *RB = dyn_cast<BinaryOperator>(RHS); + if (!RB) + return false; + if (LB->getOpcode() != RB->getOpcode()) + return false; + return areSemanticallyEqual(LB->getLHS(), RB->getLHS()) && + areSemanticallyEqual(LB->getRHS(), RB->getRHS()); + } + + // Handle ArraySectionExpr (array sections like a[0:1]) + // Attach pointers should not contain array-sections, but currently we + // don't emit an error. + if (const auto *LAS = dyn_cast<ArraySectionExpr>(LHS)) { + const auto *RAS = dyn_cast<ArraySectionExpr>(RHS); + if (!RAS) + return false; + return areSemanticallyEqual(LAS->getBase(), RAS->getBase()) && + areSemanticallyEqual(LAS->getLowerBound(), + RAS->getLowerBound()) && + areSemanticallyEqual(LAS->getLength(), RAS->getLength()); + } + + // Handle CastExpr (explicit casts) + if (const auto *LC = dyn_cast<CastExpr>(LHS)) { + const auto *RC = dyn_cast<CastExpr>(RHS); + if (!RC) + return false; + if (LC->getCastKind() != RC->getCastKind()) + return false; + return areSemanticallyEqual(LC->getSubExpr(), RC->getSubExpr()); + } + + // Handle CXXThisExpr (this pointer) + if (isa<CXXThisExpr>(LHS) && isa<CXXThisExpr>(RHS)) + return true; + + // Handle IntegerLiteral (integer constants) + if (const auto *LI = dyn_cast<IntegerLiteral>(LHS)) { + const auto *RI = dyn_cast<IntegerLiteral>(RHS); + if (!RI) + return false; + return LI->getValue() == RI->getValue(); + } + + // Handle CharacterLiteral (character constants) + if (const auto *LC = dyn_cast<CharacterLiteral>(LHS)) { + const auto *RC = dyn_cast<CharacterLiteral>(RHS); + if (!RC) + return false; + return LC->getValue() == RC->getValue(); + } + + // Handle FloatingLiteral (floating point constants) + if (const auto *LF = dyn_cast<FloatingLiteral>(LHS)) { + const auto *RF = dyn_cast<FloatingLiteral>(RHS); + if (!RF) + return false; + // Use bitwise comparison for floating point literals + return LF->getValue().bitwiseIsEqual(RF->getValue()); + } + + // Handle StringLiteral (string constants) + if (const auto *LS = dyn_cast<StringLiteral>(LHS)) { + const auto *RS = dyn_cast<StringLiteral>(RHS); + if (!RS) + return false; + return LS->getString() == RS->getString(); + } + + // Handle CXXNullPtrLiteralExpr (nullptr) + if (isa<CXXNullPtrLiteralExpr>(LHS) && isa<CXXNullPtrLiteralExpr>(RHS)) + return true; + + // Handle CXXBoolLiteralExpr (true/false) + if (const auto *LB = dyn_cast<CXXBoolLiteralExpr>(LHS)) { + const auto *RB = dyn_cast<CXXBoolLiteralExpr>(RHS); + if (!RB) + return false; + return LB->getValue() == RB->getValue(); + } + + // Fallback for other forms - use the existing comparison method + return Expr::isSameComparisonOperand(LHS, RHS); + } + }; + /// Get the offset of the OMP_MAP_MEMBER_OF field. static unsigned getFlagMemberOffset() { unsigned Offset = 0; @@ -6876,6 +7110,45 @@ public: bool HasCompleteRecord = false; }; + /// A struct to store the attach pointer and pointee information, to be used + /// when emitting an attach entry. + struct AttachInfoTy { + Address AttachPtrAddr = Address::invalid(); + Address AttachPteeAddr = Address::invalid(); + const ValueDecl *AttachPtrDecl = nullptr; + const Expr *AttachMapExpr = nullptr; + + bool isValid() const { + return AttachPtrAddr.isValid() && AttachPteeAddr.isValid(); + } + }; + + /// Check if there's any component list where the attach pointer expression + /// matches the given captured variable. + bool hasAttachEntryForCapturedVar(const ValueDecl *VD) const { + for (const auto &AttachEntry : AttachPtrExprMap) { + if (AttachEntry.second) { + // Check if the attach pointer expression is a DeclRefExpr that + // references the captured variable + if (const auto *DRE = dyn_cast<DeclRefExpr>(AttachEntry.second)) + if (DRE->getDecl() == VD) + return true; + } + } + return false; + } + + /// Get the previously-cached attach pointer for a component list, if-any. + const Expr *getAttachPtrExpr( + OMPClauseMappableExprCommon::MappableExprComponentListRef Components) + const { + const auto It = AttachPtrExprMap.find(Components); + if (It != AttachPtrExprMap.end()) + return It->second; + + return nullptr; + } + private: /// Kind that defines how a device pointer has to be returned. struct MapInfo { @@ -6948,6 +7221,27 @@ private: /// Map between lambda declarations and their map type. llvm::DenseMap<const ValueDecl *, const OMPMapClause *> LambdasMap; + /// Map from component lists to their attach pointer expressions. + llvm::DenseMap<OMPClauseMappableExprCommon::MappableExprComponentListRef, + const Expr *> + AttachPtrExprMap; + + /// Map from attach pointer expressions to their component depth. + /// nullptr key has std::nullopt depth. This can be used to order attach-ptr + /// expressions with increasing/decreasing depth. + /// The component-depth of `nullptr` (i.e. no attach-ptr) is `std::nullopt`. + /// TODO: Not urgent, but we should ideally use the number of pointer + /// dereferences in an expr as an indicator of its complexity, instead of the + /// component-depth. That would be needed for us to treat `p[1]`, `*(p + 10)`, + /// `*(p + 5 + 5)` together. + llvm::DenseMap<const Expr *, std::optional<size_t>> + AttachPtrComponentDepthMap = {{nullptr, std::nullopt}}; + + /// Map from attach pointer expressions to the order they were computed in, in + /// collectAttachPtrExprInfo(). + llvm::DenseMap<const Expr *, size_t> AttachPtrComputationOrderMap = { + {nullptr, 0}}; + llvm::Value *getExprTypeSize(const Expr *E) const { QualType ExprTy = E->getType().getCanonicalType(); @@ -8167,6 +8461,103 @@ private: } } + /// Returns the address corresponding to \p PointerExpr. + static Address getAttachPtrAddr(const Expr *PointerExpr, + CodeGenFunction &CGF) { + assert(PointerExpr && "Cannot get addr from null attach-ptr expr"); + Address AttachPtrAddr = Address::invalid(); + + if (auto *DRE = dyn_cast<DeclRefExpr>(PointerExpr)) { + // If the pointer is a variable, we can use its address directly. + AttachPtrAddr = CGF.EmitLValue(DRE).getAddress(); + } else if (auto *OASE = dyn_cast<ArraySectionExpr>(PointerExpr)) { + AttachPtrAddr = + CGF.EmitArraySectionExpr(OASE, /*IsLowerBound=*/true).getAddress(); + } else if (auto *ASE = dyn_cast<ArraySubscriptExpr>(PointerExpr)) { + AttachPtrAddr = CGF.EmitLValue(ASE).getAddress(); + } else if (auto *ME = dyn_cast<MemberExpr>(PointerExpr)) { + AttachPtrAddr = CGF.EmitMemberExpr(ME).getAddress(); + } else if (auto *UO = dyn_cast<UnaryOperator>(PointerExpr)) { + assert(UO->getOpcode() == UO_Deref && + "Unexpected unary-operator on attach-ptr-expr"); + AttachPtrAddr = CGF.EmitLValue(UO).getAddress(); + } + assert(AttachPtrAddr.isValid() && + "Failed to get address for attach pointer expression"); + return AttachPtrAddr; + } + + /// Get the address of the attach pointer, and a load from it, to get the + /// pointee base address. + /// \return A pair containing AttachPtrAddr and AttachPteeBaseAddr. The pair + /// contains invalid addresses if \p AttachPtrExpr is null. + static std::pair<Address, Address> + getAttachPtrAddrAndPteeBaseAddr(const Expr *AttachPtrExpr, + CodeGenFunction &CGF) { + + if (!AttachPtrExpr) + return {Address::invalid(), Address::invalid()}; + + Address AttachPtrAddr = getAttachPtrAddr(AttachPtrExpr, CGF); + assert(AttachPtrAddr.isValid() && "Invalid attach pointer addr"); + + QualType AttachPtrType = + OMPClauseMappableExprCommon::getComponentExprElementType(AttachPtrExpr) + .getCanonicalType(); + + Address AttachPteeBaseAddr = CGF.EmitLoadOfPointer( + AttachPtrAddr, AttachPtrType->castAs<PointerType>()); + assert(AttachPteeBaseAddr.isValid() && "Invalid attach pointee base addr"); + + return {AttachPtrAddr, AttachPteeBaseAddr}; + } + + /// Returns whether an attach entry should be emitted for a map on + /// \p MapBaseDecl on the directive \p CurDir. + static bool + shouldEmitAttachEntry(const Expr *PointerExpr, const ValueDecl *MapBaseDecl, + CodeGenFunction &CGF, + llvm::PointerUnion<const OMPExecutableDirective *, + const OMPDeclareMapperDecl *> + CurDir) { + if (!PointerExpr) + return false; + + // Pointer attachment is needed at map-entering time or for declare + // mappers. + return isa<const OMPDeclareMapperDecl *>(CurDir) || + isOpenMPTargetMapEnteringDirective( + cast<const OMPExecutableDirective *>(CurDir) + ->getDirectiveKind()); + } + + /// Computes the attach-ptr expr for \p Components, and updates various maps + /// with the information. + /// It internally calls OMPClauseMappableExprCommon::findAttachPtrExpr() + /// with the OpenMPDirectiveKind extracted from \p CurDir. + /// It updates AttachPtrComputationOrderMap, AttachPtrComponentDepthMap, and + /// AttachPtrExprMap. + void collectAttachPtrExprInfo( + OMPClauseMappableExprCommon::MappableExprComponentListRef Components, + llvm::PointerUnion<const OMPExecutableDirective *, + const OMPDeclareMapperDecl *> + CurDir) { + + OpenMPDirectiveKind CurDirectiveID = + isa<const OMPDeclareMapperDecl *>(CurDir) + ? OMPD_declare_mapper + : cast<const OMPExecutableDirective *>(CurDir)->getDirectiveKind(); + + const auto &[AttachPtrExpr, Depth] = + OMPClauseMappableExprCommon::findAttachPtrExpr(Components, + CurDirectiveID); + + AttachPtrComputationOrderMap.try_emplace( + AttachPtrExpr, AttachPtrComputationOrderMap.size()); + AttachPtrComponentDepthMap.try_emplace(AttachPtrExpr, Depth); + AttachPtrExprMap.try_emplace(Components, AttachPtrExpr); + } + /// Generate all the base pointers, section pointers, sizes, map types, and /// mappers for the extracted mappable expressions (all included in \a /// CombinedInfo). Also, for each item that relates with a device pointer, a |