diff options
Diffstat (limited to 'flang/lib')
-rw-r--r-- | flang/lib/Lower/OpenMP/Clauses.cpp | 5 | ||||
-rw-r--r-- | flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 208 | ||||
-rw-r--r-- | flang/lib/Parser/openmp-parsers.cpp | 5 | ||||
-rw-r--r-- | flang/lib/Parser/unparse.cpp | 7 | ||||
-rw-r--r-- | flang/lib/Semantics/check-omp-structure.cpp | 6 | ||||
-rw-r--r-- | flang/lib/Semantics/resolve-directives.cpp | 64 | ||||
-rw-r--r-- | flang/lib/Semantics/resolve-names.cpp | 4 |
7 files changed, 265 insertions, 34 deletions
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index 48b90cc..fac37a3 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -1036,6 +1036,11 @@ Link make(const parser::OmpClause::Link &inp, return Link{/*List=*/makeObjects(inp.v, semaCtx)}; } +LoopRange make(const parser::OmpClause::Looprange &inp, + semantics::SemanticsContext &semaCtx) { + llvm_unreachable("Unimplemented: looprange"); +} + Map make(const parser::OmpClause::Map &inp, semantics::SemanticsContext &semaCtx) { // inp.v -> parser::OmpMapClause diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index d8e36ea..9969ee4 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -2284,6 +2284,213 @@ public: } }; +static std::pair<mlir::Value, hlfir::AssociateOp> +getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) { + // If it is an expression - create a variable from it, or forward + // the value otherwise. + hlfir::AssociateOp associate; + if (!mlir::isa<hlfir::ExprType>(val.getType())) + return {val, associate}; + hlfir::Entity entity{val}; + mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder); + associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(), + "", byRefAttr); + return {associate.getBase(), associate}; +} + +class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> { +public: + using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::IndexOp op, + mlir::PatternRewriter &rewriter) const override { + // We simplify only limited cases: + // 1) a substring length shall be known at compile time + // 2) if a substring length is 0 then replace with 1 for forward search, + // or otherwise with the string length + 1 (builder shall const-fold if + // lookup direction is known at compile time). + // 3) for known string length at compile time, if it is + // shorter than substring => replace with zero. + // 4) if a substring length is one => inline as simple search loop + // 5) for forward search with input strings of kind=1 runtime is faster. + // Do not simplify in all the other cases relying on a runtime call. + + fir::FirOpBuilder builder{rewriter, op.getOperation()}; + const mlir::Location &loc = op->getLoc(); + + auto resultTy = op.getType(); + mlir::Value back = op.getBack(); + mlir::Value substrLen = + hlfir::genCharLength(loc, builder, hlfir::Entity{op.getSubstr()}); + + auto substrLenCst = fir::getIntIfConstant(substrLen); + if (!substrLenCst) { + return rewriter.notifyMatchFailure( + op, "substring length unknown at compile time"); + } + mlir::Value strLen = + hlfir::genCharLength(loc, builder, hlfir::Entity{op.getStr()}); + auto i1Ty = builder.getI1Type(); + auto idxTy = builder.getIndexType(); + if (*substrLenCst == 0) { + mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); + // zero length substring. For back search replace with + // strLen+1, or otherwise with 1. + mlir::Value strEnd = mlir::arith::AddIOp::create( + builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx); + if (back) + back = builder.createConvert(loc, i1Ty, back); + else + back = builder.createIntegerConstant(loc, i1Ty, 0); + mlir::Value result = + mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx); + + rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result)); + return mlir::success(); + } + + if (auto strLenCst = fir::getIntIfConstant(strLen)) { + if (*strLenCst < *substrLenCst) { + rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0)); + return mlir::success(); + } + if (*strLenCst == 0) { + // both strings have zero length + rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1)); + return mlir::success(); + } + } + if (*substrLenCst != 1) { + return rewriter.notifyMatchFailure( + op, "rely on runtime implementation if substring length > 1"); + } + // For forward search and character kind=1 the runtime uses memchr + // which well optimized. But it looks like memchr idiom is not recognized + // in LLVM yet. On a micro-kernel test with strings of length 40 runtime + // had ~2x less execution time vs inlined code. For unknown search direction + // at compile time pessimistically assume "forward". + std::optional<bool> isBack; + if (back) { + if (auto backCst = fir::getIntIfConstant(back)) + isBack = *backCst != 0; + } else { + isBack = false; + } + auto charTy = mlir::cast<fir::CharacterType>( + hlfir::getFortranElementType(op.getSubstr().getType())); + unsigned kind = charTy.getFKind(); + if (kind == 1 && (!isBack || !*isBack)) { + return rewriter.notifyMatchFailure( + op, "rely on runtime implementation for character kind 1"); + } + + // All checks are passed here. Generate single character search loop. + auto [strV, strAssociate] = getVariable(builder, loc, op.getStr()); + auto [substrV, substrAssociate] = getVariable(builder, loc, op.getSubstr()); + hlfir::Entity str{strV}; + hlfir::Entity substr{substrV}; + mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); + + auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx, + kind](mlir::Location loc, + fir::FirOpBuilder &builder, + hlfir::Entity &charStr, + mlir::Value index) { + auto bits = builder.getKindMap().getCharacterBitsize(kind); + auto intTy = builder.getIntegerType(bits); + auto charLen1Ty = + fir::CharacterType::getSingleton(builder.getContext(), kind); + mlir::Type designatorTy = + fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy)); + auto idxAttr = builder.getIntegerAttr(idxTy, 0); + + auto singleChr = hlfir::DesignateOp::create( + builder, loc, designatorTy, charStr, /*component=*/{}, + /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{}, + /*substring=*/mlir::ValueRange{index, index}, + /*complexPart=*/std::nullopt, + /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx}, + fir::FortranVariableFlagsAttr{}); + auto chrVal = fir::LoadOp::create(builder, loc, singleChr); + mlir::Value intVal = fir::ExtractValueOp::create( + builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr)); + return intVal; + }; + + auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx); + + // Generate search loop body with the following C equivalent: + // idx_t result = 0; + // idx_t end = strlen + 1; + // char want = substr[0]; + // for (idx_t idx = 1; idx < end; ++idx) { + // if (result == 0) { + // idx_t at = back ? end - idx: idx; + // result = str[at-1] == want ? at : result; + // } + // } + if (!back) + back = builder.createIntegerConstant(loc, i1Ty, 0); + else + back = builder.createConvert(loc, i1Ty, back); + mlir::Value strEnd = mlir::arith::AddIOp::create( + builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx); + mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); + auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange index, + mlir::ValueRange reductionArgs) + -> llvm::SmallVector<mlir::Value, 1> { + assert(index.size() == 1 && "expected single loop"); + assert(reductionArgs.size() == 1 && "expected single reduction value"); + mlir::Value inRes = reductionArgs[0]; + auto resEQzero = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx); + + mlir::Value res = + builder + .genIfOp(loc, {idxTy}, resEQzero, + /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value idx = builder.createConvert(loc, idxTy, index[0]); + // offset = back ? end - idx : idx; + mlir::Value offset = mlir::arith::SelectOp::create( + builder, loc, back, + mlir::arith::SubIOp::create(builder, loc, strEnd, idx), + idx); + + auto haveChar = + genExtractAndConvertToInt(loc, builder, str, offset); + auto charsEQ = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, haveChar, + wantChar); + mlir::Value newVal = mlir::arith::SelectOp::create( + builder, loc, charsEQ, offset, inRes); + + fir::ResultOp::create(builder, loc, newVal); + }) + .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); }) + .getResults()[0]; + return {res}; + }; + + llvm::SmallVector<mlir::Value, 1> loopOut = + hlfir::genLoopNestWithReductions(loc, builder, {strLen}, + /*reductionInits=*/{zeroIdx}, + genSearchBody, + /*isUnordered=*/false); + mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]); + + if (strAssociate) + hlfir::EndAssociateOp::create(builder, loc, strAssociate); + if (substrAssociate) + hlfir::EndAssociateOp::create(builder, loc, substrAssociate); + + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; + template <typename Op> class MatmulConversion : public mlir::OpRewritePattern<Op> { public: @@ -2955,6 +3162,7 @@ public: patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context); patterns.insert<CmpCharOpConversion>(context); + patterns.insert<IndexOpConversion>(context); patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context); patterns.insert<ReductionConversion<hlfir::CountOp>>(context); patterns.insert<ReductionConversion<hlfir::AnyOp>>(context); diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index ea09fe0..9507021 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1023,6 +1023,9 @@ TYPE_PARSER( maybe(":"_tok >> nonemptyList(Parser<OmpLinearClause::Modifier>{})), /*PostModified=*/pure(true))) +TYPE_PARSER(construct<OmpLoopRangeClause>( + scalarIntConstantExpr, "," >> scalarIntConstantExpr)) + // OpenMPv5.2 12.5.2 detach-clause -> DETACH (event-handle) TYPE_PARSER(construct<OmpDetachClause>(Parser<OmpObject>{})) @@ -1207,6 +1210,8 @@ TYPE_PARSER( // parenthesized(Parser<OmpLinearClause>{}))) || "LINK" >> construct<OmpClause>(construct<OmpClause::Link>( parenthesized(Parser<OmpObjectList>{}))) || + "LOOPRANGE" >> construct<OmpClause>(construct<OmpClause::Looprange>( + parenthesized(Parser<OmpLoopRangeClause>{}))) || "MAP" >> construct<OmpClause>(construct<OmpClause::Map>( parenthesized(Parser<OmpMapClause>{}))) || "MATCH" >> construct<OmpClause>(construct<OmpClause::Match>( diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 0fbd347..0511f5b 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -2345,6 +2345,13 @@ public: } } } + void Unparse(const OmpLoopRangeClause &x) { + Word("LOOPRANGE("); + Walk(std::get<0>(x.t)); + Put(", "); + Walk(std::get<1>(x.t)); + Put(")"); + } void Unparse(const OmpReductionClause &x) { using Modifier = OmpReductionClause::Modifier; Walk(std::get<std::optional<std::list<Modifier>>>(x.t), ": "); diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index cc2dd0a..db030bb 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -3106,6 +3106,12 @@ CHECK_REQ_CONSTANT_SCALAR_INT_CLAUSE(Collapse, OMPC_collapse) CHECK_REQ_CONSTANT_SCALAR_INT_CLAUSE(Safelen, OMPC_safelen) CHECK_REQ_CONSTANT_SCALAR_INT_CLAUSE(Simdlen, OMPC_simdlen) +void OmpStructureChecker::Enter(const parser::OmpClause::Looprange &x) { + context_.Say(GetContext().clauseSource, + "LOOPRANGE clause is not implemented yet"_err_en_US, + ContextDirectiveAsFortran()); +} + // Restrictions specific to each clause are implemented apart from the // generalized restrictions. diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index a4c8922f..270642a 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -362,6 +362,24 @@ public: explicit OmpAttributeVisitor(SemanticsContext &context) : DirectiveAttributeVisitor(context) {} + static const Scope &scopingUnit(const Scope &scope) { + const Scope *iter{&scope}; + for (; !iter->IsTopLevel(); iter = &iter->parent()) { + switch (iter->kind()) { + case Scope::Kind::BlockConstruct: + case Scope::Kind::BlockData: + case Scope::Kind::DerivedType: + case Scope::Kind::MainProgram: + case Scope::Kind::Module: + case Scope::Kind::Subprogram: + return *iter; + default: + break; + } + } + return *iter; + } + template <typename A> void Walk(const A &x) { parser::Walk(x, *this); } template <typename A> bool Pre(const A &) { return true; } template <typename A> void Post(const A &) {} @@ -952,7 +970,6 @@ private: void ResolveOmpNameList(const std::list<parser::Name> &, Symbol::Flag); void ResolveOmpName(const parser::Name &, Symbol::Flag); Symbol *ResolveName(const parser::Name *); - Symbol *ResolveOmpObjectScope(const parser::Name *); Symbol *DeclareOrMarkOtherAccessEntity(const parser::Name &, Symbol::Flag); Symbol *DeclareOrMarkOtherAccessEntity(Symbol &, Symbol::Flag); void CheckMultipleAppearances( @@ -2920,31 +2937,6 @@ Symbol *OmpAttributeVisitor::ResolveOmpCommonBlockName( return nullptr; } -// Use this function over ResolveOmpName when an omp object's scope needs -// resolving, it's symbol flag isn't important and a simple check for resolution -// failure is desired. Using ResolveOmpName means needing to work with the -// context to check for failure, whereas here a pointer comparison is all that's -// needed. -Symbol *OmpAttributeVisitor::ResolveOmpObjectScope(const parser::Name *name) { - - // TODO: Investigate whether the following block can be replaced by, or - // included in, the ResolveOmpName function - if (auto *prev{name ? GetContext().scope.parent().FindSymbol(name->source) - : nullptr}) { - name->symbol = prev; - return nullptr; - } - - // TODO: Investigate whether the following block can be replaced by, or - // included in, the ResolveOmpName function - if (auto *ompSymbol{ - name ? GetContext().scope.FindSymbol(name->source) : nullptr}) { - name->symbol = ompSymbol; - return ompSymbol; - } - return nullptr; -} - void OmpAttributeVisitor::ResolveOmpObjectList( const parser::OmpObjectList &ompObjectList, Symbol::Flag ompFlag) { for (const auto &ompObject : ompObjectList.v) { @@ -3023,13 +3015,19 @@ void OmpAttributeVisitor::ResolveOmpDesignator( context_.Say(designator.source, "List items specified in the ALLOCATE directive must not have the ALLOCATABLE attribute unless the directive is associated with an ALLOCATE statement"_err_en_US); } - if ((ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective || - ompFlag == Symbol::Flag::OmpExecutableAllocateDirective) && - ResolveOmpObjectScope(name) == nullptr) { - context_.Say(designator.source, // 2.15.3 - "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US, - parser::ToUpperCaseLetters( - llvm::omp::getOpenMPDirectiveName(directive, version))); + bool checkScope{ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective}; + // In 5.1 the scope check only applies to declarative allocate. + if (version == 50 && !checkScope) { + checkScope = ompFlag == Symbol::Flag::OmpExecutableAllocateDirective; + } + if (checkScope) { + if (scopingUnit(GetContext().scope) != + scopingUnit(symbol->GetUltimate().owner())) { + context_.Say(designator.source, // 2.15.3 + "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US, + parser::ToUpperCaseLetters( + llvm::omp::getOpenMPDirectiveName(directive, version))); + } } if (ompFlag == Symbol::Flag::OmpReduction) { // Using variables inside of a namelist in OpenMP reductions diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 2f350f0..ef0b8cd 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1618,12 +1618,14 @@ public: void Post(const parser::OpenMPDeclareTargetConstruct &) { SkipImplicitTyping(false); } - bool Pre(const parser::OpenMPDeclarativeAllocate &) { + bool Pre(const parser::OpenMPDeclarativeAllocate &x) { + AddOmpSourceRange(x.source); SkipImplicitTyping(true); return true; } void Post(const parser::OpenMPDeclarativeAllocate &) { SkipImplicitTyping(false); + messageHandler().set_currStmtSource(std::nullopt); } bool Pre(const parser::OpenMPDeclarativeConstruct &x) { AddOmpSourceRange(x.source); |