aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer')
-rw-r--r--flang/lib/Optimizer/Builder/HLFIRTools.cpp81
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp207
-rw-r--r--flang/lib/Optimizer/Transforms/AddDebugInfo.cpp9
3 files changed, 276 insertions, 21 deletions
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index f93eaf7..dbfcae1 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -676,6 +676,34 @@ mlir::Value hlfir::genLBound(mlir::Location loc, fir::FirOpBuilder &builder,
return dimInfo.getLowerBound();
}
+static bool
+getExprLengthParameters(mlir::Value expr,
+ llvm::SmallVectorImpl<mlir::Value> &result) {
+ if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) {
+ result.push_back(concat.getLength());
+ return true;
+ }
+ if (auto setLen = expr.getDefiningOp<hlfir::SetLengthOp>()) {
+ result.push_back(setLen.getLength());
+ return true;
+ }
+ if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) {
+ result.append(elemental.getTypeparams().begin(),
+ elemental.getTypeparams().end());
+ return true;
+ }
+ if (auto evalInMem = expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
+ result.append(evalInMem.getTypeparams().begin(),
+ evalInMem.getTypeparams().end());
+ return true;
+ }
+ if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
+ result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
+ return true;
+ }
+ return false;
+}
+
void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity,
llvm::SmallVectorImpl<mlir::Value> &result) {
@@ -688,29 +716,14 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
// Going through fir::ExtendedValue would create a temp,
// which is not desired for an inquiry.
// TODO: make this an interface when adding further character producing ops.
- if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) {
- result.push_back(concat.getLength());
- return;
- } else if (auto concat = expr.getDefiningOp<hlfir::SetLengthOp>()) {
- result.push_back(concat.getLength());
- return;
- } else if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) {
+
+ if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) {
hlfir::genLengthParameters(loc, builder, hlfir::Entity{asExpr.getVar()},
result);
return;
- } else if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) {
- result.append(elemental.getTypeparams().begin(),
- elemental.getTypeparams().end());
- return;
- } else if (auto evalInMem =
- expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
- result.append(evalInMem.getTypeparams().begin(),
- evalInMem.getTypeparams().end());
- return;
- } else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
- result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
- return;
}
+ if (getExprLengthParameters(expr, result))
+ return;
if (entity.isCharacter()) {
result.push_back(hlfir::GetLengthOp::create(builder, loc, expr));
return;
@@ -733,6 +746,36 @@ mlir::Value hlfir::genCharLength(mlir::Location loc, fir::FirOpBuilder &builder,
return lenParams[0];
}
+std::optional<std::int64_t> hlfir::getCharLengthIfConst(hlfir::Entity entity) {
+ if (!entity.isCharacter()) {
+ return std::nullopt;
+ }
+ if (mlir::isa<hlfir::ExprType>(entity.getType())) {
+ mlir::Value expr = entity;
+ if (auto reassoc = expr.getDefiningOp<hlfir::NoReassocOp>())
+ expr = reassoc.getVal();
+
+ if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>())
+ return getCharLengthIfConst(hlfir::Entity{asExpr.getVar()});
+
+ llvm::SmallVector<mlir::Value> param;
+ if (getExprLengthParameters(expr, param)) {
+ assert(param.size() == 1 && "characters must have one length parameters");
+ return fir::getIntIfConstant(param.pop_back_val());
+ }
+ return std::nullopt;
+ }
+
+ // entity is a var
+ if (mlir::Value len = tryGettingNonDeferredCharLen(entity))
+ return fir::getIntIfConstant(len);
+ auto charType =
+ mlir::cast<fir::CharacterType>(entity.getFortranElementType());
+ if (charType.hasConstantLen())
+ return charType.getLen();
+ return std::nullopt;
+}
+
mlir::Value hlfir::genRank(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity, mlir::Type resultType) {
if (!entity.isAssumedRank())
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index d8e36ea..ce8ebaa 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -2284,6 +2284,212 @@ public:
}
};
+static std::pair<mlir::Value, hlfir::AssociateOp>
+getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
+ // If it is an expression - create a variable from it, or forward
+ // the value otherwise.
+ hlfir::AssociateOp associate;
+ if (!mlir::isa<hlfir::ExprType>(val.getType()))
+ return {val, associate};
+ hlfir::Entity entity{val};
+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
+ associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(),
+ "", byRefAttr);
+ return {associate.getBase(), associate};
+}
+
+class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::IndexOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // We simplify only limited cases:
+ // 1) a substring length shall be known at compile time
+ // 2) if a substring length is 0 then replace with 1 for forward search,
+ // or otherwise with the string length + 1 (builder shall const-fold if
+ // lookup direction is known at compile time).
+ // 3) for known string length at compile time, if it is
+ // shorter than substring => replace with zero.
+ // 4) if a substring length is one => inline as simple search loop
+ // 5) for forward search with input strings of kind=1 runtime is faster.
+ // Do not simplify in all the other cases relying on a runtime call.
+
+ fir::FirOpBuilder builder{rewriter, op.getOperation()};
+ const mlir::Location &loc = op->getLoc();
+
+ auto resultTy = op.getType();
+ mlir::Value back = op.getBack();
+ auto substrLenCst =
+ hlfir::getCharLengthIfConst(hlfir::Entity{op.getSubstr()});
+ if (!substrLenCst) {
+ return rewriter.notifyMatchFailure(
+ op, "substring length unknown at compile time");
+ }
+ hlfir::Entity strEntity{op.getStr()};
+ auto i1Ty = builder.getI1Type();
+ auto idxTy = builder.getIndexType();
+ if (*substrLenCst == 0) {
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+ // zero length substring. For back search replace with
+ // strLen+1, or otherwise with 1.
+ mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity);
+ mlir::Value strEnd = mlir::arith::AddIOp::create(
+ builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
+ if (back)
+ back = builder.createConvert(loc, i1Ty, back);
+ else
+ back = builder.createIntegerConstant(loc, i1Ty, 0);
+ mlir::Value result =
+ mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx);
+
+ rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result));
+ return mlir::success();
+ }
+
+ if (auto strLenCst = hlfir::getCharLengthIfConst(strEntity)) {
+ if (*strLenCst < *substrLenCst) {
+ rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0));
+ return mlir::success();
+ }
+ if (*strLenCst == 0) {
+ // both strings have zero length
+ rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1));
+ return mlir::success();
+ }
+ }
+ if (*substrLenCst != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "rely on runtime implementation if substring length > 1");
+ }
+ // For forward search and character kind=1 the runtime uses memchr
+ // which well optimized. But it looks like memchr idiom is not recognized
+ // in LLVM yet. On a micro-kernel test with strings of length 40 runtime
+ // had ~2x less execution time vs inlined code. For unknown search direction
+ // at compile time pessimistically assume "forward".
+ std::optional<bool> isBack;
+ if (back) {
+ if (auto backCst = fir::getIntIfConstant(back))
+ isBack = *backCst != 0;
+ } else {
+ isBack = false;
+ }
+ auto charTy = mlir::cast<fir::CharacterType>(
+ hlfir::getFortranElementType(op.getSubstr().getType()));
+ unsigned kind = charTy.getFKind();
+ if (kind == 1 && (!isBack || !*isBack)) {
+ return rewriter.notifyMatchFailure(
+ op, "rely on runtime implementation for character kind 1");
+ }
+
+ // All checks are passed here. Generate single character search loop.
+ auto [strV, strAssociate] = getVariable(builder, loc, op.getStr());
+ auto [substrV, substrAssociate] = getVariable(builder, loc, op.getSubstr());
+ hlfir::Entity str{strV};
+ hlfir::Entity substr{substrV};
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+ auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
+ kind](mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ hlfir::Entity &charStr,
+ mlir::Value index) {
+ auto bits = builder.getKindMap().getCharacterBitsize(kind);
+ auto intTy = builder.getIntegerType(bits);
+ auto charLen1Ty =
+ fir::CharacterType::getSingleton(builder.getContext(), kind);
+ mlir::Type designatorTy =
+ fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
+ auto idxAttr = builder.getIntegerAttr(idxTy, 0);
+
+ auto singleChr = hlfir::DesignateOp::create(
+ builder, loc, designatorTy, charStr, /*component=*/{},
+ /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
+ /*substring=*/mlir::ValueRange{index, index},
+ /*complexPart=*/std::nullopt,
+ /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx},
+ fir::FortranVariableFlagsAttr{});
+ auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
+ mlir::Value intVal = fir::ExtractValueOp::create(
+ builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
+ return intVal;
+ };
+
+ auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx);
+
+ // Generate search loop body with the following C equivalent:
+ // idx_t result = 0;
+ // idx_t end = strlen + 1;
+ // char want = substr[0];
+ // for (idx_t idx = 1; idx < end; ++idx) {
+ // if (result == 0) {
+ // idx_t at = back ? end - idx: idx;
+ // result = str[at-1] == want ? at : result;
+ // }
+ // }
+ mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity);
+ if (!back)
+ back = builder.createIntegerConstant(loc, i1Ty, 0);
+ else
+ back = builder.createConvert(loc, i1Ty, back);
+ mlir::Value strEnd = mlir::arith::AddIOp::create(
+ builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
+ mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
+ auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange index,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 1> {
+ assert(index.size() == 1 && "expected single loop");
+ assert(reductionArgs.size() == 1 && "expected single reduction value");
+ mlir::Value inRes = reductionArgs[0];
+ auto resEQzero = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);
+
+ mlir::Value res =
+ builder
+ .genIfOp(loc, {idxTy}, resEQzero,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value idx = builder.createConvert(loc, idxTy, index[0]);
+ // offset = back ? end - idx : idx;
+ mlir::Value offset = mlir::arith::SelectOp::create(
+ builder, loc, back,
+ mlir::arith::SubIOp::create(builder, loc, strEnd, idx),
+ idx);
+
+ auto haveChar =
+ genExtractAndConvertToInt(loc, builder, str, offset);
+ auto charsEQ = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
+ wantChar);
+ mlir::Value newVal = mlir::arith::SelectOp::create(
+ builder, loc, charsEQ, offset, inRes);
+
+ fir::ResultOp::create(builder, loc, newVal);
+ })
+ .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
+ .getResults()[0];
+ return {res};
+ };
+
+ llvm::SmallVector<mlir::Value, 1> loopOut =
+ hlfir::genLoopNestWithReductions(loc, builder, {strLen},
+ /*reductionInits=*/{zeroIdx},
+ genSearchBody,
+ /*isUnordered=*/false);
+ mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]);
+
+ if (strAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, strAssociate);
+ if (substrAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, substrAssociate);
+
+ rewriter.replaceOp(op, result);
+ return mlir::success();
+ }
+};
+
template <typename Op>
class MatmulConversion : public mlir::OpRewritePattern<Op> {
public:
@@ -2955,6 +3161,7 @@ public:
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
patterns.insert<CmpCharOpConversion>(context);
+ patterns.insert<IndexOpConversion>(context);
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
diff --git a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
index bdf7e4a..e006d2e 100644
--- a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
+++ b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
@@ -285,11 +285,16 @@ mlir::LLVM::DIModuleAttr AddDebugInfoPass::getOrCreateModuleAttr(
if (auto iter{moduleMap.find(name)}; iter != moduleMap.end()) {
modAttr = iter->getValue();
} else {
+ // When decl is true, it means that module is only being used in this
+ // compilation unit and it is defined elsewhere. But if the file/line/scope
+ // fields are valid, the module is not merged with its definition and is
+ // considered different. So we only set those fields when decl is false.
modAttr = mlir::LLVM::DIModuleAttr::get(
- context, fileAttr, scope, mlir::StringAttr::get(context, name),
+ context, decl ? nullptr : fileAttr, decl ? nullptr : scope,
+ mlir::StringAttr::get(context, name),
/* configMacros */ mlir::StringAttr(),
/* includePath */ mlir::StringAttr(),
- /* apinotes */ mlir::StringAttr(), line, decl);
+ /* apinotes */ mlir::StringAttr(), decl ? 0 : line, decl);
moduleMap[name] = modAttr;
}
return modAttr;