//===-- HlfirIntrinsics.cpp -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Lower/HlfirIntrinsics.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/MutableBox.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "mlir/IR/Value.h" #include "llvm/ADT/SmallVector.h" #include namespace { class HlfirTransformationalIntrinsic { public: explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder, mlir::Location loc) : builder(builder), loc(loc) {} virtual ~HlfirTransformationalIntrinsic() = default; hlfir::EntityWithAttributes lower(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType); for (const hlfir::CleanupFunction &fn : cleanupFns) fn(); return {hlfir::EntityWithAttributes{res}}; } protected: fir::FirOpBuilder &builder; mlir::Location loc; llvm::SmallVector cleanupFns; virtual mlir::Value lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) = 0; llvm::SmallVector getOperandVector( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering); mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType); template inline OP createOp(BUILD_ARGS... args) { return builder.create(loc, args...); } mlir::Value loadBoxAddress( const std::optional &arg); void addCleanup(std::optional cleanup) { if (cleanup) cleanupFns.emplace_back(std::move(*cleanup)); } }; template class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic { public: using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; protected: mlir::Value lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) override; }; using HlfirSumLowering = HlfirReductionIntrinsic; using HlfirProductLowering = HlfirReductionIntrinsic; using HlfirMaxvalLowering = HlfirReductionIntrinsic; using HlfirMinvalLowering = HlfirReductionIntrinsic; using HlfirAnyLowering = HlfirReductionIntrinsic; using HlfirAllLowering = HlfirReductionIntrinsic; template class HlfirMinMaxLocIntrinsic : public HlfirTransformationalIntrinsic { public: using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; protected: mlir::Value lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) override; }; using HlfirMinlocLowering = HlfirMinMaxLocIntrinsic; using HlfirMaxlocLowering = HlfirMinMaxLocIntrinsic; template class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic { public: using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; protected: mlir::Value lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) override; }; using HlfirMatmulLowering = HlfirProductIntrinsic; using HlfirDotProductLowering = HlfirProductIntrinsic; class HlfirTransposeLowering : public HlfirTransformationalIntrinsic { public: using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; protected: mlir::Value lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) override; }; class HlfirCountLowering : public HlfirTransformationalIntrinsic { public: using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; protected: mlir::Value lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) override; }; class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic { public: HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc, hlfir::CharExtremumPredicate pred) : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {} protected: mlir::Value lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) override; protected: hlfir::CharExtremumPredicate pred; }; } // namespace mlir::Value HlfirTransformationalIntrinsic::loadBoxAddress( const std::optional &arg) { if (!arg) return mlir::Value{}; hlfir::Entity actual = arg->getActual(loc, builder); if (!arg->handleDynamicOptional()) { if (actual.isMutableBox()) { // this is a box address type but is not dynamically optional. Just load // the box, assuming it is well formed (!fir.ref> -> // !fir.box<...>) return builder.create(loc, actual.getBase()); } return actual; } auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, actual); addCleanup(cleanup); mlir::Value isPresent = arg->getIsPresent(); // createBox will not do create any invalid memory dereferences if exv is // absent. The created fir.box will not be usable, but the SelectOp below // ensures it won't be. mlir::Value box = builder.createBox(loc, exv); mlir::Type boxType = box.getType(); auto absent = builder.create(loc, boxType); auto boxOrAbsent = builder.create( loc, boxType, isPresent, box, absent); return boxOrAbsent; } static mlir::Value loadOptionalValue( mlir::Location loc, fir::FirOpBuilder &builder, const std::optional &arg, hlfir::Entity actual) { if (!arg->handleDynamicOptional()) return hlfir::loadTrivialScalar(loc, builder, actual); mlir::Value isPresent = arg->getIsPresent(); mlir::Type eleType = hlfir::getFortranElementType(actual.getType()); return builder .genIfOp(loc, {eleType}, isPresent, /*withElseRegion=*/true) .genThen([&]() { assert(actual.isScalar() && fir::isa_trivial(eleType) && "must be a numerical or logical scalar"); hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, actual); builder.create(loc, val); }) .genElse([&]() { mlir::Value zero = fir::factory::createZeroValue(builder, loc, eleType); builder.create(loc, zero); }) .getResults()[0]; } llvm::SmallVector HlfirTransformationalIntrinsic::getOperandVector( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering) { llvm::SmallVector operands; operands.reserve(loweredActuals.size()); for (size_t i = 0; i < loweredActuals.size(); ++i) { std::optional arg = loweredActuals[i]; if (!arg) { operands.emplace_back(); continue; } hlfir::Entity actual = arg->getActual(loc, builder); mlir::Value valArg; if (!argLowering) { valArg = hlfir::loadTrivialScalar(loc, builder, actual); } else { fir::ArgLoweringRule argRules = fir::lowerIntrinsicArgumentAs(*argLowering, i); if (argRules.lowerAs == fir::LowerIntrinsicArgAs::Box) valArg = loadBoxAddress(arg); else if (!argRules.handleDynamicOptional && argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired) valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual); else if (argRules.handleDynamicOptional && argRules.lowerAs == fir::LowerIntrinsicArgAs::Value) valArg = loadOptionalValue(loc, builder, arg, actual); else if (argRules.handleDynamicOptional) TODO(loc, "hlfir transformational intrinsic dynamically optional " "argument without box lowering"); else valArg = actual.getBase(); } operands.emplace_back(valArg); } return operands; } mlir::Type HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray, mlir::Type stmtResultType) { mlir::Type normalisedResult = hlfir::getFortranElementOrSequenceType(stmtResultType); if (auto array = normalisedResult.dyn_cast()) { hlfir::ExprType::Shape resultShape = hlfir::ExprType::Shape{array.getShape()}; mlir::Type elementType = array.getEleTy(); return hlfir::ExprType::get(builder.getContext(), resultShape, elementType, /*polymorphic=*/false); } else if (auto resCharType = mlir::dyn_cast(stmtResultType)) { normalisedResult = hlfir::ExprType::get( builder.getContext(), hlfir::ExprType::Shape{}, resCharType, false); } return normalisedResult; } template mlir::Value HlfirReductionIntrinsic::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { auto operands = getOperandVector(loweredActuals, argLowering); mlir::Value array = operands[0]; mlir::Value dim = operands[1]; // dim, mask can be NULL if these arguments are not given if (dim) dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); mlir::Type resultTy = computeResultType(array, stmtResultType); OP op; if constexpr (HAS_MASK) op = createOp(resultTy, array, dim, /*mask=*/operands[2]); else op = createOp(resultTy, array, dim); return op; } template mlir::Value HlfirMinMaxLocIntrinsic::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { auto operands = getOperandVector(loweredActuals, argLowering); mlir::Value array = operands[0]; mlir::Value dim = operands[1]; mlir::Value mask = operands[2]; mlir::Value back = operands[4]; // dim, mask and back can be NULL if these arguments are not given. if (dim) dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); if (back) back = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{back}); mlir::Type resultTy = computeResultType(array, stmtResultType); return createOp(resultTy, array, dim, mask, back); } template mlir::Value HlfirProductIntrinsic::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { auto operands = getOperandVector(loweredActuals, argLowering); mlir::Type resultType = computeResultType(operands[0], stmtResultType); return createOp(resultType, operands[0], operands[1]); } mlir::Value HlfirTransposeLowering::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { auto operands = getOperandVector(loweredActuals, argLowering); hlfir::ExprType::Shape resultShape; mlir::Type normalisedResult = hlfir::getFortranElementOrSequenceType(stmtResultType); auto array = normalisedResult.cast(); llvm::ArrayRef arrayShape = array.getShape(); assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2"); mlir::Type elementType = array.getEleTy(); resultShape.push_back(arrayShape[0]); resultShape.push_back(arrayShape[1]); if (auto resCharType = mlir::dyn_cast(elementType)) if (!resCharType.hasConstantLen()) { // The FunctionRef expression might have imprecise character // type at this point, and we can improve it by propagating // the constant length from the argument. auto argCharType = mlir::dyn_cast( hlfir::getFortranElementType(operands[0].getType())); if (argCharType && argCharType.hasConstantLen()) elementType = fir::CharacterType::get( builder.getContext(), resCharType.getFKind(), argCharType.getLen()); } mlir::Type resultTy = hlfir::ExprType::get(builder.getContext(), resultShape, elementType, fir::isPolymorphicType(stmtResultType)); return createOp(resultTy, operands[0]); } mlir::Value HlfirCountLowering::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { auto operands = getOperandVector(loweredActuals, argLowering); mlir::Value array = operands[0]; mlir::Value dim = operands[1]; if (dim) dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); mlir::Type resultType = computeResultType(array, stmtResultType); return createOp(resultType, array, dim); } mlir::Value HlfirCharExtremumLowering::lowerImpl( const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { auto operands = getOperandVector(loweredActuals, argLowering); assert(operands.size() >= 2); return createOp(pred, mlir::ValueRange{operands}); } std::optional Fortran::lower::lowerHlfirIntrinsic( fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name, const Fortran::lower::PreparedActualArguments &loweredActuals, const fir::IntrinsicArgumentLoweringRules *argLowering, mlir::Type stmtResultType) { // If the result is of a derived type that may need finalization, // we have to use DestroyOp with 'finalize' attribute for the result // of the intrinsic operation. if (name == "sum") return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "product") return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "any") return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "all") return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "matmul") return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "dot_product") return HlfirDotProductLowering{builder, loc}.lower( loweredActuals, argLowering, stmtResultType); // FIXME: the result may need finalization. if (name == "transpose") return HlfirTransposeLowering{builder, loc}.lower( loweredActuals, argLowering, stmtResultType); if (name == "count") return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "maxval") return HlfirMaxvalLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "minval") return HlfirMinvalLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "minloc") return HlfirMinlocLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (name == "maxloc") return HlfirMaxlocLowering{builder, loc}.lower(loweredActuals, argLowering, stmtResultType); if (mlir::isa(stmtResultType)) { if (name == "min") return HlfirCharExtremumLowering{builder, loc, hlfir::CharExtremumPredicate::min} .lower(loweredActuals, argLowering, stmtResultType); if (name == "max") return HlfirCharExtremumLowering{builder, loc, hlfir::CharExtremumPredicate::max} .lower(loweredActuals, argLowering, stmtResultType); } return std::nullopt; }