aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp')
-rw-r--r--clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp188
1 files changed, 134 insertions, 54 deletions
diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
index 1d31b22..dbf4878 100644
--- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
+++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
@@ -64,39 +64,125 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
return false;
}
+static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
+ if (RD == nullptr)
+ return nullptr;
+ if (hasOptionalClassName(*RD))
+ return RD;
+
+ if (!RD->hasDefinition())
+ return nullptr;
+
+ for (const CXXBaseSpecifier &Base : RD->bases())
+ if (const CXXRecordDecl *BaseClass =
+ getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
+ return BaseClass;
+
+ return nullptr;
+}
+
namespace {
using namespace ::clang::ast_matchers;
using LatticeTransferState = TransferState<NoopLattice>;
-AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
- return hasOptionalClassName(Node);
+AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); }
+
+AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
+ return getOptionalBaseClass(&Node) != nullptr;
}
-DeclarationMatcher optionalClass() {
- return classTemplateSpecializationDecl(
- hasOptionalClassNameMatcher(),
- hasTemplateArgument(0, refersToType(type().bind("T"))));
+auto desugarsToOptionalType() {
+ return hasUnqualifiedDesugaredType(
+ recordType(hasDeclaration(cxxRecordDecl(optionalClass()))));
}
-auto optionalOrAliasType() {
+auto desugarsToOptionalOrDerivedType() {
return hasUnqualifiedDesugaredType(
- recordType(hasDeclaration(optionalClass())));
+ recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
+}
+
+auto hasOptionalType() { return hasType(desugarsToOptionalType()); }
+
+/// Matches any of the spellings of the optional types and sugar, aliases,
+/// derived classes, etc.
+auto hasOptionalOrDerivedType() {
+ return hasType(desugarsToOptionalOrDerivedType());
+}
+
+QualType getPublicType(const Expr *E) {
+ auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens());
+ if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) {
+ QualType Ty = E->getType();
+ if (Ty->isPointerType())
+ return Ty->getPointeeType();
+ return Ty;
+ }
+
+ // Is the derived type that we're casting from the type of `*this`? In this
+ // special case, we can upcast to the base class even if the base is
+ // non-public.
+ bool CastingFromThis = isa<CXXThisExpr>(Cast->getSubExpr());
+
+ // Find the least-derived type in the path (i.e. the last entry in the list)
+ // that we can access.
+ const CXXBaseSpecifier *PublicBase = nullptr;
+ for (const CXXBaseSpecifier *Base : Cast->path()) {
+ if (Base->getAccessSpecifier() != AS_public && !CastingFromThis)
+ break;
+ PublicBase = Base;
+ CastingFromThis = false;
+ }
+
+ if (PublicBase != nullptr)
+ return PublicBase->getType();
+
+ // We didn't find any public type that we could cast to. There may be more
+ // casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this
+ // will return the type of `getSubExpr()`.)
+ return getPublicType(Cast->getSubExpr());
}
-/// Matches any of the spellings of the optional types and sugar, aliases, etc.
-auto hasOptionalType() { return hasType(optionalOrAliasType()); }
+// Returns the least-derived type for the receiver of `MCE` that
+// `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
+// Effectively, we upcast until we reach a non-public base class, unless that
+// base is a base of `*this`.
+//
+// This is needed to correctly match methods called on types derived from
+// `std::optional`.
+//
+// Say we have a `struct Derived : public std::optional<int> {} d;` For a call
+// `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
+//
+// ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
+// | <UncheckedDerivedToBase (optional -> __optional_storage_base)>
+// `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
+//
+// The type of the implicit object argument is `__optional_storage_base`
+// (since this is the internal type that `has_value()` is declared on). If we
+// call `IgnoreParenImpCasts()` on the implicit object argument, we get the
+// `DeclRefExpr`, which has type `Derived`. Neither of these types is
+// `optional`, and hence neither is sufficient for querying whether we are
+// calling a method on `optional`.
+//
+// Instead, starting with the most derived type, we need to follow the chain of
+// casts
+QualType getPublicReceiverType(const CXXMemberCallExpr &MCE) {
+ return getPublicType(MCE.getImplicitObjectArgument());
+}
+
+AST_MATCHER_P(CXXMemberCallExpr, publicReceiverType,
+ ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
+ return InnerMatcher.matches(getPublicReceiverType(Node), Finder, Builder);
+}
auto isOptionalMemberCallWithNameMatcher(
ast_matchers::internal::Matcher<NamedDecl> matcher,
const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
- auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr()))
- : cxxThisExpr());
- return cxxMemberCallExpr(
- on(expr(Exception,
- anyOf(hasOptionalType(),
- hasType(pointerType(pointee(optionalOrAliasType())))))),
- callee(cxxMethodDecl(matcher)));
+ return cxxMemberCallExpr(Ignorable ? on(expr(unless(*Ignorable)))
+ : anything(),
+ publicReceiverType(desugarsToOptionalType()),
+ callee(cxxMethodDecl(matcher)));
}
auto isOptionalOperatorCallWithName(
@@ -129,49 +215,51 @@ auto inPlaceClass() {
auto isOptionalNulloptConstructor() {
return cxxConstructExpr(
- hasOptionalType(),
hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
- hasParameter(0, hasNulloptType()))));
+ hasParameter(0, hasNulloptType()))),
+ hasOptionalOrDerivedType());
}
auto isOptionalInPlaceConstructor() {
- return cxxConstructExpr(hasOptionalType(),
- hasArgument(0, hasType(inPlaceClass())));
+ return cxxConstructExpr(hasArgument(0, hasType(inPlaceClass())),
+ hasOptionalOrDerivedType());
}
auto isOptionalValueOrConversionConstructor() {
return cxxConstructExpr(
- hasOptionalType(),
unless(hasDeclaration(
cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
- argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
+ argumentCountIs(1), hasArgument(0, unless(hasNulloptType())),
+ hasOptionalOrDerivedType());
}
auto isOptionalValueOrConversionAssignment() {
return cxxOperatorCallExpr(
hasOverloadedOperatorName("="),
- callee(cxxMethodDecl(ofClass(optionalClass()))),
+ callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
unless(hasDeclaration(cxxMethodDecl(
anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))),
argumentCountIs(2), hasArgument(1, unless(hasNulloptType())));
}
auto isOptionalNulloptAssignment() {
- return cxxOperatorCallExpr(hasOverloadedOperatorName("="),
- callee(cxxMethodDecl(ofClass(optionalClass()))),
- argumentCountIs(2),
- hasArgument(1, hasNulloptType()));
+ return cxxOperatorCallExpr(
+ hasOverloadedOperatorName("="),
+ callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
+ argumentCountIs(2), hasArgument(1, hasNulloptType()));
}
auto isStdSwapCall() {
return callExpr(callee(functionDecl(hasName("std::swap"))),
- argumentCountIs(2), hasArgument(0, hasOptionalType()),
- hasArgument(1, hasOptionalType()));
+ argumentCountIs(2),
+ hasArgument(0, hasOptionalOrDerivedType()),
+ hasArgument(1, hasOptionalOrDerivedType()));
}
auto isStdForwardCall() {
return callExpr(callee(functionDecl(hasName("std::forward"))),
- argumentCountIs(1), hasArgument(0, hasOptionalType()));
+ argumentCountIs(1),
+ hasArgument(0, hasOptionalOrDerivedType()));
}
constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall";
@@ -212,8 +300,9 @@ auto isValueOrNotEqX() {
}
auto isCallReturningOptional() {
- return callExpr(hasType(qualType(anyOf(
- optionalOrAliasType(), referenceType(pointee(optionalOrAliasType()))))));
+ return callExpr(hasType(qualType(
+ anyOf(desugarsToOptionalOrDerivedType(),
+ referenceType(pointee(desugarsToOptionalOrDerivedType()))))));
}
template <typename L, typename R>
@@ -275,12 +364,9 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
return HasValueVal;
}
-/// Returns true if and only if `Type` is an optional type.
-bool isOptionalType(QualType Type) {
- if (!Type->isRecordType())
- return false;
- const CXXRecordDecl *D = Type->getAsCXXRecordDecl();
- return D != nullptr && hasOptionalClassName(*D);
+QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) {
+ auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
+ return CTSD.getTemplateArgs()[0].getAsType();
}
/// Returns the number of optional wrappers in `Type`.
@@ -288,15 +374,13 @@ bool isOptionalType(QualType Type) {
/// For example, if `Type` is `optional<optional<int>>`, the result of this
/// function will be 2.
int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) {
- if (!isOptionalType(Type))
+ const CXXRecordDecl *Optional =
+ getOptionalBaseClass(Type->getAsCXXRecordDecl());
+ if (Optional == nullptr)
return 0;
return 1 + countOptionalWrappers(
ASTCtx,
- cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl())
- ->getTemplateArgs()
- .get(0)
- .getAsType()
- .getDesugaredType(ASTCtx));
+ valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx));
}
StorageLocation *getLocBehindPossiblePointer(const Expr &E,
@@ -843,13 +927,7 @@ auto buildDiagnoseMatchSwitch(
ast_matchers::DeclarationMatcher
UncheckedOptionalAccessModel::optionalClassDecl() {
- return optionalClass();
-}
-
-static QualType valueTypeFromOptionalType(QualType OptionalTy) {
- auto *CTSD =
- cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl());
- return CTSD->getTemplateArgs()[0].getAsType();
+ return cxxRecordDecl(optionalClass());
}
UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
@@ -858,9 +936,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
TransferMatchSwitch(buildTransferMatchSwitch()) {
Env.getDataflowAnalysisContext().setSyntheticFieldCallback(
[&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
- if (!isOptionalType(Ty))
+ const CXXRecordDecl *Optional =
+ getOptionalBaseClass(Ty->getAsCXXRecordDecl());
+ if (Optional == nullptr)
return {};
- return {{"value", valueTypeFromOptionalType(Ty)},
+ return {{"value", valueTypeFromOptionalDecl(*Optional)},
{"has_value", Ctx.BoolTy}};
});
}