aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/CUDA.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/CUDA.cpp')
-rw-r--r--flang/lib/Lower/CUDA.cpp157
1 files changed, 157 insertions, 0 deletions
diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp
new file mode 100644
index 0000000..f6d0078
--- /dev/null
+++ b/flang/lib/Lower/CUDA.cpp
@@ -0,0 +1,157 @@
+//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Lower/CUDA.h"
+#include "flang/Lower/AbstractConverter.h"
+#include "flang/Optimizer/Builder/Todo.h"
+
+#define DEBUG_TYPE "flang-lower-cuda"
+
+void Fortran::lower::initializeDeviceComponentAllocator(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box) {
+ if (const auto *details{
+ sym.GetUltimate()
+ .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
+ const Fortran::semantics::DeclTypeSpec *type{details->type()};
+ const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
+ : nullptr};
+ if (derived) {
+ if (!FindCUDADeviceAllocatableUltimateComponent(*derived))
+ return; // No device components.
+
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ mlir::Location loc = converter.getCurrentLocation();
+
+ mlir::Type baseTy = fir::unwrapRefType(box.getAddr().getType());
+
+ // Only pointer and allocatable needs post allocation initialization
+ // of components descriptors.
+ if (!fir::isAllocatableType(baseTy) && !fir::isPointerType(baseTy))
+ return;
+
+ // Extract the derived type.
+ mlir::Type ty = fir::getDerivedType(baseTy);
+ auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
+ assert(recTy && "expected fir::RecordType");
+
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseTy))
+ baseTy = boxTy.getEleTy();
+ baseTy = fir::unwrapRefType(baseTy);
+
+ Fortran::semantics::UltimateComponentIterator components{*derived};
+ mlir::Value loadedBox = fir::LoadOp::create(builder, loc, box.getAddr());
+ mlir::Value addr;
+ if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(baseTy)) {
+ mlir::Type idxTy = builder.getIndexType();
+ mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+ mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+ llvm::SmallVector<fir::DoLoopOp> loops;
+ llvm::SmallVector<mlir::Value> indices;
+ llvm::SmallVector<mlir::Value> extents;
+ for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
+ mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
+ auto dimInfo = fir::BoxDimsOp::create(builder, loc, idxTy, idxTy,
+ idxTy, loadedBox, dim);
+ mlir::Value lbub = mlir::arith::AddIOp::create(
+ builder, loc, dimInfo.getResult(0), dimInfo.getResult(1));
+ mlir::Value ext =
+ mlir::arith::SubIOp::create(builder, loc, lbub, one);
+ mlir::Value cmp = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sgt, ext, zero);
+ ext = mlir::arith::SelectOp::create(builder, loc, cmp, ext, zero);
+ extents.push_back(ext);
+
+ auto loop = fir::DoLoopOp::create(
+ builder, loc, dimInfo.getResult(0), dimInfo.getResult(1),
+ dimInfo.getResult(2), /*isUnordered=*/true,
+ /*finalCount=*/false, mlir::ValueRange{});
+ loops.push_back(loop);
+ indices.push_back(loop.getInductionVar());
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+ mlir::Value boxAddr = fir::BoxAddrOp::create(builder, loc, loadedBox);
+ auto shape = fir::ShapeOp::create(builder, loc, extents);
+ addr = fir::ArrayCoorOp::create(
+ builder, loc, fir::ReferenceType::get(recTy), boxAddr, shape,
+ /*slice=*/mlir::Value{}, indices, /*typeparms=*/mlir::ValueRange{});
+ } else {
+ addr = fir::BoxAddrOp::create(builder, loc, loadedBox);
+ }
+ for (const auto &compSym : components) {
+ if (Fortran::semantics::IsDeviceAllocatable(compSym)) {
+ llvm::SmallVector<mlir::Value> coord;
+ mlir::Type fieldTy = gatherDeviceComponentCoordinatesAndType(
+ builder, loc, compSym, recTy, coord);
+ assert(coord.size() == 1 && "expect one coordinate");
+ mlir::Value comp = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(fieldTy), addr, coord[0]);
+ cuf::DataAttributeAttr dataAttr =
+ Fortran::lower::translateSymbolCUFDataAttribute(
+ builder.getContext(), compSym);
+ cuf::SetAllocatorIndexOp::create(builder, loc, comp, dataAttr);
+ }
+ }
+ }
+ }
+}
+
+mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType(
+ fir::FirOpBuilder &builder, mlir::Location loc,
+ const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
+ llvm::SmallVector<mlir::Value> &coordinates) {
+ unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
+ mlir::Type fieldTy;
+ if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+ // Field found in the base record type.
+ auto fieldName = recTy.getTypeList()[fieldIdx].first;
+ fieldTy = recTy.getTypeList()[fieldIdx].second;
+ mlir::Value fieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
+ recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ coordinates.push_back(fieldIndex);
+ } else {
+ // Field not found in base record type, search in potential
+ // record type components.
+ for (auto component : recTy.getTypeList()) {
+ if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) {
+ fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
+ if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+ mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fir::FieldType::get(childRecTy.getContext()),
+ component.first, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ coordinates.push_back(parentFieldIndex);
+ auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
+ fieldTy = childRecTy.getTypeList()[fieldIdx].second;
+ mlir::Value childFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fir::FieldType::get(fieldTy.getContext()),
+ fieldName, childRecTy,
+ /*typeParams=*/mlir::ValueRange{});
+ coordinates.push_back(childFieldIndex);
+ break;
+ }
+ }
+ }
+ }
+ if (coordinates.empty())
+ TODO(loc, "device resident component in complex derived-type hierarchy");
+ return fieldTy;
+}
+
+cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
+ mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
+ std::optional<Fortran::common::CUDADataAttr> cudaAttr =
+ Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
+ return cuf::getDataAttribute(mlirContext, cudaAttr);
+}