aboutsummaryrefslogtreecommitdiff
path: root/flang/lib
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib')
-rw-r--r--flang/lib/Lower/OpenMP/Clauses.cpp5
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp208
-rw-r--r--flang/lib/Parser/openmp-parsers.cpp5
-rw-r--r--flang/lib/Parser/unparse.cpp7
-rw-r--r--flang/lib/Semantics/check-omp-structure.cpp6
-rw-r--r--flang/lib/Semantics/resolve-directives.cpp64
-rw-r--r--flang/lib/Semantics/resolve-names.cpp4
7 files changed, 265 insertions, 34 deletions
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 48b90cc..fac37a3 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1036,6 +1036,11 @@ Link make(const parser::OmpClause::Link &inp,
return Link{/*List=*/makeObjects(inp.v, semaCtx)};
}
+LoopRange make(const parser::OmpClause::Looprange &inp,
+ semantics::SemanticsContext &semaCtx) {
+ llvm_unreachable("Unimplemented: looprange");
+}
+
Map make(const parser::OmpClause::Map &inp,
semantics::SemanticsContext &semaCtx) {
// inp.v -> parser::OmpMapClause
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index d8e36ea..9969ee4 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -2284,6 +2284,213 @@ public:
}
};
+static std::pair<mlir::Value, hlfir::AssociateOp>
+getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
+ // If it is an expression - create a variable from it, or forward
+ // the value otherwise.
+ hlfir::AssociateOp associate;
+ if (!mlir::isa<hlfir::ExprType>(val.getType()))
+ return {val, associate};
+ hlfir::Entity entity{val};
+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
+ associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(),
+ "", byRefAttr);
+ return {associate.getBase(), associate};
+}
+
+class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::IndexOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // We simplify only limited cases:
+ // 1) a substring length shall be known at compile time
+ // 2) if a substring length is 0 then replace with 1 for forward search,
+ // or otherwise with the string length + 1 (builder shall const-fold if
+ // lookup direction is known at compile time).
+ // 3) for known string length at compile time, if it is
+ // shorter than substring => replace with zero.
+ // 4) if a substring length is one => inline as simple search loop
+ // 5) for forward search with input strings of kind=1 runtime is faster.
+ // Do not simplify in all the other cases relying on a runtime call.
+
+ fir::FirOpBuilder builder{rewriter, op.getOperation()};
+ const mlir::Location &loc = op->getLoc();
+
+ auto resultTy = op.getType();
+ mlir::Value back = op.getBack();
+ mlir::Value substrLen =
+ hlfir::genCharLength(loc, builder, hlfir::Entity{op.getSubstr()});
+
+ auto substrLenCst = fir::getIntIfConstant(substrLen);
+ if (!substrLenCst) {
+ return rewriter.notifyMatchFailure(
+ op, "substring length unknown at compile time");
+ }
+ mlir::Value strLen =
+ hlfir::genCharLength(loc, builder, hlfir::Entity{op.getStr()});
+ auto i1Ty = builder.getI1Type();
+ auto idxTy = builder.getIndexType();
+ if (*substrLenCst == 0) {
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+ // zero length substring. For back search replace with
+ // strLen+1, or otherwise with 1.
+ mlir::Value strEnd = mlir::arith::AddIOp::create(
+ builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
+ if (back)
+ back = builder.createConvert(loc, i1Ty, back);
+ else
+ back = builder.createIntegerConstant(loc, i1Ty, 0);
+ mlir::Value result =
+ mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx);
+
+ rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result));
+ return mlir::success();
+ }
+
+ if (auto strLenCst = fir::getIntIfConstant(strLen)) {
+ if (*strLenCst < *substrLenCst) {
+ rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0));
+ return mlir::success();
+ }
+ if (*strLenCst == 0) {
+ // both strings have zero length
+ rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1));
+ return mlir::success();
+ }
+ }
+ if (*substrLenCst != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "rely on runtime implementation if substring length > 1");
+ }
+ // For forward search and character kind=1 the runtime uses memchr
+ // which well optimized. But it looks like memchr idiom is not recognized
+ // in LLVM yet. On a micro-kernel test with strings of length 40 runtime
+ // had ~2x less execution time vs inlined code. For unknown search direction
+ // at compile time pessimistically assume "forward".
+ std::optional<bool> isBack;
+ if (back) {
+ if (auto backCst = fir::getIntIfConstant(back))
+ isBack = *backCst != 0;
+ } else {
+ isBack = false;
+ }
+ auto charTy = mlir::cast<fir::CharacterType>(
+ hlfir::getFortranElementType(op.getSubstr().getType()));
+ unsigned kind = charTy.getFKind();
+ if (kind == 1 && (!isBack || !*isBack)) {
+ return rewriter.notifyMatchFailure(
+ op, "rely on runtime implementation for character kind 1");
+ }
+
+ // All checks are passed here. Generate single character search loop.
+ auto [strV, strAssociate] = getVariable(builder, loc, op.getStr());
+ auto [substrV, substrAssociate] = getVariable(builder, loc, op.getSubstr());
+ hlfir::Entity str{strV};
+ hlfir::Entity substr{substrV};
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+ auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
+ kind](mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ hlfir::Entity &charStr,
+ mlir::Value index) {
+ auto bits = builder.getKindMap().getCharacterBitsize(kind);
+ auto intTy = builder.getIntegerType(bits);
+ auto charLen1Ty =
+ fir::CharacterType::getSingleton(builder.getContext(), kind);
+ mlir::Type designatorTy =
+ fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
+ auto idxAttr = builder.getIntegerAttr(idxTy, 0);
+
+ auto singleChr = hlfir::DesignateOp::create(
+ builder, loc, designatorTy, charStr, /*component=*/{},
+ /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
+ /*substring=*/mlir::ValueRange{index, index},
+ /*complexPart=*/std::nullopt,
+ /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx},
+ fir::FortranVariableFlagsAttr{});
+ auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
+ mlir::Value intVal = fir::ExtractValueOp::create(
+ builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
+ return intVal;
+ };
+
+ auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx);
+
+ // Generate search loop body with the following C equivalent:
+ // idx_t result = 0;
+ // idx_t end = strlen + 1;
+ // char want = substr[0];
+ // for (idx_t idx = 1; idx < end; ++idx) {
+ // if (result == 0) {
+ // idx_t at = back ? end - idx: idx;
+ // result = str[at-1] == want ? at : result;
+ // }
+ // }
+ if (!back)
+ back = builder.createIntegerConstant(loc, i1Ty, 0);
+ else
+ back = builder.createConvert(loc, i1Ty, back);
+ mlir::Value strEnd = mlir::arith::AddIOp::create(
+ builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
+ mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
+ auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange index,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 1> {
+ assert(index.size() == 1 && "expected single loop");
+ assert(reductionArgs.size() == 1 && "expected single reduction value");
+ mlir::Value inRes = reductionArgs[0];
+ auto resEQzero = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);
+
+ mlir::Value res =
+ builder
+ .genIfOp(loc, {idxTy}, resEQzero,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value idx = builder.createConvert(loc, idxTy, index[0]);
+ // offset = back ? end - idx : idx;
+ mlir::Value offset = mlir::arith::SelectOp::create(
+ builder, loc, back,
+ mlir::arith::SubIOp::create(builder, loc, strEnd, idx),
+ idx);
+
+ auto haveChar =
+ genExtractAndConvertToInt(loc, builder, str, offset);
+ auto charsEQ = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
+ wantChar);
+ mlir::Value newVal = mlir::arith::SelectOp::create(
+ builder, loc, charsEQ, offset, inRes);
+
+ fir::ResultOp::create(builder, loc, newVal);
+ })
+ .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
+ .getResults()[0];
+ return {res};
+ };
+
+ llvm::SmallVector<mlir::Value, 1> loopOut =
+ hlfir::genLoopNestWithReductions(loc, builder, {strLen},
+ /*reductionInits=*/{zeroIdx},
+ genSearchBody,
+ /*isUnordered=*/false);
+ mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]);
+
+ if (strAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, strAssociate);
+ if (substrAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, substrAssociate);
+
+ rewriter.replaceOp(op, result);
+ return mlir::success();
+ }
+};
+
template <typename Op>
class MatmulConversion : public mlir::OpRewritePattern<Op> {
public:
@@ -2955,6 +3162,7 @@ public:
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
patterns.insert<CmpCharOpConversion>(context);
+ patterns.insert<IndexOpConversion>(context);
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index ea09fe0..9507021 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -1023,6 +1023,9 @@ TYPE_PARSER(
maybe(":"_tok >> nonemptyList(Parser<OmpLinearClause::Modifier>{})),
/*PostModified=*/pure(true)))
+TYPE_PARSER(construct<OmpLoopRangeClause>(
+ scalarIntConstantExpr, "," >> scalarIntConstantExpr))
+
// OpenMPv5.2 12.5.2 detach-clause -> DETACH (event-handle)
TYPE_PARSER(construct<OmpDetachClause>(Parser<OmpObject>{}))
@@ -1207,6 +1210,8 @@ TYPE_PARSER( //
parenthesized(Parser<OmpLinearClause>{}))) ||
"LINK" >> construct<OmpClause>(construct<OmpClause::Link>(
parenthesized(Parser<OmpObjectList>{}))) ||
+ "LOOPRANGE" >> construct<OmpClause>(construct<OmpClause::Looprange>(
+ parenthesized(Parser<OmpLoopRangeClause>{}))) ||
"MAP" >> construct<OmpClause>(construct<OmpClause::Map>(
parenthesized(Parser<OmpMapClause>{}))) ||
"MATCH" >> construct<OmpClause>(construct<OmpClause::Match>(
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 0fbd347..0511f5b 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2345,6 +2345,13 @@ public:
}
}
}
+ void Unparse(const OmpLoopRangeClause &x) {
+ Word("LOOPRANGE(");
+ Walk(std::get<0>(x.t));
+ Put(", ");
+ Walk(std::get<1>(x.t));
+ Put(")");
+ }
void Unparse(const OmpReductionClause &x) {
using Modifier = OmpReductionClause::Modifier;
Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ": ");
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index cc2dd0a..db030bb 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -3106,6 +3106,12 @@ CHECK_REQ_CONSTANT_SCALAR_INT_CLAUSE(Collapse, OMPC_collapse)
CHECK_REQ_CONSTANT_SCALAR_INT_CLAUSE(Safelen, OMPC_safelen)
CHECK_REQ_CONSTANT_SCALAR_INT_CLAUSE(Simdlen, OMPC_simdlen)
+void OmpStructureChecker::Enter(const parser::OmpClause::Looprange &x) {
+ context_.Say(GetContext().clauseSource,
+ "LOOPRANGE clause is not implemented yet"_err_en_US,
+ ContextDirectiveAsFortran());
+}
+
// Restrictions specific to each clause are implemented apart from the
// generalized restrictions.
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index a4c8922f..270642a 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -362,6 +362,24 @@ public:
explicit OmpAttributeVisitor(SemanticsContext &context)
: DirectiveAttributeVisitor(context) {}
+ static const Scope &scopingUnit(const Scope &scope) {
+ const Scope *iter{&scope};
+ for (; !iter->IsTopLevel(); iter = &iter->parent()) {
+ switch (iter->kind()) {
+ case Scope::Kind::BlockConstruct:
+ case Scope::Kind::BlockData:
+ case Scope::Kind::DerivedType:
+ case Scope::Kind::MainProgram:
+ case Scope::Kind::Module:
+ case Scope::Kind::Subprogram:
+ return *iter;
+ default:
+ break;
+ }
+ }
+ return *iter;
+ }
+
template <typename A> void Walk(const A &x) { parser::Walk(x, *this); }
template <typename A> bool Pre(const A &) { return true; }
template <typename A> void Post(const A &) {}
@@ -952,7 +970,6 @@ private:
void ResolveOmpNameList(const std::list<parser::Name> &, Symbol::Flag);
void ResolveOmpName(const parser::Name &, Symbol::Flag);
Symbol *ResolveName(const parser::Name *);
- Symbol *ResolveOmpObjectScope(const parser::Name *);
Symbol *DeclareOrMarkOtherAccessEntity(const parser::Name &, Symbol::Flag);
Symbol *DeclareOrMarkOtherAccessEntity(Symbol &, Symbol::Flag);
void CheckMultipleAppearances(
@@ -2920,31 +2937,6 @@ Symbol *OmpAttributeVisitor::ResolveOmpCommonBlockName(
return nullptr;
}
-// Use this function over ResolveOmpName when an omp object's scope needs
-// resolving, it's symbol flag isn't important and a simple check for resolution
-// failure is desired. Using ResolveOmpName means needing to work with the
-// context to check for failure, whereas here a pointer comparison is all that's
-// needed.
-Symbol *OmpAttributeVisitor::ResolveOmpObjectScope(const parser::Name *name) {
-
- // TODO: Investigate whether the following block can be replaced by, or
- // included in, the ResolveOmpName function
- if (auto *prev{name ? GetContext().scope.parent().FindSymbol(name->source)
- : nullptr}) {
- name->symbol = prev;
- return nullptr;
- }
-
- // TODO: Investigate whether the following block can be replaced by, or
- // included in, the ResolveOmpName function
- if (auto *ompSymbol{
- name ? GetContext().scope.FindSymbol(name->source) : nullptr}) {
- name->symbol = ompSymbol;
- return ompSymbol;
- }
- return nullptr;
-}
-
void OmpAttributeVisitor::ResolveOmpObjectList(
const parser::OmpObjectList &ompObjectList, Symbol::Flag ompFlag) {
for (const auto &ompObject : ompObjectList.v) {
@@ -3023,13 +3015,19 @@ void OmpAttributeVisitor::ResolveOmpDesignator(
context_.Say(designator.source,
"List items specified in the ALLOCATE directive must not have the ALLOCATABLE attribute unless the directive is associated with an ALLOCATE statement"_err_en_US);
}
- if ((ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective ||
- ompFlag == Symbol::Flag::OmpExecutableAllocateDirective) &&
- ResolveOmpObjectScope(name) == nullptr) {
- context_.Say(designator.source, // 2.15.3
- "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US,
- parser::ToUpperCaseLetters(
- llvm::omp::getOpenMPDirectiveName(directive, version)));
+ bool checkScope{ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective};
+ // In 5.1 the scope check only applies to declarative allocate.
+ if (version == 50 && !checkScope) {
+ checkScope = ompFlag == Symbol::Flag::OmpExecutableAllocateDirective;
+ }
+ if (checkScope) {
+ if (scopingUnit(GetContext().scope) !=
+ scopingUnit(symbol->GetUltimate().owner())) {
+ context_.Say(designator.source, // 2.15.3
+ "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US,
+ parser::ToUpperCaseLetters(
+ llvm::omp::getOpenMPDirectiveName(directive, version)));
+ }
}
if (ompFlag == Symbol::Flag::OmpReduction) {
// Using variables inside of a namelist in OpenMP reductions
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 2f350f0..ef0b8cd 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1618,12 +1618,14 @@ public:
void Post(const parser::OpenMPDeclareTargetConstruct &) {
SkipImplicitTyping(false);
}
- bool Pre(const parser::OpenMPDeclarativeAllocate &) {
+ bool Pre(const parser::OpenMPDeclarativeAllocate &x) {
+ AddOmpSourceRange(x.source);
SkipImplicitTyping(true);
return true;
}
void Post(const parser::OpenMPDeclarativeAllocate &) {
SkipImplicitTyping(false);
+ messageHandler().set_currStmtSource(std::nullopt);
}
bool Pre(const parser::OpenMPDeclarativeConstruct &x) {
AddOmpSourceRange(x.source);