//===-- FIRToMemRef.cpp - Convert FIR loads and stores to MemRef ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass lowers FIR dialect memory operations to the MemRef dialect. // In particular it: // // - Rewrites `fir.alloca` to `memref.alloca`. // // - Rewrites `fir.load` / `fir.store` to `memref.load` / `memref.store`. // // - Allows FIR and MemRef to coexist by introducing `fir.convert` at // memory-use sites. Memory operations (`memref.load`, `memref.store`, // `memref.reinterpret_cast`, etc.) see MemRef-typed values, while the // original FIR-typed values remain available for non-memory uses. For // example: // // %fir_ref = ... : !fir.ref> // %memref = fir.convert %fir_ref // : !fir.ref> -> memref<...> // %val = memref.load %memref[...] : memref<...> // fir.call @callee(%fir_ref) : (!fir.ref>) -> () // // Here the MemRef-typed value is used for `memref.load`, while the // original FIR-typed value is preserved for `fir.call`. // // - Computes shapes, strides, and indices as needed for slices and shifts // and emits `memref.reinterpret_cast` when dynamic layout is required // (TODO: use memref.cast instead). // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "fir-to-memref" using namespace mlir; namespace fir { #define GEN_PASS_DEF_FIRTOMEMREF #include "flang/Optimizer/Transforms/Passes.h.inc" static bool isMarshalLike(Operation *op) { auto convert = dyn_cast_if_present(op); if (!convert) return false; bool resIsMemRef = isa(convert.getType()); bool argIsMemRef = isa(convert.getValue().getType()); assert(!(resIsMemRef && argIsMemRef) && "unexpected fir.convert memref -> memref in isMarshalLike"); return resIsMemRef || argIsMemRef; } using MemRefInfo = FailureOr>>; static llvm::cl::opt enableFIRConvertOptimizations( "enable-fir-convert-opts", llvm::cl::desc("enable emilinating redundant fir.convert in FIR-to-MemRef"), llvm::cl::init(false), llvm::cl::Hidden); class FIRToMemRef : public fir::impl::FIRToMemRefBase { public: void runOnOperation() override; private: llvm::SmallSetVector eraseOps; DominanceInfo *domInfo = nullptr; void rewriteAlloca(fir::AllocaOp, PatternRewriter &, FIRToMemRefTypeConverter &); void rewriteLoadOp(fir::LoadOp, PatternRewriter &, FIRToMemRefTypeConverter &); void rewriteStoreOp(fir::StoreOp, PatternRewriter &, FIRToMemRefTypeConverter &); MemRefInfo getMemRefInfo(Value, PatternRewriter &, FIRToMemRefTypeConverter &, Operation *); MemRefInfo convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp, PatternRewriter &, FIRToMemRefTypeConverter &); void replaceFIRMemrefs(Value, Value, PatternRewriter &) const; FailureOr getFIRConvert(Operation *memOp, Operation *memref, PatternRewriter &, FIRToMemRefTypeConverter &); FailureOr> getMemrefIndices(fir::ArrayCoorOp, Operation *, PatternRewriter &, Value, Value) const; bool memrefIsOptional(Operation *) const; Value canonicalizeIndex(Value, PatternRewriter &) const; template void getShapeFrom(OpTy op, SmallVector &shapeVec, SmallVector &shiftVec, SmallVector &sliceVec) const; void populateShapeAndShift(SmallVectorImpl &shapeVec, SmallVectorImpl &shiftVec, fir::ShapeShiftOp shift) const; void populateShift(SmallVectorImpl &vec, fir::ShiftOp shift) const; void populateShape(SmallVectorImpl &vec, fir::ShapeOp shape) const; unsigned getRankFromEmbox(fir::EmboxOp embox) const { auto memrefType = embox.getMemref().getType(); Type unwrappedType = fir::unwrapRefType(memrefType); if (auto seqType = dyn_cast(unwrappedType)) return seqType.getDimension(); return 0; } bool isCompilerGeneratedAlloca(Operation *op) const; void copyAttribute(Operation *from, Operation *to, llvm::StringRef name) const; Type getBaseType(Type type, bool complexBaseTypes = false) const; bool memrefIsDeviceData(Operation *memref) const; mlir::Attribute findCudaDataAttr(Value val) const; }; void FIRToMemRef::populateShapeAndShift(SmallVectorImpl &shapeVec, SmallVectorImpl &shiftVec, fir::ShapeShiftOp shift) const { for (mlir::OperandRange::iterator i = shift.getPairs().begin(), endIter = shift.getPairs().end(); i != endIter;) { shiftVec.push_back(*i++); shapeVec.push_back(*i++); } } bool FIRToMemRef::isCompilerGeneratedAlloca(Operation *op) const { if (!isa(op)) llvm_unreachable("expected alloca op"); return !op->getAttr("bindc_name") && !op->getAttr("uniq_name"); } void FIRToMemRef::copyAttribute(Operation *from, Operation *to, llvm::StringRef name) const { if (Attribute value = from->getAttr(name)) to->setAttr(name, value); } Type FIRToMemRef::getBaseType(Type type, bool complexBaseTypes) const { if (fir::isa_fir_type(type)) { type = fir::getFortranElementType(type); } else if (auto memrefTy = dyn_cast(type)) { type = memrefTy.getElementType(); } if (!complexBaseTypes) if (auto complexTy = dyn_cast(type)) type = complexTy.getElementType(); return type; } bool FIRToMemRef::memrefIsDeviceData(Operation *memref) const { if (isa(memref)) return true; return cuf::hasDeviceDataAttr(memref); } mlir::Attribute FIRToMemRef::findCudaDataAttr(Value val) const { Value currentVal = val; llvm::SmallPtrSet visited; while (currentVal) { Operation *defOp = currentVal.getDefiningOp(); if (!defOp || !visited.insert(defOp).second) break; if (cuf::DataAttributeAttr cudaAttr = cuf::getDataAttr(defOp)) return cudaAttr; // TODO: This is a best-effort backward walk; it is easy to miss attributes // as FIR evolves. Long term, it would be preferable if the necessary // information was carried in the type system (or otherwise made available // without relying on a walk-back through defining ops). if (auto reboxOp = dyn_cast(defOp)) { currentVal = reboxOp.getBox(); } else if (auto convertOp = dyn_cast(defOp)) { currentVal = convertOp->getOperand(0); } else if (auto emboxOp = dyn_cast(defOp)) { currentVal = emboxOp.getMemref(); } else if (auto boxAddrOp = dyn_cast(defOp)) { currentVal = boxAddrOp.getVal(); } else if (auto declareOp = dyn_cast(defOp)) { currentVal = declareOp.getMemref(); } else { break; } } return nullptr; } void FIRToMemRef::populateShift(SmallVectorImpl &vec, fir::ShiftOp shift) const { vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); } void FIRToMemRef::populateShape(SmallVectorImpl &vec, fir::ShapeOp shape) const { vec.append(shape.getExtents().begin(), shape.getExtents().end()); } template void FIRToMemRef::getShapeFrom(OpTy op, SmallVector &shapeVec, SmallVector &shiftVec, SmallVector &sliceVec) const { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { Value shapeVal = op.getShape(); if (shapeVal) { Operation *shapeValOp = shapeVal.getDefiningOp(); if (auto shapeOp = dyn_cast(shapeValOp)) { populateShape(shapeVec, shapeOp); } else if (auto shapeShiftOp = dyn_cast(shapeValOp)) { populateShapeAndShift(shapeVec, shiftVec, shapeShiftOp); } else if (auto shiftOp = dyn_cast(shapeValOp)) { populateShift(shiftVec, shiftOp); } } Value sliceVal = op.getSlice(); if (sliceVal) { if (auto sliceOp = sliceVal.getDefiningOp()) { auto triples = sliceOp.getTriples(); sliceVec.append(triples.begin(), triples.end()); } } } } void FIRToMemRef::rewriteAlloca(fir::AllocaOp firAlloca, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { if (!typeConverter.convertibleType(firAlloca.getInType())) return; if (typeConverter.isEmptyArray(firAlloca.getType())) return; rewriter.setInsertionPointAfter(firAlloca); Type type = firAlloca.getType(); MemRefType memrefTy = typeConverter.convertMemrefType(type); Location loc = firAlloca.getLoc(); SmallVector sizes = firAlloca.getOperands(); std::reverse(sizes.begin(), sizes.end()); auto alloca = memref::AllocaOp::create(rewriter, loc, memrefTy, sizes); copyAttribute(firAlloca, alloca, firAlloca.getBindcNameAttrName()); copyAttribute(firAlloca, alloca, firAlloca.getUniqNameAttrName()); copyAttribute(firAlloca, alloca, cuf::getDataAttrName()); auto convert = fir::ConvertOp::create(rewriter, loc, type, alloca); rewriter.replaceOp(firAlloca, convert); if (isCompilerGeneratedAlloca(alloca)) { for (Operation *userOp : convert->getUsers()) { if (auto declareOp = dyn_cast(userOp)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: removing declare for compiler temp:\n"; declareOp->dump()); declareOp->replaceAllUsesWith(convert); eraseOps.insert(userOp); } } } } bool FIRToMemRef::memrefIsOptional(Operation *op) const { if (auto declare = dyn_cast(op)) { if (fir::FortranVariableOpInterface(declare).isOptional()) return true; Value operand = declare.getMemref(); Operation *operandOp = operand.getDefiningOp(); if (operandOp && isa(operandOp)) return true; } for (mlir::Value result : op->getResults()) for (mlir::Operation *userOp : result.getUsers()) if (isa(userOp)) return true; // TODO: If `op` is not a `fir.declare`, OPTIONAL information may still be // present on a related `fir.declare` reached by tracing the address/box // through common forwarding ops (e.g. `fir.convert`, `fir.rebox`, // `fir.embox`, `fir.box_addr`), then checking `declare.isOptional()`. Add the // search after FIR improves on it. return false; } static Value castTypeToIndexType(Value originalValue, PatternRewriter &rewriter) { if (originalValue.getType().isIndex()) return originalValue; Type indexType = rewriter.getIndexType(); return arith::IndexCastOp::create(rewriter, originalValue.getLoc(), indexType, originalValue); } FailureOr> FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref, PatternRewriter &rewriter, Value converted, Value one) const { IndexType indexTy = rewriter.getIndexType(); SmallVector indices; Location loc = arrayCoorOp->getLoc(); SmallVector shiftVec, shapeVec, sliceVec; int rank = arrayCoorOp.getIndices().size(); getShapeFrom(arrayCoorOp, shapeVec, shiftVec, sliceVec); if (auto embox = dyn_cast_or_null(memref)) { getShapeFrom(embox, shapeVec, shiftVec, sliceVec); rank = getRankFromEmbox(embox); } SmallVector sliceLbs, sliceStrides; for (size_t i = 0; i < sliceVec.size(); i += 3) { sliceLbs.push_back(castTypeToIndexType(sliceVec[i], rewriter)); sliceStrides.push_back(castTypeToIndexType(sliceVec[i + 2], rewriter)); } const bool isShifted = !shiftVec.empty(); const bool isSliced = !sliceVec.empty(); ValueRange idxs = arrayCoorOp.getIndices(); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector filledPositions(rank, false); for (int i = 0; i < rank; ++i) { Value step = isSliced ? sliceStrides[i] : one; Operation *stepOp = step.getDefiningOp(); if (stepOp && mlir::isa_and_nonnull(stepOp)) { Value shift = isShifted ? shiftVec[i] : one; Value sliceLb = isSliced ? sliceLbs[i] : shift; Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift); indices.push_back(offset); filledPositions[i] = true; } else { indices.push_back(zero); } } int arrayCoorIdx = 0; for (int i = 0; i < rank; ++i) { if (filledPositions[i]) continue; assert((unsigned int)arrayCoorIdx < idxs.size() && "empty dimension should be eliminated\n"); Value index = canonicalizeIndex(idxs[arrayCoorIdx], rewriter); Type cTy = index.getType(); if (!llvm::isa(cTy)) { assert(cTy.isSignlessInteger() && "expected signless integer type"); index = arith::IndexCastOp::create(rewriter, loc, indexTy, index); } Value shift = isShifted ? shiftVec[i] : one; Value stride = isSliced ? sliceStrides[i] : one; Value sliceLb = isSliced ? sliceLbs[i] : shift; Value oneIdx = arith::ConstantIndexOp::create(rewriter, loc, 1); Value indexAdjustment = isSliced ? oneIdx : sliceLb; Value delta = arith::SubIOp::create(rewriter, loc, index, indexAdjustment); Value scaled = arith::MulIOp::create(rewriter, loc, delta, stride); Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift); Value finalIndex = arith::AddIOp::create(rewriter, loc, scaled, offset); indices[i] = finalIndex; arrayCoorIdx++; } std::reverse(indices.begin(), indices.end()); return indices; } MemRefInfo FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { IndexType indexTy = rewriter.getIndexType(); Value firMemref = arrayCoorOp.getMemref(); if (!typeConverter.convertibleMemrefType(firMemref.getType())) return failure(); if (typeConverter.isEmptyArray(firMemref.getType())) return failure(); if (auto blockArg = dyn_cast(firMemref)) { Value elemRef = arrayCoorOp.getResult(); rewriter.setInsertionPointAfter(arrayCoorOp); Location loc = arrayCoorOp->getLoc(); Type elemMemrefTy = typeConverter.convertMemrefType(elemRef.getType()); Value converted = fir::ConvertOp::create(rewriter, loc, elemMemrefTy, elemRef); SmallVector indices; return std::pair{converted, indices}; } Operation *memref = firMemref.getDefiningOp(); FailureOr converted; if (enableFIRConvertOptimizations && isMarshalLike(memref) && !fir::isa_fir_type(firMemref.getType())) { converted = firMemref; rewriter.setInsertionPoint(arrayCoorOp); } else { Operation *arrayCoorOperation = arrayCoorOp.getOperation(); rewriter.setInsertionPoint(arrayCoorOp); if (memrefIsOptional(memref)) { auto ifOp = arrayCoorOperation->getParentOfType(); if (ifOp) { Operation *condition = ifOp.getCondition().getDefiningOp(); if (condition && isa(condition)) if (condition->getOperand(0) == firMemref) { if (arrayCoorOperation->getParentRegion() == &ifOp.getThenRegion()) rewriter.setInsertionPointToStart( &(ifOp.getThenRegion().front())); else if (arrayCoorOperation->getParentRegion() == &ifOp.getElseRegion()) rewriter.setInsertionPointToStart( &(ifOp.getElseRegion().front())); } } } converted = getFIRConvert(memOp, memref, rewriter, typeConverter); if (failed(converted)) return failure(); rewriter.setInsertionPointAfter(arrayCoorOp); } Location loc = arrayCoorOp->getLoc(); Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); FailureOr> failureOrIndices = getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one); if (failed(failureOrIndices)) return failure(); SmallVector indices = *failureOrIndices; if (converted == firMemref) return std::pair{*converted, indices}; Value convertedVal = *converted; MemRefType memRefTy = dyn_cast(convertedVal.getType()); bool isRebox = firMemref.getDefiningOp() != nullptr; if (memRefTy.hasStaticShape() && !isRebox) return std::pair{*converted, indices}; unsigned rank = arrayCoorOp.getIndices().size(); if (auto embox = firMemref.getDefiningOp()) rank = getRankFromEmbox(embox); SmallVector sizes; sizes.reserve(rank); SmallVector strides; strides.reserve(rank); SmallVector shapeVec, shiftVec, sliceVec; getShapeFrom(arrayCoorOp, shapeVec, shiftVec, sliceVec); Value box = firMemref; if (!isa(firMemref)) { if (auto embox = firMemref.getDefiningOp()) getShapeFrom(embox, shapeVec, shiftVec, sliceVec); else if (auto rebox = firMemref.getDefiningOp()) getShapeFrom(rebox, shapeVec, shiftVec, sliceVec); } if (shapeVec.empty()) { auto boxElementSize = fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box); for (unsigned i = 0; i < rank; ++i) { Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1); auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy, indexTy, box, dim); Value extent = boxDims->getResult(1); sizes.push_back(castTypeToIndexType(extent, rewriter)); Value byteStride = boxDims->getResult(2); Value div = arith::DivSIOp::create(rewriter, loc, byteStride, boxElementSize); strides.push_back(castTypeToIndexType(div, rewriter)); } } else { Value oneIdx = arith::ConstantIndexOp::create(rewriter, arrayCoorOp->getLoc(), 1); for (unsigned i = rank - 1; i > 0; --i) { Value size = shapeVec[i]; sizes.push_back(castTypeToIndexType(size, rewriter)); Value stride = shapeVec[0]; for (unsigned j = 1; j <= i - 1; ++j) stride = arith::MulIOp::create(rewriter, loc, shapeVec[j], stride); strides.push_back(castTypeToIndexType(stride, rewriter)); } sizes.push_back(castTypeToIndexType(shapeVec[0], rewriter)); strides.push_back(oneIdx); } assert(strides.size() == sizes.size() && sizes.size() == rank); int64_t dynamicOffset = ShapedType::kDynamic; SmallVector dynamicStrides(rank, ShapedType::kDynamic); auto stridedLayout = StridedLayoutAttr::get(convertedVal.getContext(), dynamicOffset, dynamicStrides); SmallVector dynamicShape(rank, ShapedType::kDynamic); memRefTy = MemRefType::get(dynamicShape, memRefTy.getElementType(), stridedLayout); Value offset = arith::ConstantIndexOp::create(rewriter, loc, 0); auto reinterpret = memref::ReinterpretCastOp::create( rewriter, loc, memRefTy, *converted, offset, sizes, strides); Value result = reinterpret->getResult(0); return std::pair{result, indices}; } FailureOr FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { if (enableFIRConvertOptimizations && !op->hasOneUse() && !memrefIsOptional(op)) { for (Operation *userOp : op->getUsers()) { if (auto convertOp = dyn_cast(userOp)) { Value converted = convertOp.getResult(); if (!isa(converted.getType())) continue; if (userOp->getParentOp() == memOp->getParentOp() && domInfo->dominates(userOp, memOp)) return converted; } } } assert(op->getNumResults() == 1 && "expecting one result"); Value basePtr = op->getResult(0); MemRefType memrefTy = typeConverter.convertMemrefType(basePtr.getType()); Type baseTy = memrefTy.getElementType(); if (fir::isa_std_type(baseTy) && memrefTy.getRank() == 0) { if (auto convertOp = basePtr.getDefiningOp()) { Value input = convertOp.getOperand(); if (auto alloca = input.getDefiningOp()) { assert(alloca.getType() == memrefTy && "expected same types"); if (isCompilerGeneratedAlloca(alloca)) return alloca.getResult(); } } } const Location loc = op->getLoc(); if (isa(basePtr.getType())) { Operation *baseOp = basePtr.getDefiningOp(); auto boxAddrOp = fir::BoxAddrOp::create(rewriter, loc, basePtr); if (auto cudaAttr = findCudaDataAttr(basePtr)) boxAddrOp->setAttr(cuf::getDataAttrName(), cudaAttr); basePtr = boxAddrOp; memrefTy = typeConverter.convertMemrefType(basePtr.getType()); if (baseOp) { auto sameBaseBoxTypes = [&](Type baseType, Type memrefType) -> bool { Type emboxBaseTy = getBaseType(baseType, true); Type emboxMemrefTy = getBaseType(memrefType, true); return emboxBaseTy == emboxMemrefTy; }; if (auto embox = dyn_cast_or_null(baseOp)) { if (!sameBaseBoxTypes(embox.getType(), embox.getMemref().getType())) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: embox base type and memref type are not " "the same, bailing out of conversion\n"); return failure(); } if (embox.getSlice() && embox.getSlice().getDefiningOp()) { Type originalType = embox.getMemref().getType(); basePtr = embox.getMemref(); if (typeConverter.convertibleMemrefType(originalType)) { auto convertedMemrefTy = typeConverter.convertMemrefType(originalType); memrefTy = convertedMemrefTy; } else { return failure(); } } } if (auto rebox = dyn_cast(baseOp)) { if (!sameBaseBoxTypes(rebox.getType(), rebox.getBox().getType())) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: rebox base type and box type are not the " "same, bailing out of conversion\n"); return failure(); } Type originalType = rebox.getBox().getType(); if (auto boxTy = dyn_cast(originalType)) originalType = boxTy.getElementType(); if (!typeConverter.convertibleMemrefType(originalType)) { return failure(); } else { auto convertedMemrefTy = typeConverter.convertMemrefType(originalType); memrefTy = convertedMemrefTy; } } } } auto convert = fir::ConvertOp::create(rewriter, loc, memrefTy, basePtr); return convert->getResult(0); } Value FIRToMemRef::canonicalizeIndex(Value index, PatternRewriter &rewriter) const { if (auto blockArg = dyn_cast(index)) return index; Operation *op = index.getDefiningOp(); if (auto constant = dyn_cast(op)) { if (!constant.getType().isIndex()) { Value v = arith::ConstantIndexOp::create(rewriter, op->getLoc(), constant.value()); return v; } return constant; } if (auto extsi = dyn_cast(op)) { Value operand = extsi.getOperand(); if (auto indexCast = operand.getDefiningOp()) { Value v = indexCast.getOperand(); return v; } return canonicalizeIndex(operand, rewriter); } if (auto add = dyn_cast(op)) { Value lhs = canonicalizeIndex(add.getLhs(), rewriter); Value rhs = canonicalizeIndex(add.getRhs(), rewriter); if (lhs.getType() == rhs.getType()) return arith::AddIOp::create(rewriter, op->getLoc(), lhs, rhs); } return index; } MemRefInfo FIRToMemRef::getMemRefInfo(Value firMemref, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter, Operation *memOp) { Operation *memrefOp = firMemref.getDefiningOp(); if (!memrefOp) { if (auto blockArg = dyn_cast(firMemref)) { rewriter.setInsertionPoint(memOp); Type memrefTy = typeConverter.convertMemrefType(blockArg.getType()); if (auto mt = dyn_cast(memrefTy)) if (auto inner = llvm::dyn_cast(mt.getElementType())) memrefTy = inner; Value converted = fir::ConvertOp::create(rewriter, blockArg.getLoc(), memrefTy, blockArg); SmallVector indices; return std::pair{converted, indices}; } llvm_unreachable( "FIRToMemRef: expected defining op or block argument for FIR memref"); } if (auto arrayCoorOp = dyn_cast(memrefOp)) { MemRefInfo memrefInfo = convertArrayCoorOp(memOp, arrayCoorOp, rewriter, typeConverter); if (succeeded(memrefInfo)) { for (auto user : memrefOp->getUsers()) { if (!isa(user)) { LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: array memref used by unsupported op:\n"; firMemref.dump(); user->dump()); return memrefInfo; } } eraseOps.insert(memrefOp); } return memrefInfo; } rewriter.setInsertionPoint(memOp); if (isMarshalLike(memrefOp)) { FailureOr converted = getFIRConvert(memOp, memrefOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: expected FIR memref in convert, bailing " "out:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (auto declareOp = dyn_cast(memrefOp)) { FailureOr converted = getFIRConvert(memOp, declareOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: unable to create convert for scalar " "memref:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (auto coordinateOp = dyn_cast(memrefOp)) { FailureOr converted = getFIRConvert(memOp, coordinateOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: unable to create convert for derived-type " "memref:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (auto convertOp = dyn_cast(memrefOp)) { Type fromTy = convertOp->getOperand(0).getType(); Type toTy = firMemref.getType(); if (isa(fromTy) && isa(toTy)) { FailureOr converted = getFIRConvert(memOp, convertOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: unable to create convert for conversion " "op:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } } if (auto boxAddrOp = dyn_cast(memrefOp)) { FailureOr converted = getFIRConvert(memOp, boxAddrOp, rewriter, typeConverter); if (failed(converted)) { LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: unable to create convert for box_addr " "op:\n"; firMemref.dump()); return failure(); } SmallVector indices; return std::pair{*converted, indices}; } if (memrefIsDeviceData(memrefOp)) { FailureOr converted = getFIRConvert(memOp, memrefOp, rewriter, typeConverter); if (failed(converted)) return failure(); SmallVector indices; return std::pair{*converted, indices}; } LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: unable to create convert for memref value:\n"; firMemref.dump()); return failure(); } void FIRToMemRef::replaceFIRMemrefs(Value firMemref, Value converted, PatternRewriter &rewriter) const { Operation *op = firMemref.getDefiningOp(); if (op && (isa(op) || isMarshalLike(op))) return; SmallPtrSet worklist; for (auto user : firMemref.getUsers()) { if (isMarshalLike(user) || isa(user)) continue; if (!domInfo->dominates(converted, user)) continue; if (!(isa(user->getParentOp()) || isa(user->getParentOp()))) worklist.insert(user); } Type ty = firMemref.getType(); for (auto op : worklist) { rewriter.setInsertionPoint(op); Location loc = op->getLoc(); Value replaceConvert = fir::ConvertOp::create(rewriter, loc, ty, converted); op->replaceUsesOfWith(firMemref, replaceConvert); } worklist.clear(); for (auto user : firMemref.getUsers()) { if (isMarshalLike(user) || isa(user)) continue; if (isa(user->getParentOp()) || isa(user->getParentOp())) if (domInfo->dominates(converted, user)) worklist.insert(user); } if (worklist.empty()) return; while (!worklist.empty()) { Operation *parentOp = (*worklist.begin())->getParentOp(); Value replaceConvert; SmallVector erase; for (auto op : worklist) { if (op->getParentOp() != parentOp) continue; if (!replaceConvert) { rewriter.setInsertionPoint(parentOp); replaceConvert = fir::ConvertOp::create(rewriter, op->getLoc(), ty, converted); } op->replaceUsesOfWith(firMemref, replaceConvert); erase.push_back(op); } for (auto op : erase) worklist.erase(op); } } void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { Value firMemref = load.getMemref(); if (!typeConverter.convertibleType(firMemref.getType())) return; LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR load:\n"; load.dump(); firMemref.dump()); MemRefInfo memrefInfo = getMemRefInfo(firMemref, rewriter, typeConverter, load.getOperation()); if (failed(memrefInfo)) return; Type originalType = load.getResult().getType(); Value converted = memrefInfo->first; SmallVector indices = memrefInfo->second; LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: convert for FIR load created successfully:\n"; converted.dump()); rewriter.setInsertionPointAfter(load); Attribute attr = (load.getOperation())->getAttr("tbaa"); memref::LoadOp loadOp = rewriter.replaceOpWithNewOp(load, converted, indices); if (attr) loadOp.getOperation()->setAttr("tbaa", attr); LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n"; loadOp.dump(); assert(succeeded(verify(loadOp)))); if (isa(originalType)) { Value logicalVal = fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp); loadOp.getResult().replaceAllUsesExcept(logicalVal, logicalVal.getDefiningOp()); } if (!isa(originalType)) replaceFIRMemrefs(firMemref, converted, rewriter); } void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter, FIRToMemRefTypeConverter &typeConverter) { Value firMemref = store.getMemref(); if (!typeConverter.convertibleType(firMemref.getType())) return; LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR store:\n"; store.dump(); firMemref.dump()); MemRefInfo memrefInfo = getMemRefInfo(firMemref, rewriter, typeConverter, store.getOperation()); if (failed(memrefInfo)) return; Value converted = memrefInfo->first; SmallVector indices = memrefInfo->second; LLVM_DEBUG( llvm::dbgs() << "FIRToMemRef: convert for FIR store created successfully:\n"; converted.dump()); Value value = store.getValue(); rewriter.setInsertionPointAfter(store); if (isa(value.getType())) { Type convertedType = typeConverter.convertType(value.getType()); value = fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value); } Attribute attr = (store.getOperation())->getAttr("tbaa"); memref::StoreOp storeOp = rewriter.replaceOpWithNewOp( store, value, converted, indices); if (attr) storeOp.getOperation()->setAttr("tbaa", attr); LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.store op:\n"; storeOp.dump(); assert(succeeded(verify(storeOp)))); bool isLogicalRef = false; if (fir::ReferenceType refTy = llvm::dyn_cast(firMemref.getType())) isLogicalRef = llvm::isa(refTy.getEleTy()); if (!isLogicalRef) replaceFIRMemrefs(firMemref, converted, rewriter); } void FIRToMemRef::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "Enter FIRToMemRef()\n"); func::FuncOp op = getOperation(); MLIRContext *context = op.getContext(); ModuleOp mod = op->getParentOfType(); FIRToMemRefTypeConverter typeConverter(mod); typeConverter.setConvertComplexTypes(true); PatternRewriter rewriter(context); domInfo = new DominanceInfo(op); op.walk([&](fir::AllocaOp alloca) { rewriteAlloca(alloca, rewriter, typeConverter); }); op.walk([&](Operation *op) { if (fir::LoadOp loadOp = dyn_cast(op)) rewriteLoadOp(loadOp, rewriter, typeConverter); else if (fir::StoreOp storeOp = dyn_cast(op)) rewriteStoreOp(storeOp, rewriter, typeConverter); }); for (auto eraseOp : eraseOps) rewriter.eraseOp(eraseOp); eraseOps.clear(); if (domInfo) delete domInfo; LLVM_DEBUG(llvm::dbgs() << "After FIRToMemRef()\n"; op.dump(); llvm::dbgs() << "Exit FIRToMemRef()\n";); } } // namespace fir