//===-- FIROps.cpp --------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/Support/Utils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" namespace { #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" } // namespace static llvm::cl::opt clUseStrictVolatileVerification( "strict-fir-volatile-verifier", llvm::cl::init(false), llvm::cl::desc( "use stricter verifier for FIR operations with volatile types")); bool fir::useStrictVolatileVerification() { return clUseStrictVolatileVerification; } static void propagateAttributes(mlir::Operation *fromOp, mlir::Operation *toOp) { if (!fromOp || !toOp) return; for (mlir::NamedAttribute attr : fromOp->getAttrs()) { if (attr.getName().getValue().starts_with( mlir::acc::OpenACCDialect::getDialectNamespace())) toOp->setAttr(attr.getName(), attr.getValue()); } } /// Return true if a sequence type is of some incomplete size or a record type /// is malformed or contains an incomplete sequence type. An incomplete sequence /// type is one with more unknown extents in the type than have been provided /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by /// definition. static bool verifyInType(mlir::Type inType, llvm::SmallVectorImpl &visited, unsigned dynamicExtents = 0) { if (auto st = mlir::dyn_cast(inType)) { auto shape = st.getShape(); if (shape.size() == 0) return true; for (std::size_t i = 0, end = shape.size(); i < end; ++i) { if (shape[i] != fir::SequenceType::getUnknownExtent()) continue; if (dynamicExtents-- == 0) return true; } } else if (auto rt = mlir::dyn_cast(inType)) { // don't recurse if we're already visiting this one if (llvm::is_contained(visited, rt.getName())) return false; // keep track of record types currently being visited visited.push_back(rt.getName()); for (auto &field : rt.getTypeList()) if (verifyInType(field.second, visited)) return true; visited.pop_back(); } return false; } static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { auto ty = fir::unwrapSequenceType(inType); if (numParams > 0) { if (auto recTy = mlir::dyn_cast(ty)) return numParams != recTy.getNumLenParams(); if (auto chrTy = mlir::dyn_cast(ty)) return !(numParams == 1 && chrTy.hasDynamicLen()); return true; } if (auto chrTy = mlir::dyn_cast(ty)) return !chrTy.hasConstantLen(); return false; } /// Parser shared by Alloca and Allocmem /// /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type /// ( `(` $typeparams `)` )? ( `,` $shape )? /// attr-dict-without-keyword template static mlir::ParseResult parseAllocatableOp(FN wrapResultType, mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::Type intype; if (parser.parseType(intype)) return mlir::failure(); auto &builder = parser.getBuilder(); result.addAttribute("in_type", mlir::TypeAttr::get(intype)); llvm::SmallVector operands; llvm::SmallVector typeVec; bool hasOperands = false; std::int32_t typeparamsSize = 0; if (!parser.parseOptionalLParen()) { // parse the LEN params of the derived type. ( : ) if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || parser.parseColonTypeList(typeVec) || parser.parseRParen()) return mlir::failure(); typeparamsSize = operands.size(); hasOperands = true; } std::int32_t shapeSize = 0; if (!parser.parseOptionalComma()) { // parse size to scale by, vector of n dimensions of type index if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) return mlir::failure(); shapeSize = operands.size() - typeparamsSize; auto idxTy = builder.getIndexType(); for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) typeVec.push_back(idxTy); hasOperands = true; } if (hasOperands && parser.resolveOperands(operands, typeVec, parser.getNameLoc(), result.operands)) return mlir::failure(); mlir::Type restype = wrapResultType(intype); if (!restype) { parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; return mlir::failure(); } result.addAttribute("operandSegmentSizes", builder.getDenseI32ArrayAttr( {typeparamsSize, shapeSize})); if (parser.parseOptionalAttrDict(result.attributes) || parser.addTypeToList(restype, result.types)) return mlir::failure(); return mlir::success(); } template static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { p << ' ' << op.getInType(); if (!op.getTypeparams().empty()) { p << '(' << op.getTypeparams() << " : " << op.getTypeparams().getTypes() << ')'; } // print the shape of the allocation (if any); all must be index type for (auto sh : op.getShape()) { p << ", "; p.printOperand(sh); } p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operandSegmentSizes"}); } //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// /// Create a legal memory reference as return type static mlir::Type wrapAllocaResultType(mlir::Type intype) { // FIR semantics: memory references to memory references are disallowed if (mlir::isa(intype)) return {}; return fir::ReferenceType::get(intype); } mlir::Type fir::AllocaOp::getAllocatedType() { return mlir::cast(getType()).getEleTy(); } mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { return fir::ReferenceType::get(ty); } void fir::AllocaOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { auto nameAttr = builder.getStringAttr(uniqName); build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, /*pinned=*/false, typeparams, shape); result.addAttributes(attributes); } void fir::AllocaOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, bool pinned, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { auto nameAttr = builder.getStringAttr(uniqName); build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, pinned, typeparams, shape); result.addAttributes(attributes); } void fir::AllocaOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, llvm::StringRef bindcName, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { auto nameAttr = uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); auto bindcAttr = bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, bindcAttr, /*pinned=*/false, typeparams, shape); result.addAttributes(attributes); } void fir::AllocaOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, llvm::StringRef bindcName, bool pinned, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { auto nameAttr = uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); auto bindcAttr = bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, bindcAttr, pinned, typeparams, shape); result.addAttributes(attributes); } void fir::AllocaOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, /*pinned=*/false, typeparams, shape); result.addAttributes(attributes); } void fir::AllocaOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, bool pinned, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, pinned, typeparams, shape); result.addAttributes(attributes); } mlir::ParseResult fir::AllocaOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseAllocatableOp(wrapAllocaResultType, parser, result); } void fir::AllocaOp::print(mlir::OpAsmPrinter &p) { printAllocatableOp(p, *this); } llvm::LogicalResult fir::AllocaOp::verify() { llvm::SmallVector visited; if (verifyInType(getInType(), visited, numShapeOperands())) return emitOpError("invalid type for allocation"); if (verifyTypeParamCount(getInType(), numLenParams())) return emitOpError("LEN params do not correspond to type"); mlir::Type outType = getType(); if (!mlir::isa(outType)) return emitOpError("must be a !fir.ref type"); return mlir::success(); } bool fir::AllocaOp::ownsNestedAlloca(mlir::Operation *op) { return op->hasTrait() || op->hasTrait() || mlir::isa(*op); } mlir::Region *fir::AllocaOp::getOwnerRegion() { mlir::Operation *currentOp = getOperation(); while (mlir::Operation *parentOp = currentOp->getParentOp()) { // If the operation was not registered, inquiries about its traits will be // incorrect and it is not possible to reason about the operation. This // should not happen in a normal Fortran compilation flow, but be foolproof. if (!parentOp->isRegistered()) return nullptr; if (fir::AllocaOp::ownsNestedAlloca(parentOp)) return currentOp->getParentRegion(); currentOp = parentOp; } return nullptr; } //===----------------------------------------------------------------------===// // AllocMemOp //===----------------------------------------------------------------------===// /// Create a legal heap reference as return type static mlir::Type wrapAllocMemResultType(mlir::Type intype) { // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well // FIR semantics: one may not allocate a memory reference value if (mlir::isa(intype)) return {}; return fir::HeapType::get(intype); } mlir::Type fir::AllocMemOp::getAllocatedType() { return mlir::cast(getType()).getEleTy(); } mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { return fir::HeapType::get(ty); } void fir::AllocMemOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { auto nameAttr = builder.getStringAttr(uniqName); build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, {}, typeparams, shape); result.addAttributes(attributes); } void fir::AllocMemOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, llvm::StringRef uniqName, llvm::StringRef bindcName, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { auto nameAttr = builder.getStringAttr(uniqName); auto bindcAttr = builder.getStringAttr(bindcName); build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, bindcAttr, typeparams, shape); result.addAttributes(attributes); } void fir::AllocMemOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type inType, mlir::ValueRange typeparams, mlir::ValueRange shape, llvm::ArrayRef attributes) { build(builder, result, wrapAllocMemResultType(inType), inType, {}, {}, typeparams, shape); result.addAttributes(attributes); } mlir::ParseResult fir::AllocMemOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseAllocatableOp(wrapAllocMemResultType, parser, result); } void fir::AllocMemOp::print(mlir::OpAsmPrinter &p) { printAllocatableOp(p, *this); } llvm::LogicalResult fir::AllocMemOp::verify() { llvm::SmallVector visited; if (verifyInType(getInType(), visited, numShapeOperands())) return emitOpError("invalid type for allocation"); if (verifyTypeParamCount(getInType(), numLenParams())) return emitOpError("LEN params do not correspond to type"); mlir::Type outType = getType(); if (!mlir::dyn_cast(outType)) return emitOpError("must be a !fir.heap type"); if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) return emitOpError("cannot allocate !fir.box of unknown rank or type"); return mlir::success(); } //===----------------------------------------------------------------------===// // ArrayCoorOp //===----------------------------------------------------------------------===// // CHARACTERs and derived types with LEN PARAMETERs are dependent types that // require runtime values to fully define the type of an object. static bool validTypeParams(mlir::Type dynTy, mlir::ValueRange typeParams, bool allowParamsForBox = false) { dynTy = fir::unwrapAllRefAndSeqType(dynTy); if (mlir::isa(dynTy)) { // A box value will contain type parameter values itself. if (!allowParamsForBox) return typeParams.size() == 0; // A boxed value may have no length parameters, when the lengths // are assumed. If dynamic lengths are used, then proceed // to the verification below. if (typeParams.size() == 0) return true; dynTy = fir::getFortranElementType(dynTy); } // Derived type must have all type parameters satisfied. if (auto recTy = mlir::dyn_cast(dynTy)) return typeParams.size() == recTy.getNumLenParams(); // Characters with non-constant LEN must have a type parameter value. if (auto charTy = mlir::dyn_cast(dynTy)) if (charTy.hasDynamicLen()) return typeParams.size() == 1; // Otherwise, any type parameters are invalid. return typeParams.size() == 0; } llvm::LogicalResult fir::ArrayCoorOp::verify() { auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); auto arrTy = mlir::dyn_cast(eleTy); if (!arrTy) return emitOpError("must be a reference to an array"); auto arrDim = arrTy.getDimension(); if (auto shapeOp = getShape()) { auto shapeTy = shapeOp.getType(); unsigned shapeTyRank = 0; if (auto s = mlir::dyn_cast(shapeTy)) { shapeTyRank = s.getRank(); } else if (auto ss = mlir::dyn_cast(shapeTy)) { shapeTyRank = ss.getRank(); } else { auto s = mlir::cast(shapeTy); shapeTyRank = s.getRank(); // TODO: it looks like PreCGRewrite and CodeGen can support // fir.shift with plain array reference, so we may consider // removing this check. if (!mlir::isa(getMemref().getType())) return emitOpError("shift can only be provided with fir.box memref"); } if (arrDim && arrDim != shapeTyRank) return emitOpError("rank of dimension mismatched"); // TODO: support slicing with changing the number of dimensions, // e.g. when array_coor represents an element access to array(:,1,:) // slice: the shape is 3D and the number of indices is 2 in this case. if (shapeTyRank != getIndices().size()) return emitOpError("number of indices do not match dim rank"); } if (auto sliceOp = getSlice()) { if (auto sl = mlir::dyn_cast_or_null(sliceOp.getDefiningOp())) if (!sl.getSubstr().empty()) return emitOpError("array_coor cannot take a slice with substring"); if (auto sliceTy = mlir::dyn_cast(sliceOp.getType())) if (sliceTy.getRank() != arrDim) return emitOpError("rank of dimension in slice mismatched"); } if (!validTypeParams(getMemref().getType(), getTypeparams())) return emitOpError("invalid type parameters"); return mlir::success(); } // Pull in fir.embox and fir.rebox into fir.array_coor when possible. struct SimplifyArrayCoorOp : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; llvm::LogicalResult matchAndRewrite(fir::ArrayCoorOp op, mlir::PatternRewriter &rewriter) const override { mlir::Value memref = op.getMemref(); if (!mlir::isa(memref.getType())) return mlir::failure(); mlir::Value boxedMemref, boxedShape, boxedSlice; if (auto emboxOp = mlir::dyn_cast_or_null(memref.getDefiningOp())) { boxedMemref = emboxOp.getMemref(); boxedShape = emboxOp.getShape(); boxedSlice = emboxOp.getSlice(); // If any of operands, that are not currently supported for migration // to ArrayCoorOp, is present, don't rewrite. if (!emboxOp.getTypeparams().empty() || emboxOp.getSourceBox() || emboxOp.getAccessMap()) return mlir::failure(); } else if (auto reboxOp = mlir::dyn_cast_or_null( memref.getDefiningOp())) { boxedMemref = reboxOp.getBox(); boxedShape = reboxOp.getShape(); // Avoid pulling in rebox that performs reshaping. // There is no way to represent box reshaping with array_coor. if (boxedShape && !mlir::isa(boxedShape.getType())) return mlir::failure(); boxedSlice = reboxOp.getSlice(); } else { return mlir::failure(); } bool boxedShapeIsShift = boxedShape && mlir::isa(boxedShape.getType()); bool boxedShapeIsShape = boxedShape && mlir::isa(boxedShape.getType()); bool boxedShapeIsShapeShift = boxedShape && mlir::isa(boxedShape.getType()); // Slices changing the number of dimensions are not supported // for array_coor yet. unsigned origBoxRank; if (mlir::isa(boxedMemref.getType())) origBoxRank = fir::getBoxRank(boxedMemref.getType()); else if (auto arrTy = mlir::dyn_cast( fir::unwrapRefType(boxedMemref.getType()))) origBoxRank = arrTy.getDimension(); else return mlir::failure(); if (fir::getBoxRank(memref.getType()) != origBoxRank) return mlir::failure(); // Slices with substring are not supported by array_coor. if (boxedSlice) if (auto sliceOp = mlir::dyn_cast_or_null(boxedSlice.getDefiningOp())) if (!sliceOp.getSubstr().empty()) return mlir::failure(); // If embox/rebox and array_coor have conflicting shapes or slices, // do nothing. if (op.getShape() && boxedShape && boxedShape != op.getShape()) return mlir::failure(); if (op.getSlice() && boxedSlice && boxedSlice != op.getSlice()) return mlir::failure(); std::optional shiftedIndices; // The embox/rebox and array_coor either have compatible // shape/slice at this point or shape/slice is null // in one of them but not in the other. // The compatibility means they are equal or both null. if (!op.getShape()) { if (boxedShape) { if (op.getSlice()) { if (!boxedSlice) { if (boxedShapeIsShift) { // %0 = fir.rebox %arg(%shift) // %1 = fir.array_coor %0 [%slice] %idx // Both the slice indices and %idx are 1-based, so the rebox // may be pulled in as: // %1 = fir.array_coor %arg [%slice] %idx boxedShape = nullptr; } else if (boxedShapeIsShape) { // %0 = fir.embox %arg(%shape) // %1 = fir.array_coor %0 [%slice] %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) [%slice] %idx } else if (boxedShapeIsShapeShift) { // %0 = fir.embox %arg(%shapeshift) // %1 = fir.array_coor %0 [%slice] %idx // Pull in as: // %shape = fir.shape // %1 = fir.array_coor %arg(%shape) [%slice] %idx boxedShape = getShapeFromShapeShift(boxedShape, rewriter); if (!boxedShape) return mlir::failure(); } else { return mlir::failure(); } } else { if (boxedShapeIsShift) { // %0 = fir.rebox %arg(%shift) [%slice] // %1 = fir.array_coor %0 [%slice] %idx // This FIR may only be valid if the shape specifies // that all lower bounds are 1s and the slice's start indices // and strides are all 1s. // We could pull in the rebox as: // %1 = fir.array_coor %arg [%slice] %idx // Do not do anything for the time being. return mlir::failure(); } else if (boxedShapeIsShape) { // %0 = fir.embox %arg(%shape) [%slice] // %1 = fir.array_coor %0 [%slice] %idx // This FIR may only be valid if the slice's start indices // and strides are all 1s. // We could pull in the embox as: // %1 = fir.array_coor %arg(%shape) [%slice] %idx return mlir::failure(); } else if (boxedShapeIsShapeShift) { // %0 = fir.embox %arg(%shapeshift) [%slice] // %1 = fir.array_coor %0 [%slice] %idx // This FIR may only be valid if the shape specifies // that all lower bounds are 1s and the slice's start indices // and strides are all 1s. // We could pull in the embox as: // %shape = fir.shape // %1 = fir.array_coor %arg(%shape) [%slice] %idx return mlir::failure(); } else { return mlir::failure(); } } } else { // !op.getSlice() if (!boxedSlice) { if (boxedShapeIsShift) { // %0 = fir.rebox %arg(%shift) // %1 = fir.array_coor %0 %idx // Pull in as: // %1 = fir.array_coor %arg %idx boxedShape = nullptr; } else if (boxedShapeIsShape) { // %0 = fir.embox %arg(%shape) // %1 = fir.array_coor %0 %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) %idx } else if (boxedShapeIsShapeShift) { // %0 = fir.embox %arg(%shapeshift) // %1 = fir.array_coor %0 %idx // Pull in as: // %shape = fir.shape // %1 = fir.array_coor %arg(%shape) %idx boxedShape = getShapeFromShapeShift(boxedShape, rewriter); if (!boxedShape) return mlir::failure(); } else { return mlir::failure(); } } else { if (boxedShapeIsShift) { // %0 = fir.embox %arg(%shift) [%slice] // %1 = fir.array_coor %0 %idx // Pull in as: // %tmp = arith.addi %idx, %shift.origin // %idx_shifted = arith.subi %tmp, 1 // %1 = fir.array_coor %arg(%shift) %[slice] %idx_shifted shiftedIndices = getShiftedIndices(boxedShape, op.getIndices(), rewriter); if (!shiftedIndices) return mlir::failure(); } else if (boxedShapeIsShape) { // %0 = fir.embox %arg(%shape) [%slice] // %1 = fir.array_coor %0 %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) %[slice] %idx } else if (boxedShapeIsShapeShift) { // %0 = fir.embox %arg(%shapeshift) [%slice] // %1 = fir.array_coor %0 %idx // Pull in as: // %tmp = arith.addi %idx, %shapeshift.lb // %idx_shifted = arith.subi %tmp, 1 // %1 = fir.array_coor %arg(%shapeshift) %[slice] %idx_shifted shiftedIndices = getShiftedIndices(boxedShape, op.getIndices(), rewriter); if (!shiftedIndices) return mlir::failure(); } else { return mlir::failure(); } } } } else { // !boxedShape if (op.getSlice()) { if (!boxedSlice) { // %0 = fir.rebox %arg // %1 = fir.array_coor %0 [%slice] %idx // Pull in as: // %1 = fir.array_coor %arg [%slice] %idx } else { // %0 = fir.rebox %arg [%slice] // %1 = fir.array_coor %0 [%slice] %idx // This is a valid FIR iff the slice's lower bounds // and strides are all 1s. // Pull in as: // %1 = fir.array_coor %arg [%slice] %idx } } else { // !op.getSlice() if (!boxedSlice) { // %0 = fir.rebox %arg // %1 = fir.array_coor %0 %idx // Pull in as: // %1 = fir.array_coor %arg %idx } else { // %0 = fir.rebox %arg [%slice] // %1 = fir.array_coor %0 %idx // Pull in as: // %1 = fir.array_coor %arg [%slice] %idx } } } } else { // op.getShape() if (boxedShape) { // Check if pulling in non-default shape is correct. if (op.getSlice()) { if (!boxedSlice) { // %0 = fir.embox %arg(%shape) // %1 = fir.array_coor %0(%shape) [%slice] %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) [%slice] %idx } else { // %0 = fir.embox %arg(%shape) [%slice] // %1 = fir.array_coor %0(%shape) [%slice] %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) [%slice] %idx } } else { // !op.getSlice() if (!boxedSlice) { // %0 = fir.embox %arg(%shape) // %1 = fir.array_coor %0(%shape) %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) %idx } else { // %0 = fir.embox %arg(%shape) [%slice] // %1 = fir.array_coor %0(%shape) %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) [%slice] %idx } } } else { // !boxedShape if (op.getSlice()) { if (!boxedSlice) { // %0 = fir.rebox %arg // %1 = fir.array_coor %0(%shape) [%slice] %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) [%slice] %idx } else { // %0 = fir.rebox %arg [%slice] // %1 = fir.array_coor %0(%shape) [%slice] %idx return mlir::failure(); } } else { // !op.getSlice() if (!boxedSlice) { // %0 = fir.rebox %arg // %1 = fir.array_coor %0(%shape) %idx // Pull in as: // %1 = fir.array_coor %arg(%shape) %idx } else { // %0 = fir.rebox %arg [%slice] // %1 = fir.array_coor %0(%shape) %idx // Cannot pull in without adjusting the slice indices. return mlir::failure(); } } } } // TODO: temporarily avoid producing array_coor with the shape shift // and plain array reference (it seems to be a limitation of // ArrayCoorOp verifier). if (!mlir::isa(boxedMemref.getType())) { if (boxedShape) { if (mlir::isa(boxedShape.getType())) return mlir::failure(); } else if (op.getShape() && mlir::isa(op.getShape().getType())) { return mlir::failure(); } } rewriter.modifyOpInPlace(op, [&]() { op.getMemrefMutable().assign(boxedMemref); if (boxedShape) op.getShapeMutable().assign(boxedShape); if (boxedSlice) op.getSliceMutable().assign(boxedSlice); if (shiftedIndices) op.getIndicesMutable().assign(*shiftedIndices); }); return mlir::success(); } private: using IndicesVectorTy = std::vector; // If v is a shape_shift operation: // fir.shape_shift %l1, %e1, %l2, %e2, ... // create: // fir.shape %e1, %e2, ... static mlir::Value getShapeFromShapeShift(mlir::Value v, mlir::PatternRewriter &rewriter) { auto shapeShiftOp = mlir::dyn_cast_or_null(v.getDefiningOp()); if (!shapeShiftOp) return nullptr; mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(shapeShiftOp); return fir::ShapeOp::create(rewriter, shapeShiftOp.getLoc(), shapeShiftOp.getExtents()); } static std::optional getShiftedIndices(mlir::Value v, mlir::ValueRange indices, mlir::PatternRewriter &rewriter) { auto insertAdjustments = [&](mlir::Operation *op, mlir::ValueRange lbs) { // Compute the shifted indices using the extended type. // Note that this can probably result in less efficient // MLIR and further LLVM IR due to the extra conversions. mlir::OpBuilder::InsertPoint savedIP = rewriter.saveInsertionPoint(); rewriter.setInsertionPoint(op); mlir::Location loc = op->getLoc(); mlir::Type idxTy = rewriter.getIndexType(); mlir::Value one = mlir::arith::ConstantOp::create( rewriter, loc, idxTy, rewriter.getIndexAttr(1)); rewriter.restoreInsertionPoint(savedIP); auto nsw = mlir::arith::IntegerOverflowFlags::nsw; IndicesVectorTy shiftedIndices; for (auto [lb, idx] : llvm::zip(lbs, indices)) { mlir::Value extLb = fir::ConvertOp::create(rewriter, loc, idxTy, lb); mlir::Value extIdx = fir::ConvertOp::create(rewriter, loc, idxTy, idx); mlir::Value add = mlir::arith::AddIOp::create(rewriter, loc, extIdx, extLb, nsw); mlir::Value sub = mlir::arith::SubIOp::create(rewriter, loc, add, one, nsw); shiftedIndices.push_back(sub); } return shiftedIndices; }; if (auto shiftOp = mlir::dyn_cast_or_null(v.getDefiningOp())) { return insertAdjustments(shiftOp.getOperation(), shiftOp.getOrigins()); } else if (auto shapeShiftOp = mlir::dyn_cast_or_null( v.getDefiningOp())) { return insertAdjustments(shapeShiftOp.getOperation(), shapeShiftOp.getOrigins()); } return std::nullopt; } }; void fir::ArrayCoorOp::getCanonicalizationPatterns( mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { // TODO: !fir.shape<1> operand may be removed from array_coor always. patterns.add(context); } //===----------------------------------------------------------------------===// // ArrayLoadOp //===----------------------------------------------------------------------===// static mlir::Type adjustedElementType(mlir::Type t) { if (auto ty = mlir::dyn_cast(t)) { auto eleTy = ty.getEleTy(); if (fir::isa_char(eleTy)) return eleTy; if (fir::isa_derived(eleTy)) return eleTy; if (mlir::isa(eleTy)) return eleTy; } return t; } std::vector fir::ArrayLoadOp::getExtents() { if (auto sh = getShape()) if (auto *op = sh.getDefiningOp()) { if (auto shOp = mlir::dyn_cast(op)) { auto extents = shOp.getExtents(); return {extents.begin(), extents.end()}; } return mlir::cast(op).getExtents(); } return {}; } void fir::ArrayLoadOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { effects.emplace_back(mlir::MemoryEffects::Read::get(), &getMemrefMutable(), mlir::SideEffects::DefaultResource::get()); addVolatileMemoryEffects({getMemref().getType()}, effects); } llvm::LogicalResult fir::ArrayLoadOp::verify() { auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); auto arrTy = mlir::dyn_cast(eleTy); if (!arrTy) return emitOpError("must be a reference to an array"); auto arrDim = arrTy.getDimension(); if (auto shapeOp = getShape()) { auto shapeTy = shapeOp.getType(); unsigned shapeTyRank = 0u; if (auto s = mlir::dyn_cast(shapeTy)) { shapeTyRank = s.getRank(); } else if (auto ss = mlir::dyn_cast(shapeTy)) { shapeTyRank = ss.getRank(); } else { auto s = mlir::cast(shapeTy); shapeTyRank = s.getRank(); if (!mlir::isa(getMemref().getType())) return emitOpError("shift can only be provided with fir.box memref"); } if (arrDim && arrDim != shapeTyRank) return emitOpError("rank of dimension mismatched"); } if (auto sliceOp = getSlice()) { if (auto sl = mlir::dyn_cast_or_null(sliceOp.getDefiningOp())) if (!sl.getSubstr().empty()) return emitOpError("array_load cannot take a slice with substring"); if (auto sliceTy = mlir::dyn_cast(sliceOp.getType())) if (sliceTy.getRank() != arrDim) return emitOpError("rank of dimension in slice mismatched"); } if (!validTypeParams(getMemref().getType(), getTypeparams())) return emitOpError("invalid type parameters"); return mlir::success(); } //===----------------------------------------------------------------------===// // ArrayMergeStoreOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ArrayMergeStoreOp::verify() { if (!mlir::isa(getOriginal().getDefiningOp())) return emitOpError("operand #0 must be result of a fir.array_load op"); if (auto sl = getSlice()) { if (auto sliceOp = mlir::dyn_cast_or_null(sl.getDefiningOp())) { if (!sliceOp.getSubstr().empty()) return emitOpError( "array_merge_store cannot take a slice with substring"); if (!sliceOp.getFields().empty()) { // This is an intra-object merge, where the slice is projecting the // subfields that are to be overwritten by the merge operation. auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); if (auto seqTy = mlir::dyn_cast(eleTy)) { auto projTy = fir::applyPathToType(seqTy.getEleTy(), sliceOp.getFields()); if (fir::unwrapSequenceType(getOriginal().getType()) != projTy) return emitOpError( "type of origin does not match sliced memref type"); if (fir::unwrapSequenceType(getSequence().getType()) != projTy) return emitOpError( "type of sequence does not match sliced memref type"); return mlir::success(); } return emitOpError("referenced type is not an array"); } } return mlir::success(); } auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); if (getOriginal().getType() != eleTy) return emitOpError("type of origin does not match memref element type"); if (getSequence().getType() != eleTy) return emitOpError("type of sequence does not match memref element type"); if (!validTypeParams(getMemref().getType(), getTypeparams())) return emitOpError("invalid type parameters"); return mlir::success(); } void fir::ArrayMergeStoreOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { effects.emplace_back(mlir::MemoryEffects::Write::get(), &getMemrefMutable(), mlir::SideEffects::DefaultResource::get()); addVolatileMemoryEffects({getMemref().getType()}, effects); } //===----------------------------------------------------------------------===// // ArrayFetchOp //===----------------------------------------------------------------------===// // Template function used for both array_fetch and array_update verification. template mlir::Type validArraySubobject(A op) { auto ty = op.getSequence().getType(); return fir::applyPathToType(ty, op.getIndices()); } llvm::LogicalResult fir::ArrayFetchOp::verify() { auto arrTy = mlir::cast(getSequence().getType()); auto indSize = getIndices().size(); if (indSize < arrTy.getDimension()) return emitOpError("number of indices != dimension of array"); if (indSize == arrTy.getDimension() && ::adjustedElementType(getElement().getType()) != arrTy.getEleTy()) return emitOpError("return type does not match array"); auto ty = validArraySubobject(*this); if (!ty || ty != ::adjustedElementType(getType())) return emitOpError("return type and/or indices do not type check"); if (!mlir::isa(getSequence().getDefiningOp())) return emitOpError("argument #0 must be result of fir.array_load"); if (!validTypeParams(arrTy, getTypeparams())) return emitOpError("invalid type parameters"); return mlir::success(); } //===----------------------------------------------------------------------===// // ArrayAccessOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ArrayAccessOp::verify() { auto arrTy = mlir::cast(getSequence().getType()); std::size_t indSize = getIndices().size(); if (indSize < arrTy.getDimension()) return emitOpError("number of indices != dimension of array"); if (indSize == arrTy.getDimension() && getElement().getType() != fir::ReferenceType::get(arrTy.getEleTy())) return emitOpError("return type does not match array"); mlir::Type ty = validArraySubobject(*this); if (!ty || fir::ReferenceType::get(ty) != getType()) return emitOpError("return type and/or indices do not type check"); if (!validTypeParams(arrTy, getTypeparams())) return emitOpError("invalid type parameters"); return mlir::success(); } //===----------------------------------------------------------------------===// // ArrayUpdateOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ArrayUpdateOp::verify() { if (fir::isa_ref_type(getMerge().getType())) return emitOpError("does not support reference type for merge"); auto arrTy = mlir::cast(getSequence().getType()); auto indSize = getIndices().size(); if (indSize < arrTy.getDimension()) return emitOpError("number of indices != dimension of array"); if (indSize == arrTy.getDimension() && ::adjustedElementType(getMerge().getType()) != arrTy.getEleTy()) return emitOpError("merged value does not have element type"); auto ty = validArraySubobject(*this); if (!ty || ty != ::adjustedElementType(getMerge().getType())) return emitOpError("merged value and/or indices do not type check"); if (!validTypeParams(arrTy, getTypeparams())) return emitOpError("invalid type parameters"); return mlir::success(); } //===----------------------------------------------------------------------===// // ArrayModifyOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ArrayModifyOp::verify() { auto arrTy = mlir::cast(getSequence().getType()); auto indSize = getIndices().size(); if (indSize < arrTy.getDimension()) return emitOpError("number of indices must match array dimension"); return mlir::success(); } //===----------------------------------------------------------------------===// // BoxAddrOp //===----------------------------------------------------------------------===// void fir::BoxAddrOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value val) { mlir::Type type = llvm::TypeSwitch(val.getType()) .Case([&](fir::BaseBoxType ty) -> mlir::Type { mlir::Type eleTy = ty.getEleTy(); if (fir::isa_ref_type(eleTy)) return eleTy; return fir::ReferenceType::get(eleTy); }) .Case([&](fir::BoxCharType ty) -> mlir::Type { return fir::ReferenceType::get(ty.getEleTy()); }) .Case( [&](fir::BoxProcType ty) { return ty.getEleTy(); }) .Default([&](const auto &) { return mlir::Type{}; }); assert(type && "bad val type"); build(builder, result, type, val); } mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { if (auto *v = getVal().getDefiningOp()) { if (auto box = mlir::dyn_cast(v)) { // Fold only if not sliced if (!box.getSlice() && box.getMemref().getType() == getType()) { propagateAttributes(getOperation(), box.getMemref().getDefiningOp()); return box.getMemref(); } } if (auto box = mlir::dyn_cast(v)) if (box.getMemref().getType() == getType()) return box.getMemref(); } return {}; } //===----------------------------------------------------------------------===// // BoxCharLenOp //===----------------------------------------------------------------------===// mlir::OpFoldResult fir::BoxCharLenOp::fold(FoldAdaptor adaptor) { if (auto v = getVal().getDefiningOp()) { if (auto box = mlir::dyn_cast(v)) return box.getLen(); } return {}; } //===----------------------------------------------------------------------===// // BoxDimsOp //===----------------------------------------------------------------------===// /// Get the result types packed in a tuple tuple mlir::Type fir::BoxDimsOp::getTupleType() { // note: triple, but 4 is nearest power of 2 llvm::SmallVector triple{ getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; return mlir::TupleType::get(getContext(), triple); } //===----------------------------------------------------------------------===// // BoxRankOp //===----------------------------------------------------------------------===// void fir::BoxRankOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { mlir::OpOperand &inputBox = getBoxMutable(); if (fir::isBoxAddress(inputBox.get().getType())) effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, mlir::SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// mlir::FunctionType fir::CallOp::getFunctionType() { return mlir::FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } void fir::CallOp::print(mlir::OpAsmPrinter &p) { bool isDirect = getCallee().has_value(); p << ' '; if (isDirect) p << *getCallee(); else p << getOperand(0); p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')'; // Print `proc_attrs<...>`, if present. fir::FortranProcedureFlagsEnumAttr procAttrs = getProcedureAttrsAttr(); if (procAttrs && procAttrs.getValue() != fir::FortranProcedureFlagsEnum::none) { p << ' ' << fir::FortranProcedureFlagsEnumAttr::getMnemonic(); p.printStrippedAttrOrType(procAttrs); } // Print 'fastmath<...>' (if it has non-default value) before // any other attributes. mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr(); if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) { p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic(); p.printStrippedAttrOrType(fmfAttr); } p.printOptionalAttrDict((*this)->getAttrs(), {fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName(), getProcedureAttrsAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()}); p << " : "; mlir::call_interface_impl::printFunctionSignature( p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(), /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); } mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::SmallVector operands; if (parser.parseOperandList(operands)) return mlir::failure(); mlir::NamedAttrList attrs; mlir::SymbolRefAttr funcAttr; bool isDirect = operands.empty(); if (isDirect) if (parser.parseAttribute(funcAttr, fir::CallOp::getCalleeAttrNameStr(), attrs)) return mlir::failure(); if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren)) return mlir::failure(); // Parse `proc_attrs<...>`, if present. fir::FortranProcedureFlagsEnumAttr procAttr; if (mlir::succeeded(parser.parseOptionalKeyword( fir::FortranProcedureFlagsEnumAttr::getMnemonic()))) if (parser.parseCustomAttributeWithFallback( procAttr, mlir::Type{}, getProcedureAttrsAttrName(result.name), attrs)) return mlir::failure(); // Parse 'fastmath<...>', if present. mlir::arith::FastMathFlagsAttr fmfAttr; llvm::StringRef fmfAttrName = getFastmathAttrName(result.name); if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName))) if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{}, fmfAttrName, attrs)) return mlir::failure(); if (parser.parseOptionalAttrDict(attrs) || parser.parseColon()) return mlir::failure(); llvm::SmallVector argTypes; llvm::SmallVector resTypes; llvm::SmallVector argAttrs; llvm::SmallVector resultAttrs; if (mlir::call_interface_impl::parseFunctionSignature( parser, argTypes, argAttrs, resTypes, resultAttrs)) return parser.emitError(parser.getNameLoc(), "expected function type"); mlir::FunctionType funcType = mlir::FunctionType::get(parser.getContext(), argTypes, resTypes); if (isDirect) { if (parser.resolveOperands(operands, funcType.getInputs(), parser.getNameLoc(), result.operands)) return mlir::failure(); } else { auto funcArgs = llvm::ArrayRef(operands) .drop_front(); if (parser.resolveOperand(operands[0], funcType, result.operands) || parser.resolveOperands(funcArgs, funcType.getInputs(), parser.getNameLoc(), result.operands)) return mlir::failure(); } result.attributes = attrs; mlir::call_interface_impl::addArgAndResultAttrs( parser.getBuilder(), result, argAttrs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); result.addTypes(funcType.getResults()); return mlir::success(); } void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::func::FuncOp callee, mlir::ValueRange operands) { result.addOperands(operands); result.addAttribute(getCalleeAttrNameStr(), mlir::SymbolRefAttr::get(callee)); result.addTypes(callee.getFunctionType().getResults()); } void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::SymbolRefAttr callee, llvm::ArrayRef results, mlir::ValueRange operands) { result.addOperands(operands); if (callee) result.addAttribute(getCalleeAttrNameStr(), callee); result.addTypes(results); } //===----------------------------------------------------------------------===// // CharConvertOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::CharConvertOp::verify() { auto unwrap = [&](mlir::Type t) { t = fir::unwrapSequenceType(fir::dyn_cast_ptrEleTy(t)); return mlir::dyn_cast(t); }; auto inTy = unwrap(getFrom().getType()); auto outTy = unwrap(getTo().getType()); if (!(inTy && outTy)) return emitOpError("not a reference to a character"); if (inTy.getFKind() == outTy.getFKind()) return emitOpError("buffers must have different KIND values"); return mlir::success(); } //===----------------------------------------------------------------------===// // CmpOp //===----------------------------------------------------------------------===// template static void printCmpOp(mlir::OpAsmPrinter &p, OPTY op) { p << ' '; auto predSym = mlir::arith::symbolizeCmpFPredicate( op->template getAttrOfType( OPTY::getPredicateAttrName()) .getInt()); assert(predSym.has_value() && "invalid symbol value for predicate"); p << '"' << mlir::arith::stringifyCmpFPredicate(predSym.value()) << '"' << ", "; p.printOperand(op.getLhs()); p << ", "; p.printOperand(op.getRhs()); p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); p << " : " << op.getLhs().getType(); } template static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::SmallVector ops; mlir::NamedAttrList attrs; mlir::Attribute predicateNameAttr; mlir::Type type; if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), attrs) || parser.parseComma() || parser.parseOperandList(ops, 2) || parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands)) return mlir::failure(); if (!mlir::isa(predicateNameAttr)) return parser.emitError(parser.getNameLoc(), "expected string comparison predicate attribute"); // Rewrite string attribute to an enum value. llvm::StringRef predicateName = mlir::cast(predicateNameAttr).getValue(); auto predicate = fir::CmpcOp::getPredicateByName(predicateName); auto builder = parser.getBuilder(); mlir::Type i1Type = builder.getI1Type(); attrs.set(OPTY::getPredicateAttrName(), builder.getI64IntegerAttr(static_cast(predicate))); result.attributes = attrs; result.addTypes({i1Type}); return mlir::success(); } //===----------------------------------------------------------------------===// // CmpcOp //===----------------------------------------------------------------------===// void fir::buildCmpCOp(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::arith::CmpFPredicate predicate, mlir::Value lhs, mlir::Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(builder.getI1Type()); result.addAttribute( fir::CmpcOp::getPredicateAttrName(), builder.getI64IntegerAttr(static_cast(predicate))); } mlir::arith::CmpFPredicate fir::CmpcOp::getPredicateByName(llvm::StringRef name) { auto pred = mlir::arith::symbolizeCmpFPredicate(name); assert(pred.has_value() && "invalid predicate name"); return pred.value(); } void fir::CmpcOp::print(mlir::OpAsmPrinter &p) { printCmpOp(p, *this); } mlir::ParseResult fir::CmpcOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseCmpOp(parser, result); } //===----------------------------------------------------------------------===// // VolatileCastOp //===----------------------------------------------------------------------===// static bool typesMatchExceptForVolatility(mlir::Type fromType, mlir::Type toType) { // If we can change only the volatility and get identical types, then we // match. if (fir::updateTypeWithVolatility(fromType, fir::isa_volatile_type(toType)) == toType) return true; // Otherwise, recurse on the element types if the base classes are the same. const bool match = llvm::TypeSwitch(fromType) .Case( [&](auto type) { using TYPE = decltype(type); // If we are not the same base class, then we don't match. auto castedToType = mlir::dyn_cast(toType); if (!castedToType) return false; // If we are the same base class, we match if the element types // match. return typesMatchExceptForVolatility(type.getEleTy(), castedToType.getEleTy()); }) .Default([](mlir::Type) { return false; }); return match; } llvm::LogicalResult fir::VolatileCastOp::verify() { mlir::Type fromType = getValue().getType(); mlir::Type toType = getType(); if (!typesMatchExceptForVolatility(fromType, toType)) return emitOpError("types must be identical except for volatility ") << fromType << " / " << toType; return mlir::success(); } mlir::OpFoldResult fir::VolatileCastOp::fold(FoldAdaptor adaptor) { if (getValue().getType() == getType()) return getValue(); return {}; } //===----------------------------------------------------------------------===// // ConvertOp //===----------------------------------------------------------------------===// void fir::ConvertOp::getCanonicalizationPatterns( mlir::RewritePatternSet &results, mlir::MLIRContext *context) { results.insert( context); } mlir::OpFoldResult fir::ConvertOp::fold(FoldAdaptor adaptor) { if (getValue().getType() == getType()) return getValue(); if (matchPattern(getValue(), mlir::m_Op())) { auto inner = mlir::cast(getValue().getDefiningOp()); // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a if (auto toTy = mlir::dyn_cast(getType())) if (auto fromTy = mlir::dyn_cast(inner.getValue().getType())) if (mlir::isa(inner.getType()) && (toTy == fromTy)) return inner.getValue(); // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a if (auto toTy = mlir::dyn_cast(getType())) if (auto fromTy = mlir::dyn_cast(inner.getValue().getType())) if (mlir::isa(inner.getType()) && (toTy == fromTy) && (fromTy.getWidth() == 1)) return inner.getValue(); } return {}; } bool fir::ConvertOp::isInteger(mlir::Type ty) { return mlir::isa(ty); } bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { return isInteger(ty) || mlir::isa(ty); } bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { return mlir::isa(ty); } bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { return mlir::isa(ty); } static std::optional getVectorElementType(mlir::Type ty) { mlir::Type elemTy; if (mlir::isa(ty)) elemTy = mlir::dyn_cast(ty).getElementType(); else if (mlir::isa(ty)) elemTy = mlir::dyn_cast(ty).getElementType(); else return std::nullopt; // e.g. fir.vector<4:ui32> => mlir.vector<4xi32> // e.g. mlir.vector<4xui32> => mlir.vector<4xi32> if (elemTy.isUnsignedInteger()) { elemTy = mlir::IntegerType::get( ty.getContext(), mlir::dyn_cast(elemTy).getWidth()); } return elemTy; } static std::optional getVectorLen(mlir::Type ty) { if (mlir::isa(ty)) return mlir::dyn_cast(ty).getLen(); else if (mlir::isa(ty)) { // fir.vector only supports 1-D vector if (!(mlir::dyn_cast(ty).isScalable())) return mlir::dyn_cast(ty).getShape()[0]; } return std::nullopt; } bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) { if (!(mlir::isa(inTy) && mlir::isa(outTy)) && !(mlir::isa(inTy) && mlir::isa(outTy))) return false; // Only support integer, unsigned and real vector // Both vectors must have the same element type std::optional inElemTy = getVectorElementType(inTy); std::optional outElemTy = getVectorElementType(outTy); if (!inElemTy.has_value() || !outElemTy.has_value() || inElemTy.value() != outElemTy.value()) return false; // Both vectors must have the same number of elements std::optional inLen = getVectorLen(inTy); std::optional outLen = getVectorLen(outTy); if (!inLen.has_value() || !outLen.has_value() || inLen.value() != outLen.value()) return false; return true; } static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) { // Both records must have the same field types. // Trust frontend semantics for in-depth checks, such as if both records // have the BIND(C) attribute. auto inRecTy = mlir::dyn_cast(inTy); auto outRecTy = mlir::dyn_cast(outTy); return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList(); } bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) { if (inType == outType) return true; return (isPointerCompatible(inType) && isPointerCompatible(outType)) || (isIntegerCompatible(inType) && isIntegerCompatible(outType)) || (isInteger(inType) && isFloatCompatible(outType)) || (isFloatCompatible(inType) && isInteger(outType)) || (isFloatCompatible(inType) && isFloatCompatible(outType)) || (isIntegerCompatible(inType) && isPointerCompatible(outType)) || (isPointerCompatible(inType) && isIntegerCompatible(outType)) || (mlir::isa(inType) && mlir::isa(outType)) || (mlir::isa(inType) && mlir::isa(outType)) || (fir::isa_complex(inType) && fir::isa_complex(outType)) || (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) || (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) || (fir::isPolymorphicType(inType) && mlir::isa(outType)) || areVectorsCompatible(inType, outType) || areRecordsCompatible(inType, outType); } // In general, ptrtoint-like conversions are allowed to lose volatility // information because they are either: // // 1. passing an entity to an external function and there's nothing we can do // about volatility after that happens, or // 2. for code generation, at which point we represent volatility with // attributes on the LLVM instructions and intrinsics. // // For all other cases, volatility ought to match exactly. static mlir::LogicalResult verifyVolatility(mlir::Type inType, mlir::Type outType) { const bool toLLVMPointer = mlir::isa(outType); const bool toInteger = fir::isa_integer(outType); // When converting references to classes or allocatables into boxes for // runtime arguments, we cast away all the volatility information and pass a // box. This is allowed. const bool isBoxNoneLike = [&]() { if (fir::isBoxNone(outType)) return true; if (auto referenceType = mlir::dyn_cast(outType)) { if (fir::isBoxNone(referenceType.getElementType())) { return true; } } return false; }(); const bool isPtrToIntLike = toLLVMPointer || toInteger || isBoxNoneLike; if (isPtrToIntLike) { return mlir::success(); } // In all other cases, we need to check for an exact volatility match. return mlir::success(fir::isa_volatile_type(inType) == fir::isa_volatile_type(outType)); } llvm::LogicalResult fir::ConvertOp::verify() { mlir::Type inType = getValue().getType(); mlir::Type outType = getType(); if (fir::useStrictVolatileVerification()) { if (failed(verifyVolatility(inType, outType))) { return emitOpError("this conversion does not preserve volatility: ") << inType << " / " << outType; } } if (canBeConverted(inType, outType)) return mlir::success(); return emitOpError("invalid type conversion") << getValue().getType() << " / " << getType(); } //===----------------------------------------------------------------------===// // CoordinateOp //===----------------------------------------------------------------------===// void fir::CoordinateOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type resultType, mlir::Value ref, mlir::ValueRange coor) { llvm::SmallVector fieldIndices; llvm::SmallVector dynamicIndices; bool anyField = false; for (mlir::Value index : coor) { if (auto field = index.getDefiningOp()) { auto recTy = mlir::cast(field.getOnType()); fieldIndices.push_back(recTy.getFieldIndex(field.getFieldId())); anyField = true; } else { fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); dynamicIndices.push_back(index); } } auto typeAttr = mlir::TypeAttr::get(ref.getType()); if (anyField) { build(builder, result, resultType, ref, dynamicIndices, typeAttr, builder.getDenseI32ArrayAttr(fieldIndices)); } else { build(builder, result, resultType, ref, dynamicIndices, typeAttr, nullptr); } } void fir::CoordinateOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type resultType, mlir::Value ref, llvm::ArrayRef coor) { llvm::SmallVector fieldIndices; llvm::SmallVector dynamicIndices; bool anyField = false; for (fir::IntOrValue index : coor) { llvm::TypeSwitch(index) .Case([&](mlir::IntegerAttr intAttr) { fieldIndices.push_back(intAttr.getInt()); anyField = true; }) .Case([&](mlir::Value value) { dynamicIndices.push_back(value); fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); }); } auto typeAttr = mlir::TypeAttr::get(ref.getType()); if (anyField) { build(builder, result, resultType, ref, dynamicIndices, typeAttr, builder.getDenseI32ArrayAttr(fieldIndices)); } else { build(builder, result, resultType, ref, dynamicIndices, typeAttr, nullptr); } } void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) { p << ' ' << getRef(); if (!getFieldIndicesAttr()) { p << ", " << getCoor(); } else { mlir::Type eleTy = fir::getFortranElementType(getRef().getType()); for (auto index : getIndices()) { p << ", "; llvm::TypeSwitch(index) .Case([&](mlir::IntegerAttr intAttr) { if (auto recordType = llvm::dyn_cast(eleTy)) { int fieldId = intAttr.getInt(); if (fieldId < static_cast(recordType.getNumFields())) { auto nameAndType = recordType.getTypeList()[fieldId]; p << std::get(nameAndType); eleTy = fir::getFortranElementType( std::get(nameAndType)); return; } } // Invalid index, still print it so that invalid IR can be // investigated. p << intAttr; }) .Case([&](mlir::Value value) { p << value; }); } } p.printOptionalAttrDict( (*this)->getAttrs(), /*elideAttrs=*/{getBaseTypeAttrName(), getFieldIndicesAttrName()}); p << " : "; p.printFunctionalType(getOperandTypes(), (*this)->getResultTypes()); } mlir::ParseResult fir::CoordinateOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::OpAsmParser::UnresolvedOperand memref; if (parser.parseOperand(memref) || parser.parseComma()) return mlir::failure(); llvm::SmallVector coorOperands; llvm::SmallVector> fieldNames; llvm::SmallVector fieldIndices; while (true) { llvm::StringRef fieldName; if (mlir::succeeded(parser.parseOptionalKeyword(&fieldName))) { fieldNames.push_back({fieldName, static_cast(fieldIndices.size())}); // Actual value will be computed later when base type has been parsed. fieldIndices.push_back(0); } else { mlir::OpAsmParser::UnresolvedOperand index; if (parser.parseOperand(index)) return mlir::failure(); fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex); coorOperands.push_back(index); } if (mlir::failed(parser.parseOptionalComma())) break; } llvm::SmallVector allOperands; allOperands.push_back(memref); allOperands.append(coorOperands.begin(), coorOperands.end()); mlir::FunctionType funcTy; auto loc = parser.getCurrentLocation(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(funcTy) || parser.resolveOperands(allOperands, funcTy.getInputs(), loc, result.operands) || parser.addTypesToList(funcTy.getResults(), result.types)) return mlir::failure(); result.addAttribute(getBaseTypeAttrName(result.name), mlir::TypeAttr::get(funcTy.getInput(0))); if (!fieldNames.empty()) { mlir::Type eleTy = fir::getFortranElementType(funcTy.getInput(0)); for (auto [fieldName, operandPosition] : fieldNames) { auto recTy = llvm::dyn_cast(eleTy); if (!recTy) return parser.emitError( loc, "base must be a derived type when field name appears"); unsigned fieldNum = recTy.getFieldIndex(fieldName); if (fieldNum > recTy.getNumFields()) return parser.emitError(loc) << "field '" << fieldName << "' is not a component or subcomponent of the base type"; fieldIndices[operandPosition] = fieldNum; eleTy = fir::getFortranElementType( std::get(recTy.getTypeList()[fieldNum])); } result.addAttribute(getFieldIndicesAttrName(result.name), parser.getBuilder().getDenseI32ArrayAttr(fieldIndices)); } return mlir::success(); } llvm::LogicalResult fir::CoordinateOp::verify() { const mlir::Type refTy = getRef().getType(); if (fir::isa_ref_type(refTy)) { auto eleTy = fir::dyn_cast_ptrEleTy(refTy); if (auto arrTy = mlir::dyn_cast(eleTy)) { if (arrTy.hasUnknownShape()) return emitOpError("cannot find coordinate in unknown shape"); if (arrTy.getConstantRows() < arrTy.getDimension() - 1) return emitOpError("cannot find coordinate with unknown extents"); } if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || fir::isa_char_string(eleTy))) return emitOpError("cannot apply to this element type"); } auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(refTy); unsigned dimension = 0; const unsigned numCoors = getCoor().size(); for (auto coorOperand : llvm::enumerate(getCoor())) { auto co = coorOperand.value(); if (dimension == 0 && mlir::isa(eleTy)) { dimension = mlir::cast(eleTy).getDimension(); if (dimension == 0) return emitOpError("cannot apply to array of unknown rank"); } if (auto *defOp = co.getDefiningOp()) { if (auto index = mlir::dyn_cast(defOp)) { // Recovering a LEN type parameter only makes sense from a boxed // value. For a bare reference, the LEN type parameters must be // passed as additional arguments to `index`. if (mlir::isa(refTy)) { if (coorOperand.index() != numCoors - 1) return emitOpError("len_param_index must be last argument"); if (getNumOperands() != 2) return emitOpError("too many operands for len_param_index case"); } if (eleTy != index.getOnType()) emitOpError( "len_param_index type not compatible with reference type"); return mlir::success(); } else if (auto index = mlir::dyn_cast(defOp)) { if (eleTy != index.getOnType()) emitOpError("field_index type not compatible with reference type"); if (auto recTy = mlir::dyn_cast(eleTy)) { eleTy = recTy.getType(index.getFieldName()); continue; } return emitOpError("field_index not applied to !fir.type"); } } if (dimension) { if (--dimension == 0) eleTy = mlir::cast(eleTy).getElementType(); } else { if (auto t = mlir::dyn_cast(eleTy)) { // FIXME: Generally, we don't know which field of the tuple is being // referred to unless the operand is a constant. Just assume everything // is good in the tuple case for now. return mlir::success(); } else if (auto t = mlir::dyn_cast(eleTy)) { // FIXME: This is the same as the tuple case. return mlir::success(); } else if (auto t = mlir::dyn_cast(eleTy)) { eleTy = t.getElementType(); } else if (auto t = mlir::dyn_cast(eleTy)) { if (t.getLen() == fir::CharacterType::singleton()) return emitOpError("cannot apply to character singleton"); eleTy = fir::CharacterType::getSingleton(t.getContext(), t.getFKind()); if (fir::unwrapRefType(getType()) != eleTy) return emitOpError("character type mismatch"); } else { return emitOpError("invalid parameters (too many)"); } } } return mlir::success(); } fir::CoordinateIndicesAdaptor fir::CoordinateOp::getIndices() { return CoordinateIndicesAdaptor(getFieldIndicesAttr(), getCoor()); } //===----------------------------------------------------------------------===// // DispatchOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::DispatchOp::verify() { // Check that pass_arg_pos is in range of actual operands. pass_arg_pos is // unsigned so check for less than zero is not needed. if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1)) return emitOpError( "pass_arg_pos must be smaller than the number of operands"); // Operand pointed by pass_arg_pos must have polymorphic type. if (getPassArgPos() && !fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType())) return emitOpError("pass_arg_pos must be a polymorphic operand"); return mlir::success(); } mlir::FunctionType fir::DispatchOp::getFunctionType() { return mlir::FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } //===----------------------------------------------------------------------===// // TypeInfoOp //===----------------------------------------------------------------------===// void fir::TypeInfoOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, fir::RecordType type, fir::RecordType parentType, llvm::ArrayRef attrs) { result.addRegion(); result.addRegion(); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(type.getName())); result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); if (parentType) result.addAttribute(getParentTypeAttrName(result.name), mlir::TypeAttr::get(parentType)); result.addAttributes(attrs); } llvm::LogicalResult fir::TypeInfoOp::verify() { if (!getDispatchTable().empty()) for (auto &op : getDispatchTable().front().without_terminator()) if (!mlir::isa(op)) return op.emitOpError("dispatch table must contain dt_entry"); if (!mlir::isa(getType())) return emitOpError("type must be a fir.type"); if (getParentType() && !mlir::isa(*getParentType())) return emitOpError("parent_type must be a fir.type"); return mlir::success(); } //===----------------------------------------------------------------------===// // EmboxOp //===----------------------------------------------------------------------===// // Conversions from reference types to box types must preserve volatility. static llvm::LogicalResult verifyEmboxOpVolatilityInvariants(mlir::Type memrefType, mlir::Type resultType) { if (!fir::useStrictVolatileVerification()) return mlir::success(); mlir::Type boxElementType = llvm::TypeSwitch(resultType) .Case( [&](auto type) { return type.getEleTy(); }) .Default([&](mlir::Type type) { return type; }); // If the embox is simply wrapping a non-volatile type into a volatile box, // we're not losing any volatility information. if (boxElementType == memrefType) { return mlir::success(); } // Otherwise, the volatility of the input and result must match. const bool volatilityMatches = fir::isa_volatile_type(memrefType) == fir::isa_volatile_type(resultType); return mlir::success(volatilityMatches); } llvm::LogicalResult fir::EmboxOp::verify() { auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); bool isArray = false; if (auto seqTy = mlir::dyn_cast(eleTy)) { eleTy = seqTy.getEleTy(); isArray = true; } if (hasLenParams()) { auto lenPs = numLenParams(); if (auto rt = mlir::dyn_cast(eleTy)) { if (lenPs != rt.getNumLenParams()) return emitOpError("number of LEN params does not correspond" " to the !fir.type type"); } else if (auto strTy = mlir::dyn_cast(eleTy)) { if (strTy.getLen() != fir::CharacterType::unknownLen()) return emitOpError("CHARACTER already has static LEN"); } else { return emitOpError("LEN parameters require CHARACTER or derived type"); } for (auto lp : getTypeparams()) if (!fir::isa_integer(lp.getType())) return emitOpError("LEN parameters must be integral type"); } if (getShape() && !isArray) return emitOpError("shape must not be provided for a scalar"); if (getSlice() && !isArray) return emitOpError("slice must not be provided for a scalar"); if (getSourceBox() && !mlir::isa(getResult().getType())) return emitOpError("source_box must be used with fir.class result type"); if (failed(verifyEmboxOpVolatilityInvariants(getMemref().getType(), getResult().getType()))) return emitOpError( "cannot convert between volatile and non-volatile types:") << " " << getMemref().getType() << " " << getResult().getType(); return mlir::success(); } /// Returns true if \p extent matches the extent of the \p box's /// dimension \p dim. static bool isBoxExtent(mlir::Value box, std::int64_t dim, mlir::Value extent) { if (auto op = extent.getDefiningOp()) if (op.getVal() == box && op.getExtent() == extent) if (auto dimOperand = fir::getIntIfConstant(op.getDim())) return *dimOperand == dim; return false; } /// Returns true if \p lb matches the lower bound of the \p box's /// dimension \p dim. If \p mayHaveNonDefaultLowerBounds is false, /// then \p lb may be an integer constant 1. static bool isBoxLb(mlir::Value box, std::int64_t dim, mlir::Value lb, bool mayHaveNonDefaultLowerBounds = true) { if (auto op = lb.getDefiningOp()) { if (op.getVal() == box && op.getLowerBound() == lb) if (auto dimOperand = fir::getIntIfConstant(op.getDim())) return *dimOperand == dim; } else if (!mayHaveNonDefaultLowerBounds) { if (auto constantLb = fir::getIntIfConstant(lb)) return *constantLb == 1; } return false; } /// Returns true if \p ub matches the upper bound of the \p box's /// dimension \p dim. If \p mayHaveNonDefaultLowerBounds is false, /// then the dimension's lower bound may be an integer constant 1. /// Note that the upper bound is usually a result of computation /// involving the lower bound and the extent, and the function /// tries its best to recognize the computation pattern. /// The conservative result 'false' does not necessarily mean /// that \p ub is not an actual upper bound value. static bool isBoxUb(mlir::Value box, std::int64_t dim, mlir::Value ub, bool mayHaveNonDefaultLowerBounds = true) { if (auto sub1 = ub.getDefiningOp()) { auto one = fir::getIntIfConstant(sub1.getOperand(1)); if (!one || *one != 1) return false; if (auto add = sub1.getOperand(0).getDefiningOp()) if ((isBoxLb(box, dim, add.getOperand(0)) && isBoxExtent(box, dim, add.getOperand(1))) || (isBoxLb(box, dim, add.getOperand(1)) && isBoxExtent(box, dim, add.getOperand(0)))) return true; } else if (!mayHaveNonDefaultLowerBounds) { return isBoxExtent(box, dim, ub); } return false; } /// Checks if the given \p sliceOp specifies a contiguous /// array slice. If \p checkWhole is true, then the check /// is done for all dimensions, otherwise, only for the innermost /// dimension. /// The simplest way to prove that this is an contiguous slice /// is to check whether the slice stride(s) is 1. /// For more complex cases, extra information must be provided /// by the caller: /// * \p origBox - if not null, then the source array is represented /// with this !fir.box value. The box is used to recognize /// the full dimension slices, which are specified by the triplets /// computed from the dimensions' lower bounds and extents. /// * \p mayHaveNonDefaultLowerBounds may be set to false to indicate /// that the source entity has default lower bounds, so the full /// dimension slices computations may use 1 for the lower bound. static bool isContiguousArraySlice(fir::SliceOp sliceOp, bool checkWhole = true, mlir::Value origBox = nullptr, bool mayHaveNonDefaultLowerBounds = true) { if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) { // TODO: generalize code for the triples analysis with // hlfir::designatePreservesContinuity, especially when // recognition of the whole dimension slices is added. auto triples = sliceOp.getTriples(); assert((triples.size() % 3) == 0 && "invalid triples size"); // A slice with step=1 in the innermost dimension preserves // the continuity of the array in the innermost dimension. // If checkWhole is false, then check only the innermost slice triples. std::size_t checkUpTo = checkWhole ? triples.size() : 3; checkUpTo = std::min(checkUpTo, triples.size()); for (std::size_t i = 0; i < checkUpTo; i += 3) { if (triples[i] != triples[i + 1]) { // This is a section of the dimension. Only allow it // to be the first triple, if the source of the slice // is a boxed array. If it is a raw pointer, then // the result will still be contiguous, as long as // the strides are all ones. // When origBox is not null, we must prove that the triple // covers the whole dimension and the stride is one, // before claiming contiguity for this dimension. if (i != 0 && origBox) { std::int64_t dim = i / 3; if (!isBoxLb(origBox, dim, triples[i], mayHaveNonDefaultLowerBounds) || !isBoxUb(origBox, dim, triples[i + 1], mayHaveNonDefaultLowerBounds)) return false; } auto constantStep = fir::getIntIfConstant(triples[i + 2]); if (!constantStep || *constantStep != 1) return false; } } return true; } return false; } bool fir::isContiguousEmbox(fir::EmboxOp embox, bool checkWhole) { auto sliceArg = embox.getSlice(); if (!sliceArg) return true; if (auto sliceOp = mlir::dyn_cast_or_null(sliceArg.getDefiningOp())) return isContiguousArraySlice(sliceOp, checkWhole); return false; } //===----------------------------------------------------------------------===// // EmboxCharOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::EmboxCharOp::verify() { auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); if (!mlir::dyn_cast_or_null(eleTy)) return mlir::failure(); return mlir::success(); } //===----------------------------------------------------------------------===// // EmboxProcOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::EmboxProcOp::verify() { // host bindings (optional) must be a reference to a tuple if (auto h = getHost()) { if (auto r = mlir::dyn_cast(h.getType())) if (mlir::isa(r.getEleTy())) return mlir::success(); return mlir::failure(); } return mlir::success(); } //===----------------------------------------------------------------------===// // TypeDescOp //===----------------------------------------------------------------------===// void fir::TypeDescOp::build(mlir::OpBuilder &, mlir::OperationState &result, mlir::TypeAttr inty) { result.addAttribute("in_type", inty); result.addTypes(TypeDescType::get(inty.getValue())); } mlir::ParseResult fir::TypeDescOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::Type intype; if (parser.parseType(intype)) return mlir::failure(); result.addAttribute("in_type", mlir::TypeAttr::get(intype)); mlir::Type restype = fir::TypeDescType::get(intype); if (parser.addTypeToList(restype, result.types)) return mlir::failure(); return mlir::success(); } void fir::TypeDescOp::print(mlir::OpAsmPrinter &p) { p << ' ' << getOperation()->getAttr("in_type"); p.printOptionalAttrDict(getOperation()->getAttrs(), {"in_type"}); } llvm::LogicalResult fir::TypeDescOp::verify() { mlir::Type resultTy = getType(); if (auto tdesc = mlir::dyn_cast(resultTy)) { if (tdesc.getOfTy() != getInType()) return emitOpError("wrapped type mismatched"); return mlir::success(); } return emitOpError("must be !fir.tdesc type"); } //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// mlir::Type fir::GlobalOp::resultType() { return wrapAllocaResultType(getType()); } mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { // Parse the optional linkage llvm::StringRef linkage; auto &builder = parser.getBuilder(); if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { if (fir::GlobalOp::verifyValidLinkage(linkage)) return mlir::failure(); mlir::StringAttr linkAttr = builder.getStringAttr(linkage); result.addAttribute(fir::GlobalOp::getLinkNameAttrName(result.name), linkAttr); } // Parse the name as a symbol reference attribute. mlir::SymbolRefAttr nameAttr; if (parser.parseAttribute(nameAttr, fir::GlobalOp::getSymrefAttrName(result.name), result.attributes)) return mlir::failure(); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), nameAttr.getRootReference()); bool simpleInitializer = false; if (mlir::succeeded(parser.parseOptionalLParen())) { mlir::Attribute attr; if (parser.parseAttribute(attr, getInitValAttrName(result.name), result.attributes) || parser.parseRParen()) return mlir::failure(); simpleInitializer = true; } if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); if (succeeded( parser.parseOptionalKeyword(getConstantAttrName(result.name)))) { // if "constant" keyword then mark this as a constant, not a variable result.addAttribute(getConstantAttrName(result.name), builder.getUnitAttr()); } if (succeeded(parser.parseOptionalKeyword(getTargetAttrName(result.name)))) result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); mlir::Type globalType; if (parser.parseColonType(globalType)) return mlir::failure(); result.addAttribute(fir::GlobalOp::getTypeAttrName(result.name), mlir::TypeAttr::get(globalType)); if (simpleInitializer) { result.addRegion(); } else { // Parse the optional initializer body. auto parseResult = parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{}); if (parseResult.has_value() && mlir::failed(*parseResult)) return mlir::failure(); } return mlir::success(); } void fir::GlobalOp::print(mlir::OpAsmPrinter &p) { if (getLinkName()) p << ' ' << *getLinkName(); p << ' '; p.printAttributeWithoutType(getSymrefAttr()); if (auto val = getValueOrNull()) p << '(' << val << ')'; // Print all other attributes that are not pretty printed here. p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{ getSymNameAttrName(), getSymrefAttrName(), getTypeAttrName(), getConstantAttrName(), getTargetAttrName(), getLinkNameAttrName(), getInitValAttrName()}); if (getOperation()->getAttr(getConstantAttrName())) p << " " << getConstantAttrName().strref(); if (getOperation()->getAttr(getTargetAttrName())) p << " " << getTargetAttrName().strref(); p << " : "; p.printType(getType()); if (hasInitializationBody()) { p << ' '; p.printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } } void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { getBlock().getOperations().push_back(op); } void fir::GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef name, bool isConstant, bool isTarget, mlir::Type type, mlir::Attribute initialVal, mlir::StringAttr linkage, llvm::ArrayRef attrs) { result.addRegion(); result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute(getSymrefAttrName(result.name), mlir::SymbolRefAttr::get(builder.getContext(), name)); if (isConstant) result.addAttribute(getConstantAttrName(result.name), builder.getUnitAttr()); if (isTarget) result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); if (initialVal) result.addAttribute(getInitValAttrName(result.name), initialVal); if (linkage) result.addAttribute(getLinkNameAttrName(result.name), linkage); result.attributes.append(attrs.begin(), attrs.end()); } void fir::GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef name, mlir::Type type, mlir::Attribute initialVal, mlir::StringAttr linkage, llvm::ArrayRef attrs) { build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef name, bool isConstant, bool isTarget, mlir::Type type, mlir::StringAttr linkage, llvm::ArrayRef attrs) { build(builder, result, name, isConstant, isTarget, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef name, mlir::Type type, mlir::StringAttr linkage, llvm::ArrayRef attrs) { build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, {}, linkage, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef name, bool isConstant, bool isTarget, mlir::Type type, llvm::ArrayRef attrs) { build(builder, result, name, isConstant, isTarget, type, mlir::StringAttr{}, attrs); } void fir::GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef name, mlir::Type type, llvm::ArrayRef attrs) { build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, attrs); } mlir::ParseResult fir::GlobalOp::verifyValidLinkage(llvm::StringRef linkage) { // Supporting only a subset of the LLVM linkage types for now static const char *validNames[] = {"common", "internal", "linkonce", "linkonce_odr", "weak"}; return mlir::success(llvm::is_contained(validNames, linkage)); } //===----------------------------------------------------------------------===// // GlobalLenOp //===----------------------------------------------------------------------===// mlir::ParseResult fir::GlobalLenOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::StringRef fieldName; if (failed(parser.parseOptionalKeyword(&fieldName))) { mlir::StringAttr fieldAttr; if (parser.parseAttribute(fieldAttr, fir::GlobalLenOp::getLenParamAttrName(), result.attributes)) return mlir::failure(); } else { result.addAttribute(fir::GlobalLenOp::getLenParamAttrName(), parser.getBuilder().getStringAttr(fieldName)); } mlir::IntegerAttr constant; if (parser.parseComma() || parser.parseAttribute(constant, fir::GlobalLenOp::getIntAttrName(), result.attributes)) return mlir::failure(); return mlir::success(); } void fir::GlobalLenOp::print(mlir::OpAsmPrinter &p) { p << ' ' << getOperation()->getAttr(fir::GlobalLenOp::getLenParamAttrName()) << ", " << getOperation()->getAttr(fir::GlobalLenOp::getIntAttrName()); } //===----------------------------------------------------------------------===// // FieldIndexOp //===----------------------------------------------------------------------===// template mlir::ParseResult parseFieldLikeOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::StringRef fieldName; auto &builder = parser.getBuilder(); mlir::Type recty; if (parser.parseOptionalKeyword(&fieldName) || parser.parseComma() || parser.parseType(recty)) return mlir::failure(); result.addAttribute(fir::FieldIndexOp::getFieldAttrName(), builder.getStringAttr(fieldName)); if (!mlir::dyn_cast(recty)) return mlir::failure(); result.addAttribute(fir::FieldIndexOp::getTypeAttrName(), mlir::TypeAttr::get(recty)); if (!parser.parseOptionalLParen()) { llvm::SmallVector operands; llvm::SmallVector types; auto loc = parser.getNameLoc(); if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || parser.parseColonTypeList(types) || parser.parseRParen() || parser.resolveOperands(operands, types, loc, result.operands)) return mlir::failure(); } mlir::Type fieldType = TY::get(builder.getContext()); if (parser.addTypeToList(fieldType, result.types)) return mlir::failure(); return mlir::success(); } mlir::ParseResult fir::FieldIndexOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseFieldLikeOp(parser, result); } template void printFieldLikeOp(mlir::OpAsmPrinter &p, OP &op) { p << ' ' << op.getOperation() ->template getAttrOfType( fir::FieldIndexOp::getFieldAttrName()) .getValue() << ", " << op.getOperation()->getAttr(fir::FieldIndexOp::getTypeAttrName()); if (op.getNumOperands()) { p << '('; p.printOperands(op.getTypeparams()); auto sep = ") : "; for (auto op : op.getTypeparams()) { p << sep; if (op) p.printType(op.getType()); else p << "()"; sep = ", "; } } } void fir::FieldIndexOp::print(mlir::OpAsmPrinter &p) { printFieldLikeOp(p, *this); } void fir::FieldIndexOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef fieldName, mlir::Type recTy, mlir::ValueRange operands) { result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); result.addOperands(operands); } llvm::SmallVector fir::FieldIndexOp::getAttributes() { llvm::SmallVector attrs; attrs.push_back(getFieldIdAttr()); attrs.push_back(getOnTypeAttr()); return attrs; } //===----------------------------------------------------------------------===// // InsertOnRangeOp //===----------------------------------------------------------------------===// static mlir::ParseResult parseCustomRangeSubscript(mlir::OpAsmParser &parser, mlir::DenseIntElementsAttr &coord) { llvm::SmallVector lbounds; llvm::SmallVector ubounds; if (parser.parseKeyword("from") || parser.parseCommaSeparatedList( mlir::AsmParser::Delimiter::Paren, [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) || parser.parseKeyword("to") || parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::Paren, [&] { return parser.parseInteger(ubounds.emplace_back(0)); })) return mlir::failure(); llvm::SmallVector zippedBounds; for (auto zip : llvm::zip(lbounds, ubounds)) { zippedBounds.push_back(std::get<0>(zip)); zippedBounds.push_back(std::get<1>(zip)); } coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds); return mlir::success(); } static void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, fir::InsertOnRangeOp op, mlir::DenseIntElementsAttr coord) { printer << "from ("; auto enumerate = llvm::enumerate(coord.getValues()); // Even entries are the lower bounds. llvm::interleaveComma( make_filter_range( enumerate, [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), printer, [&](auto indexed_value) { printer << indexed_value.value(); }); printer << ") to ("; // Odd entries are the upper bounds. llvm::interleaveComma( make_filter_range( enumerate, [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), printer, [&](auto indexed_value) { printer << indexed_value.value(); }); printer << ")"; } /// Range bounds must be nonnegative, and the range must not be empty. llvm::LogicalResult fir::InsertOnRangeOp::verify() { if (fir::hasDynamicSize(getSeq().getType())) return emitOpError("must have constant shape and size"); mlir::DenseIntElementsAttr coorAttr = getCoor(); if (coorAttr.size() < 2 || coorAttr.size() % 2 != 0) return emitOpError("has uneven number of values in ranges"); bool rangeIsKnownToBeNonempty = false; for (auto i = coorAttr.getValues().end(), b = coorAttr.getValues().begin(); i != b;) { int64_t ub = (*--i); int64_t lb = (*--i); if (lb < 0 || ub < 0) return emitOpError("negative range bound"); if (rangeIsKnownToBeNonempty) continue; if (lb > ub) return emitOpError("empty range"); rangeIsKnownToBeNonempty = lb < ub; } return mlir::success(); } bool fir::InsertOnRangeOp::isFullRange() { auto extents = getType().getShape(); mlir::DenseIntElementsAttr indexes = getCoor(); if (indexes.size() / 2 != static_cast(extents.size())) return false; auto cur_index = indexes.value_begin(); for (unsigned i = 0; i < indexes.size(); i += 2) { if (*(cur_index++) != 0) return false; if (*(cur_index++) != extents[i / 2] - 1) return false; } return true; } //===----------------------------------------------------------------------===// // InsertValueOp //===----------------------------------------------------------------------===// static bool checkIsIntegerConstant(mlir::Attribute attr, std::int64_t conVal) { if (auto iattr = mlir::dyn_cast(attr)) return iattr.getInt() == conVal; return false; } static bool isZero(mlir::Attribute a) { return checkIsIntegerConstant(a, 0); } static bool isOne(mlir::Attribute a) { return checkIsIntegerConstant(a, 1); } // Undo some complex patterns created in the front-end and turn them back into // complex ops. template struct UndoComplexPattern : public mlir::RewritePattern { UndoComplexPattern(mlir::MLIRContext *ctx) : mlir::RewritePattern("fir.insert_value", 2, ctx) {} llvm::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { auto insval = mlir::dyn_cast_or_null(op); if (!insval || !mlir::isa(insval.getType())) return mlir::failure(); auto insval2 = mlir::dyn_cast_or_null( insval.getAdt().getDefiningOp()); if (!insval2) return mlir::failure(); auto binf = mlir::dyn_cast_or_null(insval.getVal().getDefiningOp()); auto binf2 = mlir::dyn_cast_or_null(insval2.getVal().getDefiningOp()); if (!binf || !binf2 || insval.getCoor().size() != 1 || !isOne(insval.getCoor()[0]) || insval2.getCoor().size() != 1 || !isZero(insval2.getCoor()[0])) return mlir::failure(); auto eai = mlir::dyn_cast_or_null( binf.getLhs().getDefiningOp()); auto ebi = mlir::dyn_cast_or_null( binf.getRhs().getDefiningOp()); auto ear = mlir::dyn_cast_or_null( binf2.getLhs().getDefiningOp()); auto ebr = mlir::dyn_cast_or_null( binf2.getRhs().getDefiningOp()); if (!eai || !ebi || !ear || !ebr || ear.getAdt() != eai.getAdt() || ebr.getAdt() != ebi.getAdt() || eai.getCoor().size() != 1 || !isOne(eai.getCoor()[0]) || ebi.getCoor().size() != 1 || !isOne(ebi.getCoor()[0]) || ear.getCoor().size() != 1 || !isZero(ear.getCoor()[0]) || ebr.getCoor().size() != 1 || !isZero(ebr.getCoor()[0])) return mlir::failure(); rewriter.replaceOpWithNewOp(op, ear.getAdt(), ebr.getAdt()); return mlir::success(); } }; void fir::InsertValueOp::getCanonicalizationPatterns( mlir::RewritePatternSet &results, mlir::MLIRContext *context) { results.insert, UndoComplexPattern>(context); } //===----------------------------------------------------------------------===// // IterWhileOp //===----------------------------------------------------------------------===// void fir::IterWhileOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, mlir::Value iterate, bool finalCountValue, mlir::ValueRange iterArgs, llvm::ArrayRef attributes) { result.addOperands({lb, ub, step, iterate}); if (finalCountValue) { result.addTypes(builder.getIndexType()); result.addAttribute(getFinalValueAttrNameStr(), builder.getUnitAttr()); } result.addTypes(iterate.getType()); result.addOperands(iterArgs); for (auto v : iterArgs) result.addTypes(v.getType()); mlir::Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new mlir::Block{}); bodyRegion->front().addArgument(builder.getIndexType(), result.location); bodyRegion->front().addArgument(iterate.getType(), result.location); bodyRegion->front().addArguments( iterArgs.getTypes(), llvm::SmallVector(iterArgs.size(), result.location)); result.addAttributes(attributes); } mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::Argument inductionVariable, iterateVar; mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput; if (parser.parseLParen() || parser.parseArgument(inductionVariable) || parser.parseEqual()) return mlir::failure(); // Parse loop bounds. auto indexType = builder.getIndexType(); auto i1Type = builder.getIntegerType(1); if (parser.parseOperand(lb) || parser.resolveOperand(lb, indexType, result.operands) || parser.parseKeyword("to") || parser.parseOperand(ub) || parser.resolveOperand(ub, indexType, result.operands) || parser.parseKeyword("step") || parser.parseOperand(step) || parser.parseRParen() || parser.resolveOperand(step, indexType, result.operands) || parser.parseKeyword("and") || parser.parseLParen() || parser.parseArgument(iterateVar) || parser.parseEqual() || parser.parseOperand(iterateInput) || parser.parseRParen() || parser.resolveOperand(iterateInput, i1Type, result.operands)) return mlir::failure(); // Parse the initial iteration arguments. auto prependCount = false; // Induction variable. llvm::SmallVector regionArgs; regionArgs.push_back(inductionVariable); regionArgs.push_back(iterateVar); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { llvm::SmallVector operands; llvm::SmallVector regionTypes; // Parse assignment list and results type list. if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(regionTypes)) return mlir::failure(); if (regionTypes.size() == operands.size() + 2) prependCount = true; llvm::ArrayRef resTypes = regionTypes; resTypes = prependCount ? resTypes.drop_front(2) : resTypes; // Resolve input operands. for (auto operandType : llvm::zip(operands, resTypes)) if (parser.resolveOperand(std::get<0>(operandType), std::get<1>(operandType), result.operands)) return mlir::failure(); if (prependCount) { result.addTypes(regionTypes); } else { result.addTypes(i1Type); result.addTypes(resTypes); } } else if (succeeded(parser.parseOptionalArrow())) { llvm::SmallVector typeList; if (parser.parseLParen() || parser.parseTypeList(typeList) || parser.parseRParen()) return mlir::failure(); // Type list must be "(index, i1)". if (typeList.size() != 2 || !mlir::isa(typeList[0]) || !typeList[1].isSignlessInteger(1)) return mlir::failure(); result.addTypes(typeList); prependCount = true; } else { result.addTypes(i1Type); } if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return mlir::failure(); llvm::SmallVector argTypes; // Induction variable (hidden) if (prependCount) result.addAttribute(IterWhileOp::getFinalValueAttrNameStr(), builder.getUnitAttr()); else argTypes.push_back(indexType); // Loop carried variables (including iterate) argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. auto *body = result.addRegion(); if (regionArgs.size() != argTypes.size()) return parser.emitError( parser.getNameLoc(), "mismatch in number of loop-carried values and defined values"); for (size_t i = 0, e = regionArgs.size(); i != e; ++i) regionArgs[i].type = argTypes[i]; if (parser.parseRegion(*body, regionArgs)) return mlir::failure(); fir::IterWhileOp::ensureTerminator(*body, builder, result.location); return mlir::success(); } llvm::LogicalResult fir::IterWhileOp::verify() { // Check that the body defines as single block argument for the induction // variable. auto *body = getBody(); if (!body->getArgument(1).getType().isInteger(1)) return emitOpError( "expected body second argument to be an index argument for " "the induction variable"); if (!body->getArgument(0).getType().isIndex()) return emitOpError( "expected body first argument to be an index argument for " "the induction variable"); auto opNumResults = getNumResults(); if (getFinalValue()) { // Result type must be "(index, i1, ...)". if (!mlir::isa(getResult(0).getType())) return emitOpError("result #0 expected to be index"); if (!getResult(1).getType().isSignlessInteger(1)) return emitOpError("result #1 expected to be i1"); opNumResults--; } else { // iterate_while always returns the early exit induction value. // Result type must be "(i1, ...)" if (!getResult(0).getType().isSignlessInteger(1)) return emitOpError("result #0 expected to be i1"); } if (opNumResults == 0) return mlir::failure(); if (getNumIterOperands() != opNumResults) return emitOpError( "mismatch in number of loop-carried values and defined values"); if (getNumRegionIterArgs() != opNumResults) return emitOpError( "mismatch in number of basic block args and defined values"); auto iterOperands = getIterOperands(); auto iterArgs = getRegionIterArgs(); auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); unsigned i = 0u; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) return emitOpError() << "types mismatch between " << i << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) return emitOpError() << "types mismatch between " << i << "th iter region arg and defined value"; i++; } return mlir::success(); } void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) { p << " (" << getInductionVar() << " = " << getLowerBound() << " to " << getUpperBound() << " step " << getStep() << ") and ("; assert(hasIterOperands()); auto regionArgs = getRegionIterArgs(); auto operands = getIterOperands(); p << regionArgs.front() << " = " << *operands.begin() << ")"; if (regionArgs.size() > 1) { p << " iter_args("; llvm::interleaveComma( llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> ("; llvm::interleaveComma( llvm::drop_begin(getResultTypes(), getFinalValue() ? 0 : 1), p); p << ")"; } else if (getFinalValue()) { p << " -> (" << getResultTypes() << ')'; } p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {getFinalValueAttrNameStr()}); p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } llvm::SmallVector fir::IterWhileOp::getLoopRegions() { return {&getRegion()}; } mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { for (auto i : llvm::enumerate(getInitArgs())) if (iterArg == i.value()) return getRegion().front().getArgument(i.index() + 1); return {}; } void fir::IterWhileOp::resultToSourceOps( llvm::SmallVectorImpl &results, unsigned resultNum) { auto oper = getFinalValue() ? resultNum + 1 : resultNum; auto *term = getRegion().front().getTerminator(); if (oper < term->getNumOperands()) results.push_back(term->getOperand(oper)); } mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) return getInitArgs()[blockArgNum - 1]; return {}; } std::optional> fir::IterWhileOp::getYieldedValuesMutable() { auto *term = getRegion().front().getTerminator(); return getFinalValue() ? term->getOpOperands().drop_front() : term->getOpOperands(); } //===----------------------------------------------------------------------===// // LenParamIndexOp //===----------------------------------------------------------------------===// mlir::ParseResult fir::LenParamIndexOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseFieldLikeOp(parser, result); } void fir::LenParamIndexOp::print(mlir::OpAsmPrinter &p) { printFieldLikeOp(p, *this); } void fir::LenParamIndexOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, llvm::StringRef fieldName, mlir::Type recTy, mlir::ValueRange operands) { result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); result.addOperands(operands); } llvm::SmallVector fir::LenParamIndexOp::getAttributes() { llvm::SmallVector attrs; attrs.push_back(getFieldIdAttr()); attrs.push_back(getOnTypeAttr()); return attrs; } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value refVal) { if (!refVal) { mlir::emitError(result.location, "LoadOp has null argument"); return; } auto eleTy = fir::dyn_cast_ptrEleTy(refVal.getType()); if (!eleTy) { mlir::emitError(result.location, "not a memory reference type"); return; } build(builder, result, eleTy, refVal); } void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Type resTy, mlir::Value refVal) { if (!refVal) { mlir::emitError(result.location, "LoadOp has null argument"); return; } result.addOperands(refVal); result.addTypes(resTy); } mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { if ((ele = fir::dyn_cast_ptrEleTy(ref))) return mlir::success(); return mlir::failure(); } mlir::ParseResult fir::LoadOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::Type type; mlir::OpAsmParser::UnresolvedOperand oper; if (parser.parseOperand(oper) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(oper, type, result.operands)) return mlir::failure(); mlir::Type eleTy; if (fir::LoadOp::getElementOf(eleTy, type) || parser.addTypeToList(eleTy, result.types)) return mlir::failure(); return mlir::success(); } void fir::LoadOp::print(mlir::OpAsmPrinter &p) { p << ' '; p.printOperand(getMemref()); p.printOptionalAttrDict(getOperation()->getAttrs(), {}); p << " : " << getMemref().getType(); } void fir::LoadOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { effects.emplace_back(mlir::MemoryEffects::Read::get(), &getMemrefMutable(), mlir::SideEffects::DefaultResource::get()); addVolatileMemoryEffects({getMemref().getType()}, effects); } //===----------------------------------------------------------------------===// // DoLoopOp //===----------------------------------------------------------------------===// void fir::DoLoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, bool unordered, bool finalCountValue, mlir::ValueRange iterArgs, mlir::ValueRange reduceOperands, llvm::ArrayRef reduceAttrs, llvm::ArrayRef attributes) { result.addOperands({lb, ub, step}); result.addOperands(reduceOperands); result.addOperands(iterArgs); result.addAttribute(getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr( {1, 1, 1, static_cast(reduceOperands.size()), static_cast(iterArgs.size())})); if (finalCountValue) { result.addTypes(builder.getIndexType()); result.addAttribute(getFinalValueAttrName(result.name), builder.getUnitAttr()); } for (auto v : iterArgs) result.addTypes(v.getType()); mlir::Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new mlir::Block{}); if (iterArgs.empty() && !finalCountValue) fir::DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); bodyRegion->front().addArgument(builder.getIndexType(), result.location); bodyRegion->front().addArguments( iterArgs.getTypes(), llvm::SmallVector(iterArgs.size(), result.location)); if (unordered) result.addAttribute(getUnorderedAttrName(result.name), builder.getUnitAttr()); if (!reduceAttrs.empty()) result.addAttribute(getReduceAttrsAttrName(result.name), builder.getArrayAttr(reduceAttrs)); result.addAttributes(attributes); } mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::OpAsmParser::Argument inductionVariable; mlir::OpAsmParser::UnresolvedOperand lb, ub, step; // Parse the induction variable followed by '='. if (parser.parseArgument(inductionVariable) || parser.parseEqual()) return mlir::failure(); // Parse loop bounds. auto indexType = builder.getIndexType(); if (parser.parseOperand(lb) || parser.resolveOperand(lb, indexType, result.operands) || parser.parseKeyword("to") || parser.parseOperand(ub) || parser.resolveOperand(ub, indexType, result.operands) || parser.parseKeyword("step") || parser.parseOperand(step) || parser.resolveOperand(step, indexType, result.operands)) return mlir::failure(); if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) result.addAttribute("unordered", builder.getUnitAttr()); // Parse the reduction arguments. llvm::SmallVector reduceOperands; llvm::SmallVector reduceArgTypes; if (succeeded(parser.parseOptionalKeyword("reduce"))) { // Parse reduction attributes and variables. llvm::SmallVector attributes; if (failed(parser.parseCommaSeparatedList( mlir::AsmParser::Delimiter::Paren, [&]() { if (parser.parseAttribute(attributes.emplace_back()) || parser.parseArrow() || parser.parseOperand(reduceOperands.emplace_back()) || parser.parseColonType(reduceArgTypes.emplace_back())) return mlir::failure(); return mlir::success(); }))) return mlir::failure(); // Resolve input operands. for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) return mlir::failure(); llvm::SmallVector arrayAttr(attributes.begin(), attributes.end()); result.addAttribute(getReduceAttrsAttrName(result.name), builder.getArrayAttr(arrayAttr)); } // Parse the optional initial iteration arguments. llvm::SmallVector regionArgs; llvm::SmallVector iterOperands; llvm::SmallVector argTypes; bool prependCount = false; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { // Parse assignment list and results type list. if (parser.parseAssignmentList(regionArgs, iterOperands) || parser.parseArrowTypeList(result.types)) return mlir::failure(); if (result.types.size() == iterOperands.size() + 1) prependCount = true; // Resolve input operands. llvm::ArrayRef resTypes = result.types; for (auto operand_type : llvm::zip( iterOperands, prependCount ? resTypes.drop_front() : resTypes)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) return mlir::failure(); } else if (succeeded(parser.parseOptionalArrow())) { if (parser.parseKeyword("index")) return mlir::failure(); result.types.push_back(indexType); prependCount = true; } // Set the operandSegmentSizes attribute result.addAttribute(getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr( {1, 1, 1, static_cast(reduceOperands.size()), static_cast(iterOperands.size())})); if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return mlir::failure(); // Induction variable. if (prependCount) result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name), builder.getUnitAttr()); else argTypes.push_back(indexType); // Loop carried variables argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. auto *body = result.addRegion(); if (regionArgs.size() != argTypes.size()) return parser.emitError( parser.getNameLoc(), "mismatch in number of loop-carried values and defined values"); for (size_t i = 0, e = regionArgs.size(); i != e; ++i) regionArgs[i].type = argTypes[i]; if (parser.parseRegion(*body, regionArgs)) return mlir::failure(); DoLoopOp::ensureTerminator(*body, builder, result.location); return mlir::success(); } fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { auto ivArg = mlir::dyn_cast(val); if (!ivArg) return {}; assert(ivArg.getOwner() && "unlinked block argument"); auto *containingInst = ivArg.getOwner()->getParentOp(); return mlir::dyn_cast_or_null(containingInst); } // Lifted from loop.loop llvm::LogicalResult fir::DoLoopOp::verify() { // Check that the body defines as single block argument for the induction // variable. auto *body = getBody(); if (!body->getArgument(0).getType().isIndex()) return emitOpError( "expected body first argument to be an index argument for " "the induction variable"); auto opNumResults = getNumResults(); if (opNumResults == 0) return mlir::success(); if (getFinalValue()) { if (getUnordered()) return emitOpError("unordered loop has no final value"); opNumResults--; } if (getNumIterOperands() != opNumResults) return emitOpError( "mismatch in number of loop-carried values and defined values"); if (getNumRegionIterArgs() != opNumResults) return emitOpError( "mismatch in number of basic block args and defined values"); auto iterOperands = getIterOperands(); auto iterArgs = getRegionIterArgs(); auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); unsigned i = 0u; for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { if (std::get<0>(e).getType() != std::get<2>(e).getType()) return emitOpError() << "types mismatch between " << i << "th iter operand and defined value"; if (std::get<1>(e).getType() != std::get<2>(e).getType()) return emitOpError() << "types mismatch between " << i << "th iter region arg and defined value"; i++; } auto reduceAttrs = getReduceAttrsAttr(); if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) return emitOpError( "mismatch in number of reduction variables and reduction attributes"); return mlir::success(); } void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { bool printBlockTerminators = false; p << ' ' << getInductionVar() << " = " << getLowerBound() << " to " << getUpperBound() << " step " << getStep(); if (getUnordered()) p << " unordered"; if (hasReduceOperands()) { p << " reduce("; auto attrs = getReduceAttrsAttr(); auto operands = getReduceOperands(); llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " << std::get<1>(it).getType(); }); p << ')'; printBlockTerminators = true; } if (hasIterOperands()) { p << " iter_args("; auto regionArgs = getRegionIterArgs(); auto operands = getIterOperands(); llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> (" << getResultTypes() << ')'; printBlockTerminators = true; } else if (getFinalValue()) { p << " -> " << getResultTypes(); printBlockTerminators = true; } p.printOptionalAttrDictWithKeyword( (*this)->getAttrs(), {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"}); p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); } llvm::SmallVector fir::DoLoopOp::getLoopRegions() { return {&getRegion()}; } /// Translate a value passed as an iter_arg to the corresponding block /// argument in the body of the loop. mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { for (auto i : llvm::enumerate(getInitArgs())) if (iterArg == i.value()) return getRegion().front().getArgument(i.index() + 1); return {}; } /// Translate the result vector (by index number) to the corresponding value /// to the `fir.result` Op. void fir::DoLoopOp::resultToSourceOps( llvm::SmallVectorImpl &results, unsigned resultNum) { auto oper = getFinalValue() ? resultNum + 1 : resultNum; auto *term = getRegion().front().getTerminator(); if (oper < term->getNumOperands()) results.push_back(term->getOperand(oper)); } /// Translate the block argument (by index number) to the corresponding value /// passed as an iter_arg to the parent DoLoopOp. mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) return getInitArgs()[blockArgNum - 1]; return {}; } std::optional> fir::DoLoopOp::getYieldedValuesMutable() { auto *term = getRegion().front().getTerminator(); return getFinalValue() ? term->getOpOperands().drop_front() : term->getOpOperands(); } //===----------------------------------------------------------------------===// // DTEntryOp //===----------------------------------------------------------------------===// mlir::ParseResult fir::DTEntryOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::StringRef methodName; // allow `methodName` or `"methodName"` if (failed(parser.parseOptionalKeyword(&methodName))) { mlir::StringAttr methodAttr; if (parser.parseAttribute(methodAttr, getMethodAttrName(result.name), result.attributes)) return mlir::failure(); } else { result.addAttribute(getMethodAttrName(result.name), parser.getBuilder().getStringAttr(methodName)); } mlir::SymbolRefAttr calleeAttr; if (parser.parseComma() || parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(), result.attributes)) return mlir::failure(); return mlir::success(); } void fir::DTEntryOp::print(mlir::OpAsmPrinter &p) { p << ' ' << getMethodAttr() << ", " << getProcAttr(); } //===----------------------------------------------------------------------===// // ReboxOp //===----------------------------------------------------------------------===// /// Get the scalar type related to a fir.box type. /// Example: return f32 for !fir.box>. static mlir::Type getBoxScalarEleTy(mlir::Type boxTy) { auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); if (auto seqTy = mlir::dyn_cast(eleTy)) return seqTy.getEleTy(); return eleTy; } /// Test if \p t1 and \p t2 are compatible character types (if they can /// represent the same type at runtime). static bool areCompatibleCharacterTypes(mlir::Type t1, mlir::Type t2) { auto c1 = mlir::dyn_cast(t1); auto c2 = mlir::dyn_cast(t2); if (!c1 || !c2) return false; if (c1.hasDynamicLen() || c2.hasDynamicLen()) return true; return c1.getLen() == c2.getLen(); } llvm::LogicalResult fir::ReboxOp::verify() { auto inputBoxTy = getBox().getType(); if (fir::isa_unknown_size_box(inputBoxTy)) return emitOpError("box operand must not have unknown rank or type"); auto outBoxTy = getType(); if (fir::isa_unknown_size_box(outBoxTy)) return emitOpError("result type must not have unknown rank or type"); auto inputRank = fir::getBoxRank(inputBoxTy); auto inputEleTy = getBoxScalarEleTy(inputBoxTy); auto outRank = fir::getBoxRank(outBoxTy); auto outEleTy = getBoxScalarEleTy(outBoxTy); if (auto sliceVal = getSlice()) { // Slicing case if (mlir::cast(sliceVal.getType()).getRank() != inputRank) return emitOpError("slice operand rank must match box operand rank"); if (auto shapeVal = getShape()) { if (auto shiftTy = mlir::dyn_cast(shapeVal.getType())) { if (shiftTy.getRank() != inputRank) return emitOpError("shape operand and input box ranks must match " "when there is a slice"); } else { return emitOpError("shape operand must absent or be a fir.shift " "when there is a slice"); } } if (auto sliceOp = sliceVal.getDefiningOp()) { auto slicedRank = mlir::cast(sliceOp).getOutRank(); if (slicedRank != outRank) return emitOpError("result type rank and rank after applying slice " "operand must match"); } } else { // Reshaping case unsigned shapeRank = inputRank; if (auto shapeVal = getShape()) { auto ty = shapeVal.getType(); if (auto shapeTy = mlir::dyn_cast(ty)) { shapeRank = shapeTy.getRank(); } else if (auto shapeShiftTy = mlir::dyn_cast(ty)) { shapeRank = shapeShiftTy.getRank(); } else { auto shiftTy = mlir::cast(ty); shapeRank = shiftTy.getRank(); if (shapeRank != inputRank) return emitOpError("shape operand and input box ranks must match " "when the shape is a fir.shift"); } } if (shapeRank != outRank) return emitOpError("result type and shape operand ranks must match"); } if (inputEleTy != outEleTy) { // TODO: check that outBoxTy is a parent type of inputBoxTy for derived // types. // Character input and output types with constant length may be different if // there is a substring in the slice, otherwise, they must match. If any of // the types is a character with dynamic length, the other type can be any // character type. const bool typeCanMismatch = mlir::isa(inputEleTy) || mlir::isa(outEleTy) || (mlir::isa(inputEleTy) && mlir::isa(outEleTy)) || (getSlice() && mlir::isa(inputEleTy)) || (getSlice() && fir::isa_complex(inputEleTy) && mlir::isa(outEleTy)) || areCompatibleCharacterTypes(inputEleTy, outEleTy); if (!typeCanMismatch) return emitOpError( "op input and output element types must match for intrinsic types"); } return mlir::success(); } //===----------------------------------------------------------------------===// // ReboxAssumedRankOp //===----------------------------------------------------------------------===// static bool areCompatibleAssumedRankElementType(mlir::Type inputEleTy, mlir::Type outEleTy) { if (inputEleTy == outEleTy) return true; // Output is unlimited polymorphic -> output dynamic type is the same as input // type. if (mlir::isa(outEleTy)) return true; // Output/Input are derived types. Assuming input extends output type, output // dynamic type is the output static type, unless output is polymorphic. if (mlir::isa(inputEleTy) && mlir::isa(outEleTy)) return true; if (areCompatibleCharacterTypes(inputEleTy, outEleTy)) return true; return false; } llvm::LogicalResult fir::ReboxAssumedRankOp::verify() { mlir::Type inputType = getBox().getType(); if (!mlir::isa(inputType) && !fir::isBoxAddress(inputType)) return emitOpError("input must be a box or box address"); mlir::Type inputEleTy = mlir::cast(fir::unwrapRefType(inputType)) .unwrapInnerType(); mlir::Type outEleTy = mlir::cast(getType()).unwrapInnerType(); if (!areCompatibleAssumedRankElementType(inputEleTy, outEleTy)) return emitOpError("input and output element types are incompatible"); return mlir::success(); } void fir::ReboxAssumedRankOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { mlir::OpOperand &inputBox = getBoxMutable(); if (fir::isBoxAddress(inputBox.get().getType())) effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, mlir::SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// // ResultOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ResultOp::verify() { auto *parentOp = (*this)->getParentOp(); auto results = parentOp->getResults(); auto operands = (*this)->getOperands(); if (parentOp->getNumResults() != getNumOperands()) return emitOpError() << "parent of result must have same arity"; for (auto e : llvm::zip(results, operands)) if (std::get<0>(e).getType() != std::get<1>(e).getType()) return emitOpError() << "types mismatch between result op and its parent"; return mlir::success(); } //===----------------------------------------------------------------------===// // SaveResultOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::SaveResultOp::verify() { auto resultType = getValue().getType(); if (resultType != fir::dyn_cast_ptrEleTy(getMemref().getType())) return emitOpError("value type must match memory reference type"); if (fir::isa_unknown_size_box(resultType)) return emitOpError("cannot save !fir.box of unknown rank or type"); if (mlir::isa(resultType)) { if (getShape() || !getTypeparams().empty()) return emitOpError( "must not have shape or length operands if the value is a fir.box"); return mlir::success(); } // fir.record or fir.array case. unsigned shapeTyRank = 0; if (auto shapeVal = getShape()) { auto shapeTy = shapeVal.getType(); if (auto s = mlir::dyn_cast(shapeTy)) shapeTyRank = s.getRank(); else shapeTyRank = mlir::cast(shapeTy).getRank(); } auto eleTy = resultType; if (auto seqTy = mlir::dyn_cast(resultType)) { if (seqTy.getDimension() != shapeTyRank) emitOpError("shape operand must be provided and have the value rank " "when the value is a fir.array"); eleTy = seqTy.getEleTy(); } else { if (shapeTyRank != 0) emitOpError( "shape operand should only be provided if the value is a fir.array"); } if (auto recTy = mlir::dyn_cast(eleTy)) { if (recTy.getNumLenParams() != getTypeparams().size()) emitOpError("length parameters number must match with the value type " "length parameters"); } else if (auto charTy = mlir::dyn_cast(eleTy)) { if (getTypeparams().size() > 1) emitOpError("no more than one length parameter must be provided for " "character value"); } else { if (!getTypeparams().empty()) emitOpError("length parameters must not be provided for this value type"); } return mlir::success(); } //===----------------------------------------------------------------------===// // IntegralSwitchTerminator //===----------------------------------------------------------------------===// static constexpr llvm::StringRef getCompareOffsetAttr() { return "compare_operand_offsets"; } static constexpr llvm::StringRef getTargetOffsetAttr() { return "target_operand_offsets"; } template static llvm::LogicalResult verifyIntegralSwitchTerminator(OpT op) { if (!mlir::isa( op.getSelector().getType())) return op.emitOpError("must be an integer"); auto cases = op->template getAttrOfType(op.getCasesAttr()).getValue(); auto count = op.getNumDest(); if (count == 0) return op.emitOpError("must have at least one successor"); if (op.getNumConditions() != count) return op.emitOpError("number of cases and targets don't match"); if (op.targetOffsetSize() != count) return op.emitOpError("incorrect number of successor operand groups"); for (decltype(count) i = 0; i != count; ++i) { if (!mlir::isa(cases[i])) return op.emitOpError("invalid case alternative"); } return mlir::success(); } static mlir::ParseResult parseIntegralSwitchTerminator( mlir::OpAsmParser &parser, mlir::OperationState &result, llvm::StringRef casesAttr, llvm::StringRef operandSegmentAttr) { mlir::OpAsmParser::UnresolvedOperand selector; mlir::Type type; if (fir::parseSelector(parser, result, selector, type)) return mlir::failure(); llvm::SmallVector ivalues; llvm::SmallVector dests; llvm::SmallVector> destArgs; while (true) { mlir::Attribute ivalue; // Integer or Unit mlir::Block *dest; llvm::SmallVector destArg; mlir::NamedAttrList temp; if (parser.parseAttribute(ivalue, "i", temp) || parser.parseComma() || parser.parseSuccessorAndUseList(dest, destArg)) return mlir::failure(); ivalues.push_back(ivalue); dests.push_back(dest); destArgs.push_back(destArg); if (!parser.parseOptionalRSquare()) break; if (parser.parseComma()) return mlir::failure(); } auto &bld = parser.getBuilder(); result.addAttribute(casesAttr, bld.getArrayAttr(ivalues)); llvm::SmallVector argOffs; int32_t sumArgs = 0; const auto count = dests.size(); for (std::remove_const_t i = 0; i != count; ++i) { result.addSuccessors(dests[i]); result.addOperands(destArgs[i]); auto argSize = destArgs[i].size(); argOffs.push_back(argSize); sumArgs += argSize; } result.addAttribute(operandSegmentAttr, bld.getDenseI32ArrayAttr({1, 0, sumArgs})); result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); return mlir::success(); } template static void printIntegralSwitchTerminator(OpT op, mlir::OpAsmPrinter &p) { p << ' '; p.printOperand(op.getSelector()); p << " : " << op.getSelector().getType() << " ["; auto cases = op->template getAttrOfType(op.getCasesAttr()).getValue(); auto count = op.getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) p << ", "; auto &attr = cases[i]; if (auto intAttr = mlir::dyn_cast_or_null(attr)) p << intAttr.getValue(); else p.printAttribute(attr); p << ", "; op.printSuccessorAtIndex(p, i); } p << ']'; p.printOptionalAttrDict( op->getAttrs(), {op.getCasesAttr(), getCompareOffsetAttr(), getTargetOffsetAttr(), op.getOperandSegmentSizeAttr()}); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::SelectOp::verify() { return verifyIntegralSwitchTerminator(*this); } mlir::ParseResult fir::SelectOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), getOperandSegmentSizeAttr()); } void fir::SelectOp::print(mlir::OpAsmPrinter &p) { printIntegralSwitchTerminator(*this, p); } template static A getSubOperands(unsigned pos, A allArgs, mlir::DenseI32ArrayAttr ranges, AdditionalArgs &&...additionalArgs) { unsigned start = 0; for (unsigned i = 0; i < pos; ++i) start += ranges[i]; return allArgs.slice(start, ranges[pos], std::forward(additionalArgs)...); } static mlir::MutableOperandRange getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, llvm::StringRef offsetAttr) { mlir::Operation *owner = operands.getOwner(); mlir::NamedAttribute targetOffsetAttr = *owner->getAttrDictionary().getNamed(offsetAttr); return getSubOperands( pos, operands, mlir::cast(targetOffsetAttr.getValue()), mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); } std::optional fir::SelectOp::getCompareOperands(unsigned) { return {}; } std::optional> fir::SelectOp::getCompareOperands(llvm::ArrayRef, unsigned) { return {}; } mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) { return mlir::SuccessorOperands(::getMutableSuccessorOperands( oper, getTargetArgsMutable(), getTargetOffsetAttr())); } std::optional> fir::SelectOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } std::optional fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectOp::targetOffsetSize() { return (*this) ->getAttrOfType(getTargetOffsetAttr()) .size(); } //===----------------------------------------------------------------------===// // SelectCaseOp //===----------------------------------------------------------------------===// std::optional fir::SelectCaseOp::getCompareOperands(unsigned cond) { auto a = (*this)->getAttrOfType(getCompareOffsetAttr()); return {getSubOperands(cond, getCompareArgs(), a)}; } std::optional> fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef operands, unsigned cond) { auto a = (*this)->getAttrOfType(getCompareOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } std::optional fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands, unsigned cond) { auto a = (*this)->getAttrOfType(getCompareOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { return mlir::SuccessorOperands(::getMutableSuccessorOperands( oper, getTargetArgsMutable(), getTargetOffsetAttr())); } std::optional> fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } std::optional fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } // parser for fir.select_case Op mlir::ParseResult fir::SelectCaseOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::OpAsmParser::UnresolvedOperand selector; mlir::Type type; if (fir::parseSelector(parser, result, selector, type)) return mlir::failure(); llvm::SmallVector attrs; llvm::SmallVector opers; llvm::SmallVector dests; llvm::SmallVector> destArgs; llvm::SmallVector argOffs; std::int32_t offSize = 0; while (true) { mlir::Attribute attr; mlir::Block *dest; llvm::SmallVector destArg; mlir::NamedAttrList temp; if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || parser.parseComma()) return mlir::failure(); attrs.push_back(attr); if (mlir::dyn_cast_or_null(attr)) { argOffs.push_back(0); } else if (mlir::dyn_cast_or_null(attr)) { mlir::OpAsmParser::UnresolvedOperand oper1; mlir::OpAsmParser::UnresolvedOperand oper2; if (parser.parseOperand(oper1) || parser.parseComma() || parser.parseOperand(oper2) || parser.parseComma()) return mlir::failure(); opers.push_back(oper1); opers.push_back(oper2); argOffs.push_back(2); offSize += 2; } else { mlir::OpAsmParser::UnresolvedOperand oper; if (parser.parseOperand(oper) || parser.parseComma()) return mlir::failure(); opers.push_back(oper); argOffs.push_back(1); ++offSize; } if (parser.parseSuccessorAndUseList(dest, destArg)) return mlir::failure(); dests.push_back(dest); destArgs.push_back(destArg); if (mlir::succeeded(parser.parseOptionalRSquare())) break; if (parser.parseComma()) return mlir::failure(); } result.addAttribute(fir::SelectCaseOp::getCasesAttr(), parser.getBuilder().getArrayAttr(attrs)); if (parser.resolveOperands(opers, type, result.operands)) return mlir::failure(); llvm::SmallVector targOffs; int32_t toffSize = 0; const auto count = dests.size(); for (std::remove_const_t i = 0; i != count; ++i) { result.addSuccessors(dests[i]); result.addOperands(destArgs[i]); auto argSize = destArgs[i].size(); targOffs.push_back(argSize); toffSize += argSize; } auto &bld = parser.getBuilder(); result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), bld.getDenseI32ArrayAttr({1, offSize, toffSize})); result.addAttribute(getCompareOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(targOffs)); return mlir::success(); } void fir::SelectCaseOp::print(mlir::OpAsmPrinter &p) { p << ' '; p.printOperand(getSelector()); p << " : " << getSelector().getType() << " ["; auto cases = getOperation()->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) p << ", "; p << cases[i] << ", "; if (!mlir::isa(cases[i])) { auto caseArgs = *getCompareOperands(i); p.printOperand(*caseArgs.begin()); p << ", "; if (mlir::isa(cases[i])) { p.printOperand(*(++caseArgs.begin())); p << ", "; } } printSuccessorAtIndex(p, i); } p << ']'; p.printOptionalAttrDict(getOperation()->getAttrs(), {getCasesAttr(), getCompareOffsetAttr(), getTargetOffsetAttr(), getOperandSegmentSizeAttr()}); } unsigned fir::SelectCaseOp::compareOffsetSize() { return (*this) ->getAttrOfType(getCompareOffsetAttr()) .size(); } unsigned fir::SelectCaseOp::targetOffsetSize() { return (*this) ->getAttrOfType(getTargetOffsetAttr()) .size(); } void fir::SelectCaseOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value selector, llvm::ArrayRef compareAttrs, llvm::ArrayRef cmpOperands, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, llvm::ArrayRef attributes) { result.addOperands(selector); result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); llvm::SmallVector operOffs; int32_t operSize = 0; for (auto attr : compareAttrs) { if (mlir::isa(attr)) { operOffs.push_back(2); operSize += 2; } else if (mlir::isa(attr)) { operOffs.push_back(0); } else { operOffs.push_back(1); ++operSize; } } for (auto ops : cmpOperands) result.addOperands(ops); result.addAttribute(getCompareOffsetAttr(), builder.getDenseI32ArrayAttr(operOffs)); const auto count = destinations.size(); for (auto d : destinations) result.addSuccessors(d); const auto opCount = destOperands.size(); llvm::SmallVector argOffs; std::int32_t sumArgs = 0; for (std::remove_const_t i = 0; i != count; ++i) { if (i < opCount) { result.addOperands(destOperands[i]); const auto argSz = destOperands[i].size(); argOffs.push_back(argSz); sumArgs += argSz; } else { argOffs.push_back(0); } } result.addAttribute(getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr({1, operSize, sumArgs})); result.addAttribute(getTargetOffsetAttr(), builder.getDenseI32ArrayAttr(argOffs)); result.addAttributes(attributes); } /// This builder has a slightly simplified interface in that the list of /// operands need not be partitioned by the builder. Instead the operands are /// partitioned here, before being passed to the default builder. This /// partitioning is unchecked, so can go awry on bad input. void fir::SelectCaseOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value selector, llvm::ArrayRef compareAttrs, llvm::ArrayRef cmpOpList, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, llvm::ArrayRef attributes) { llvm::SmallVector cmpOpers; auto iter = cmpOpList.begin(); for (auto &attr : compareAttrs) { if (mlir::isa(attr)) { cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); iter += 2; } else if (mlir::isa(attr)) { cmpOpers.push_back(mlir::ValueRange{}); } else { cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); ++iter; } } build(builder, result, selector, compareAttrs, cmpOpers, destinations, destOperands, attributes); } llvm::LogicalResult fir::SelectCaseOp::verify() { if (!mlir::isa(getSelector().getType())) return emitOpError("must be an integer, character, or logical"); auto cases = getOperation()->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); if (getNumConditions() != count) return emitOpError("number of conditions and successors don't match"); if (compareOffsetSize() != count) return emitOpError("incorrect number of compare operand groups"); if (targetOffsetSize() != count) return emitOpError("incorrect number of successor operand groups"); for (decltype(count) i = 0; i != count; ++i) { auto &attr = cases[i]; if (!(mlir::isa(attr) || mlir::isa(attr) || mlir::isa(attr) || mlir::isa(attr) || mlir::isa(attr))) return emitOpError("incorrect select case attribute type"); } return mlir::success(); } //===----------------------------------------------------------------------===// // SelectRankOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::SelectRankOp::verify() { return verifyIntegralSwitchTerminator(*this); } mlir::ParseResult fir::SelectRankOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), getOperandSegmentSizeAttr()); } void fir::SelectRankOp::print(mlir::OpAsmPrinter &p) { printIntegralSwitchTerminator(*this, p); } std::optional fir::SelectRankOp::getCompareOperands(unsigned) { return {}; } std::optional> fir::SelectRankOp::getCompareOperands(llvm::ArrayRef, unsigned) { return {}; } mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) { return mlir::SuccessorOperands(::getMutableSuccessorOperands( oper, getTargetArgsMutable(), getTargetOffsetAttr())); } std::optional> fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } std::optional fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectRankOp::targetOffsetSize() { return (*this) ->getAttrOfType(getTargetOffsetAttr()) .size(); } //===----------------------------------------------------------------------===// // SelectTypeOp //===----------------------------------------------------------------------===// std::optional fir::SelectTypeOp::getCompareOperands(unsigned) { return {}; } std::optional> fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef, unsigned) { return {}; } mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { return mlir::SuccessorOperands(::getMutableSuccessorOperands( oper, getTargetArgsMutable(), getTargetOffsetAttr())); } std::optional> fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } std::optional fir::SelectTypeOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { auto a = (*this)->getAttrOfType(getTargetOffsetAttr()); auto segments = (*this)->getAttrOfType( getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } mlir::ParseResult fir::SelectTypeOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::OpAsmParser::UnresolvedOperand selector; mlir::Type type; if (fir::parseSelector(parser, result, selector, type)) return mlir::failure(); llvm::SmallVector attrs; llvm::SmallVector dests; llvm::SmallVector> destArgs; while (true) { mlir::Attribute attr; mlir::Block *dest; llvm::SmallVector destArg; mlir::NamedAttrList temp; if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || parser.parseSuccessorAndUseList(dest, destArg)) return mlir::failure(); attrs.push_back(attr); dests.push_back(dest); destArgs.push_back(destArg); if (mlir::succeeded(parser.parseOptionalRSquare())) break; if (parser.parseComma()) return mlir::failure(); } auto &bld = parser.getBuilder(); result.addAttribute(fir::SelectTypeOp::getCasesAttr(), bld.getArrayAttr(attrs)); llvm::SmallVector argOffs; int32_t offSize = 0; const auto count = dests.size(); for (std::remove_const_t i = 0; i != count; ++i) { result.addSuccessors(dests[i]); result.addOperands(destArgs[i]); auto argSize = destArgs[i].size(); argOffs.push_back(argSize); offSize += argSize; } result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), bld.getDenseI32ArrayAttr({1, 0, offSize})); result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); return mlir::success(); } unsigned fir::SelectTypeOp::targetOffsetSize() { return (*this) ->getAttrOfType(getTargetOffsetAttr()) .size(); } void fir::SelectTypeOp::print(mlir::OpAsmPrinter &p) { p << ' '; p.printOperand(getSelector()); p << " : " << getSelector().getType() << " ["; auto cases = getOperation()->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) p << ", "; p << cases[i] << ", "; printSuccessorAtIndex(p, i); } p << ']'; p.printOptionalAttrDict(getOperation()->getAttrs(), {getCasesAttr(), getCompareOffsetAttr(), getTargetOffsetAttr(), fir::SelectTypeOp::getOperandSegmentSizeAttr()}); } llvm::LogicalResult fir::SelectTypeOp::verify() { if (!mlir::isa(getSelector().getType())) return emitOpError("must be a fir.class or fir.box type"); if (auto boxType = mlir::dyn_cast(getSelector().getType())) if (!mlir::isa(boxType.getEleTy())) return emitOpError("selector must be polymorphic"); auto typeGuardAttr = getCases(); for (unsigned idx = 0; idx < typeGuardAttr.size(); ++idx) if (mlir::isa(typeGuardAttr[idx]) && idx != typeGuardAttr.size() - 1) return emitOpError("default must be the last attribute"); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); if (getNumConditions() != count) return emitOpError("number of conditions and successors don't match"); if (targetOffsetSize() != count) return emitOpError("incorrect number of successor operand groups"); for (unsigned i = 0; i != count; ++i) { if (!mlir::isa( typeGuardAttr[i])) return emitOpError("invalid type-case alternative"); } return mlir::success(); } void fir::SelectTypeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value selector, llvm::ArrayRef typeOperands, llvm::ArrayRef destinations, llvm::ArrayRef destOperands, llvm::ArrayRef attributes) { result.addOperands(selector); result.addAttribute(getCasesAttr(), builder.getArrayAttr(typeOperands)); const auto count = destinations.size(); for (mlir::Block *dest : destinations) result.addSuccessors(dest); const auto opCount = destOperands.size(); llvm::SmallVector argOffs; int32_t sumArgs = 0; for (std::remove_const_t i = 0; i != count; ++i) { if (i < opCount) { result.addOperands(destOperands[i]); const auto argSz = destOperands[i].size(); argOffs.push_back(argSz); sumArgs += argSz; } else { argOffs.push_back(0); } } result.addAttribute(getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr({1, 0, sumArgs})); result.addAttribute(getTargetOffsetAttr(), builder.getDenseI32ArrayAttr(argOffs)); result.addAttributes(attributes); } //===----------------------------------------------------------------------===// // ShapeOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ShapeOp::verify() { auto size = getExtents().size(); auto shapeTy = mlir::dyn_cast(getType()); assert(shapeTy && "must be a shape type"); if (shapeTy.getRank() != size) return emitOpError("shape type rank mismatch"); return mlir::success(); } void fir::ShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::ValueRange extents) { auto type = fir::ShapeType::get(builder.getContext(), extents.size()); build(builder, result, type, extents); } //===----------------------------------------------------------------------===// // ShapeShiftOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ShapeShiftOp::verify() { auto size = getPairs().size(); if (size < 2 || size > 16 * 2) return emitOpError("incorrect number of args"); if (size % 2 != 0) return emitOpError("requires a multiple of 2 args"); auto shapeTy = mlir::dyn_cast(getType()); assert(shapeTy && "must be a shape shift type"); if (shapeTy.getRank() * 2 != size) return emitOpError("shape type rank mismatch"); return mlir::success(); } //===----------------------------------------------------------------------===// // ShiftOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::ShiftOp::verify() { auto size = getOrigins().size(); auto shiftTy = mlir::dyn_cast(getType()); assert(shiftTy && "must be a shift type"); if (shiftTy.getRank() != size) return emitOpError("shift type rank mismatch"); return mlir::success(); } //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// void fir::SliceOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::ValueRange trips, mlir::ValueRange path, mlir::ValueRange substr) { const auto rank = trips.size() / 3; auto sliceTy = fir::SliceType::get(builder.getContext(), rank); build(builder, result, sliceTy, trips, path, substr); } /// Return the output rank of a slice op. The output rank must be between 1 and /// the rank of the array being sliced (inclusive). unsigned fir::SliceOp::getOutputRank(mlir::ValueRange triples) { unsigned rank = 0; if (!triples.empty()) { for (unsigned i = 1, end = triples.size(); i < end; i += 3) { auto *op = triples[i].getDefiningOp(); if (!mlir::isa_and_nonnull(op)) ++rank; } assert(rank > 0); } return rank; } llvm::LogicalResult fir::SliceOp::verify() { auto size = getTriples().size(); if (size < 3 || size > 16 * 3) return emitOpError("incorrect number of args for triple"); if (size % 3 != 0) return emitOpError("requires a multiple of 3 args"); auto sliceTy = mlir::dyn_cast(getType()); assert(sliceTy && "must be a slice type"); if (sliceTy.getRank() * 3 != size) return emitOpError("slice type rank mismatch"); return mlir::success(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// mlir::Type fir::StoreOp::elementType(mlir::Type refType) { return fir::dyn_cast_ptrEleTy(refType); } mlir::ParseResult fir::StoreOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { mlir::Type type; mlir::OpAsmParser::UnresolvedOperand oper; mlir::OpAsmParser::UnresolvedOperand store; if (parser.parseOperand(oper) || parser.parseKeyword("to") || parser.parseOperand(store) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(oper, fir::StoreOp::elementType(type), result.operands) || parser.resolveOperand(store, type, result.operands)) return mlir::failure(); return mlir::success(); } void fir::StoreOp::print(mlir::OpAsmPrinter &p) { p << ' '; p.printOperand(getValue()); p << " to "; p.printOperand(getMemref()); p.printOptionalAttrDict(getOperation()->getAttrs(), {}); p << " : " << getMemref().getType(); } llvm::LogicalResult fir::StoreOp::verify() { if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType())) return emitOpError("store value type must match memory reference type"); return mlir::success(); } void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value value, mlir::Value memref) { build(builder, result, value, memref, {}); } void fir::StoreOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { effects.emplace_back(mlir::MemoryEffects::Write::get(), &getMemrefMutable(), mlir::SideEffects::DefaultResource::get()); addVolatileMemoryEffects({getMemref().getType()}, effects); } //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// void fir::CopyOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value source, mlir::Value destination, bool noOverlap) { mlir::UnitAttr noOverlapAttr = noOverlap ? builder.getUnitAttr() : mlir::UnitAttr{}; build(builder, result, source, destination, noOverlapAttr); } llvm::LogicalResult fir::CopyOp::verify() { mlir::Type sourceType = fir::unwrapRefType(getSource().getType()); mlir::Type destinationType = fir::unwrapRefType(getDestination().getType()); if (sourceType != destinationType) return emitOpError("source and destination must have the same value type"); return mlir::success(); } void fir::CopyOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { effects.emplace_back(mlir::MemoryEffects::Read::get(), &getSourceMutable(), mlir::SideEffects::DefaultResource::get()); effects.emplace_back(mlir::MemoryEffects::Write::get(), &getDestinationMutable(), mlir::SideEffects::DefaultResource::get()); addVolatileMemoryEffects({getDestination().getType(), getSource().getType()}, effects); } //===----------------------------------------------------------------------===// // StringLitOp //===----------------------------------------------------------------------===// inline fir::CharacterType::KindTy stringLitOpGetKind(fir::StringLitOp op) { auto eleTy = mlir::cast(op.getType()).getElementType(); return mlir::cast(eleTy).getFKind(); } bool fir::StringLitOp::isWideValue() { return stringLitOpGetKind(*this) != 1; } static mlir::NamedAttribute mkNamedIntegerAttr(mlir::OpBuilder &builder, llvm::StringRef name, int64_t v) { assert(v > 0); return builder.getNamedAttr( name, builder.getIntegerAttr(builder.getIntegerType(64), v)); } void fir::StringLitOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, fir::CharacterType inType, llvm::StringRef val, std::optional len) { auto valAttr = builder.getNamedAttr(value(), builder.getStringAttr(val)); int64_t length = len ? *len : inType.getLen(); auto lenAttr = mkNamedIntegerAttr(builder, size(), length); result.addAttributes({valAttr, lenAttr}); result.addTypes(inType); } template static mlir::ArrayAttr convertToArrayAttr(mlir::OpBuilder &builder, llvm::ArrayRef xlist) { llvm::SmallVector attrs; auto ty = builder.getIntegerType(8 * sizeof(C)); for (auto ch : xlist) attrs.push_back(builder.getIntegerAttr(ty, ch)); return builder.getArrayAttr(attrs); } void fir::StringLitOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, fir::CharacterType inType, llvm::ArrayRef vlist, std::optional len) { auto valAttr = builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); std::int64_t length = len ? *len : inType.getLen(); auto lenAttr = mkNamedIntegerAttr(builder, size(), length); result.addAttributes({valAttr, lenAttr}); result.addTypes(inType); } void fir::StringLitOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, fir::CharacterType inType, llvm::ArrayRef vlist, std::optional len) { auto valAttr = builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); std::int64_t length = len ? *len : inType.getLen(); auto lenAttr = mkNamedIntegerAttr(builder, size(), length); result.addAttributes({valAttr, lenAttr}); result.addTypes(inType); } void fir::StringLitOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, fir::CharacterType inType, llvm::ArrayRef vlist, std::optional len) { auto valAttr = builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); std::int64_t length = len ? *len : inType.getLen(); auto lenAttr = mkNamedIntegerAttr(builder, size(), length); result.addAttributes({valAttr, lenAttr}); result.addTypes(inType); } mlir::ParseResult fir::StringLitOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); mlir::Attribute val; mlir::NamedAttrList attrs; llvm::SMLoc trailingTypeLoc; if (parser.parseAttribute(val, "fake", attrs)) return mlir::failure(); if (auto v = mlir::dyn_cast(val)) result.attributes.push_back( builder.getNamedAttr(fir::StringLitOp::value(), v)); else if (auto v = mlir::dyn_cast(val)) result.attributes.push_back( builder.getNamedAttr(fir::StringLitOp::xlist(), v)); else if (auto v = mlir::dyn_cast(val)) result.attributes.push_back( builder.getNamedAttr(fir::StringLitOp::xlist(), v)); else return parser.emitError(parser.getCurrentLocation(), "found an invalid constant"); mlir::IntegerAttr sz; mlir::Type type; if (parser.parseLParen() || parser.parseAttribute(sz, fir::StringLitOp::size(), result.attributes) || parser.parseRParen() || parser.getCurrentLocation(&trailingTypeLoc) || parser.parseColonType(type)) return mlir::failure(); auto charTy = mlir::dyn_cast(type); if (!charTy) return parser.emitError(trailingTypeLoc, "must have character type"); type = fir::CharacterType::get(builder.getContext(), charTy.getFKind(), sz.getInt()); if (!type || parser.addTypesToList(type, result.types)) return mlir::failure(); return mlir::success(); } void fir::StringLitOp::print(mlir::OpAsmPrinter &p) { p << ' ' << getValue() << '('; p << mlir::cast(getSize()).getValue() << ") : "; p.printType(getType()); } llvm::LogicalResult fir::StringLitOp::verify() { if (mlir::cast(getSize()).getValue().isNegative()) return emitOpError("size must be non-negative"); if (auto xl = getOperation()->getAttr(fir::StringLitOp::xlist())) { if (auto xList = mlir::dyn_cast(xl)) { for (auto a : xList) if (!mlir::isa(a)) return emitOpError("values in initializer must be integers"); } else if (mlir::isa(xl)) { // do nothing } else { return emitOpError("has unexpected attribute"); } } return mlir::success(); } //===----------------------------------------------------------------------===// // UnboxProcOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::UnboxProcOp::verify() { if (auto eleTy = fir::dyn_cast_ptrEleTy(getRefTuple().getType())) if (mlir::isa(eleTy)) return mlir::success(); return emitOpError("second output argument has bad type"); } //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value cond, bool withElseRegion) { build(builder, result, {}, cond, withElseRegion); } void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::TypeRange resultTypes, mlir::Value cond, bool withElseRegion) { result.addOperands(cond); result.addTypes(resultTypes); mlir::Region *thenRegion = result.addRegion(); thenRegion->push_back(new mlir::Block()); if (resultTypes.empty()) IfOp::ensureTerminator(*thenRegion, builder, result.location); mlir::Region *elseRegion = result.addRegion(); if (withElseRegion) { elseRegion->push_back(new mlir::Block()); if (resultTypes.empty()) IfOp::ensureTerminator(*elseRegion, builder, result.location); } } // These 3 functions copied from scf.if implementation. /// Given the region at `index`, or the parent operation if `index` is None, /// return the successor regions. These are the regions that may be selected /// during the flow of control. void fir::IfOp::getSuccessorRegions( mlir::RegionBranchPoint point, llvm::SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { regions.push_back(mlir::RegionSuccessor(getResults())); return; } // Don't consider the else region if it is empty. regions.push_back(mlir::RegionSuccessor(&getThenRegion())); // Don't consider the else region if it is empty. mlir::Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) regions.push_back(mlir::RegionSuccessor()); else regions.push_back(mlir::RegionSuccessor(elseRegion)); } void fir::IfOp::getEntrySuccessorRegions( llvm::ArrayRef operands, llvm::SmallVectorImpl ®ions) { FoldAdaptor adaptor(operands); auto boolAttr = mlir::dyn_cast_or_null(adaptor.getCondition()); if (!boolAttr || boolAttr.getValue()) regions.emplace_back(&getThenRegion()); // If the else region is empty, execution continues after the parent op. if (!boolAttr || !boolAttr.getValue()) { if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else regions.emplace_back(getResults()); } } void fir::IfOp::getRegionInvocationBounds( llvm::ArrayRef operands, llvm::SmallVectorImpl &invocationBounds) { if (auto cond = mlir::dyn_cast_or_null(operands[0])) { // If the condition is known, then one region is known to be executed once // and the other zero times. invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); } else { // Non-constant condition. Each region may be executed 0 or 1 times. invocationBounds.assign(2, {0, 1}); } } mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { result.regions.reserve(2); mlir::Region *thenRegion = result.addRegion(); mlir::Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); mlir::OpAsmParser::UnresolvedOperand cond; mlir::Type i1Type = builder.getIntegerType(1); if (parser.parseOperand(cond) || parser.resolveOperand(cond, i1Type, result.operands)) return mlir::failure(); if (mlir::succeeded( parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) { if (parser.parseLParen()) return mlir::failure(); mlir::DenseI32ArrayAttr weights; if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{})) return mlir::failure(); if (weights) result.addAttribute(getRegionWeightsAttrName(result.name), weights); if (parser.parseRParen()) return mlir::failure(); } if (parser.parseOptionalArrowTypeList(result.types)) return mlir::failure(); if (parser.parseRegion(*thenRegion, {}, {})) return mlir::failure(); fir::IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); if (mlir::succeeded(parser.parseOptionalKeyword("else"))) { if (parser.parseRegion(*elseRegion, {}, {})) return mlir::failure(); fir::IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); return mlir::success(); } llvm::LogicalResult fir::IfOp::verify() { if (getNumResults() != 0 && getElseRegion().empty()) return emitOpError("must have an else block if defining values"); return mlir::success(); } void fir::IfOp::print(mlir::OpAsmPrinter &p) { bool printBlockTerminators = false; p << ' ' << getCondition(); if (auto weights = getRegionWeightsAttr()) { p << ' ' << getWeightsAttrAssemblyName() << '('; p.printStrippedAttrOrType(weights); p << ')'; } if (!getResults().empty()) { p << " -> (" << getResultTypes() << ')'; printBlockTerminators = true; } p << ' '; p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); // Print the 'else' regions if it exists and has a block. auto &otherReg = getElseRegion(); if (!otherReg.empty()) { p << " else "; p.printRegion(otherReg, /*printEntryBlockArgs=*/false, printBlockTerminators); } p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{getRegionWeightsAttrName()}); } void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl &results, unsigned resultNum) { auto *term = getThenRegion().front().getTerminator(); if (resultNum < term->getNumOperands()) results.push_back(term->getOperand(resultNum)); term = getElseRegion().front().getTerminator(); if (resultNum < term->getNumOperands()) results.push_back(term->getOperand(resultNum)); } //===----------------------------------------------------------------------===// // BoxOffsetOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::BoxOffsetOp::verify() { auto boxType = mlir::dyn_cast_or_null( fir::dyn_cast_ptrEleTy(getBoxRef().getType())); mlir::Type boxCharType; if (!boxType) { boxCharType = mlir::dyn_cast_or_null( fir::dyn_cast_ptrEleTy(getBoxRef().getType())); if (!boxCharType) return emitOpError("box_ref operand must have !fir.ref> or " "!fir.ref> type"); if (getField() == fir::BoxFieldAttr::derived_type) return emitOpError("cannot address derived_type field of a fir.boxchar"); } if (getField() != fir::BoxFieldAttr::base_addr && getField() != fir::BoxFieldAttr::derived_type) return emitOpError("cannot address provided field"); if (getField() == fir::BoxFieldAttr::derived_type) { if (!fir::boxHasAddendum(boxType)) return emitOpError("can only address derived_type field of derived type " "or unlimited polymorphic fir.box"); } return mlir::success(); } void fir::BoxOffsetOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value boxRef, fir::BoxFieldAttr field) { mlir::Type valueType = fir::unwrapPassByRefType(fir::unwrapRefType(boxRef.getType())); mlir::Type resultType = valueType; if (field == fir::BoxFieldAttr::base_addr) resultType = fir::LLVMPointerType::get(fir::ReferenceType::get(valueType)); else if (field == fir::BoxFieldAttr::derived_type) resultType = fir::LLVMPointerType::get( fir::TypeDescType::get(fir::unwrapSequenceType(valueType))); build(builder, result, {resultType}, boxRef, field); } //===----------------------------------------------------------------------===// mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { if (mlir::isa(attr)) return mlir::success(); return mlir::failure(); } unsigned fir::getCaseArgumentOffset(llvm::ArrayRef cases, unsigned dest) { unsigned o = 0; for (unsigned i = 0; i < dest; ++i) { auto &attr = cases[i]; if (!mlir::dyn_cast_or_null(attr)) { ++o; if (mlir::dyn_cast_or_null(attr)) ++o; } } return o; } mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result, mlir::OpAsmParser::UnresolvedOperand &selector, mlir::Type &type) { if (parser.parseOperand(selector) || parser.parseColonType(type) || parser.resolveOperand(selector, type, result.operands) || parser.parseLSquare()) return mlir::failure(); return mlir::success(); } mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name, mlir::FunctionType type, llvm::ArrayRef attrs, const mlir::SymbolTable *symbolTable) { if (symbolTable) if (auto f = symbolTable->lookup(name)) { #ifdef EXPENSIVE_CHECKS assert(f == module.lookupSymbol(name) && "symbolTable and module out of sync"); #endif return f; } if (auto f = module.lookupSymbol(name)) return f; mlir::OpBuilder modBuilder(module.getBodyRegion()); modBuilder.setInsertionPointToEnd(module.getBody()); auto result = mlir::func::FuncOp::create(modBuilder, loc, name, type, attrs); result.setVisibility(mlir::SymbolTable::Visibility::Private); return result; } fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, llvm::StringRef name, mlir::Type type, llvm::ArrayRef attrs, const mlir::SymbolTable *symbolTable) { if (symbolTable) if (auto g = symbolTable->lookup(name)) { #ifdef EXPENSIVE_CHECKS assert(g == module.lookupSymbol(name) && "symbolTable and module out of sync"); #endif return g; } if (auto g = module.lookupSymbol(name)) return g; mlir::OpBuilder modBuilder(module.getBodyRegion()); auto result = fir::GlobalOp::create(modBuilder, loc, name, type, attrs); result.setVisibility(mlir::SymbolTable::Visibility::Private); return result; } bool fir::hasHostAssociationArgument(mlir::func::FuncOp func) { if (auto allArgAttrs = func.getAllArgAttrs()) for (auto attr : allArgAttrs) if (auto dict = mlir::dyn_cast_or_null(attr)) if (dict.get(fir::getHostAssocAttrName())) return true; return false; } // Test if value's definition has the specified set of // attributeNames. The value's definition is one of the operations // that are able to carry the Fortran variable attributes, e.g. // fir.alloca or fir.allocmem. Function arguments may also represent // value definitions and carry relevant attributes. // // If it is not possible to reach the limited set of definition // entities from the given value, then the function will return // std::nullopt. Otherwise, the definition is known and the return // value is computed as: // * if checkAny is true, then the function will return true // iff any of the attributeNames attributes is set on the definition. // * if checkAny is false, then the function will return true // iff all of the attributeNames attributes are set on the definition. static std::optional valueCheckFirAttributes(mlir::Value value, llvm::ArrayRef attributeNames, bool checkAny) { auto testAttributeSets = [&](llvm::ArrayRef setAttrs, llvm::ArrayRef checkAttrs) { if (checkAny) { // Return true iff any of checkAttrs attributes is present // in setAttrs set. for (llvm::StringRef checkAttrName : checkAttrs) if (llvm::any_of(setAttrs, [&](mlir::NamedAttribute setAttr) { return setAttr.getName() == checkAttrName; })) return true; return false; } // Return true iff all attributes from checkAttrs are present // in setAttrs set. for (mlir::StringRef checkAttrName : checkAttrs) if (llvm::none_of(setAttrs, [&](mlir::NamedAttribute setAttr) { return setAttr.getName() == checkAttrName; })) return false; return true; }; // If this is a fir.box that was loaded, the fir attributes will be on the // related fir.ref creation. if (mlir::isa(value.getType())) if (auto definingOp = value.getDefiningOp()) if (auto loadOp = mlir::dyn_cast(definingOp)) value = loadOp.getMemref(); // If this is a function argument, look in the argument attributes. if (auto blockArg = mlir::dyn_cast(value)) { if (blockArg.getOwner() && blockArg.getOwner()->isEntryBlock()) if (auto funcOp = mlir::dyn_cast( blockArg.getOwner()->getParentOp())) return testAttributeSets( mlir::cast(*funcOp).getArgAttrs( blockArg.getArgNumber()), attributeNames); // If it is not a function argument, the attributes are unknown. return std::nullopt; } if (auto definingOp = value.getDefiningOp()) { // If this is an allocated value, look at the allocation attributes. if (mlir::isa(definingOp) || mlir::isa(definingOp)) return testAttributeSets(definingOp->getAttrs(), attributeNames); // If this is an imported global, look at AddrOfOp and GlobalOp attributes. // Both operations are looked at because use/host associated variable (the // AddrOfOp) can have ASYNCHRONOUS/VOLATILE attributes even if the ultimate // entity (the globalOp) does not have them. if (auto addressOfOp = mlir::dyn_cast(definingOp)) { if (testAttributeSets(addressOfOp->getAttrs(), attributeNames)) return true; if (auto module = definingOp->getParentOfType()) if (auto globalOp = module.lookupSymbol(addressOfOp.getSymbol())) return testAttributeSets(globalOp->getAttrs(), attributeNames); } } // TODO: Construct associated entities attributes. Decide where the fir // attributes must be placed/looked for in this case. return std::nullopt; } bool fir::valueMayHaveFirAttributes( mlir::Value value, llvm::ArrayRef attributeNames) { std::optional mayHaveAttr = valueCheckFirAttributes(value, attributeNames, /*checkAny=*/true); return mayHaveAttr.value_or(true); } bool fir::valueHasFirAttribute(mlir::Value value, llvm::StringRef attributeName) { std::optional mayHaveAttr = valueCheckFirAttributes(value, {attributeName}, /*checkAny=*/false); return mayHaveAttr.value_or(false); } bool fir::anyFuncArgsHaveAttr(mlir::func::FuncOp func, llvm::StringRef attr) { for (unsigned i = 0, end = func.getNumArguments(); i < end; ++i) if (func.getArgAttr(i, attr)) return true; return false; } std::optional fir::getIntIfConstant(mlir::Value value) { if (auto *definingOp = value.getDefiningOp()) { if (auto cst = mlir::dyn_cast(definingOp)) if (auto intAttr = mlir::dyn_cast(cst.getValue())) return intAttr.getInt(); if (auto llConstOp = mlir::dyn_cast(definingOp)) if (auto attr = mlir::dyn_cast(llConstOp.getValue())) return attr.getValue().getSExtValue(); } return {}; } bool fir::isDummyArgument(mlir::Value v) { auto blockArg{mlir::dyn_cast(v)}; if (!blockArg) { auto defOp = v.getDefiningOp(); if (defOp) { if (auto declareOp = mlir::dyn_cast(defOp)) if (declareOp.getDummyScope()) return true; } return false; } auto *owner{blockArg.getOwner()}; return owner->isEntryBlock() && mlir::isa(owner->getParentOp()); } mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { eleTy = llvm::TypeSwitch(eleTy) .Case([&](fir::RecordType ty) { if (auto *op = (*i++).getDefiningOp()) { if (auto off = mlir::dyn_cast(op)) return ty.getType(off.getFieldName()); if (auto off = mlir::dyn_cast(op)) return ty.getType(fir::toInt(off)); } return mlir::Type{}; }) .Case([&](fir::SequenceType ty) { bool valid = true; const auto rank = ty.getDimension(); for (std::remove_const_t ii = 0; valid && ii < rank; ++ii) valid = i < end && fir::isa_integer((*i++).getType()); return valid ? ty.getEleTy() : mlir::Type{}; }) .Case([&](mlir::TupleType ty) { if (auto *op = (*i++).getDefiningOp()) if (auto off = mlir::dyn_cast(op)) return ty.getType(fir::toInt(off)); return mlir::Type{}; }) .Case([&](mlir::ComplexType ty) { if (fir::isa_integer((*i++).getType())) return ty.getElementType(); return mlir::Type{}; }) .Default([&](const auto &) { return mlir::Type{}; }); } return eleTy; } bool fir::reboxPreservesContinuity(fir::ReboxOp rebox, bool mayHaveNonDefaultLowerBounds, bool checkWhole) { // If slicing is not involved, then the rebox does not affect // the continuity of the array. auto sliceArg = rebox.getSlice(); if (!sliceArg) return true; if (auto sliceOp = mlir::dyn_cast_or_null(sliceArg.getDefiningOp())) return isContiguousArraySlice(sliceOp, checkWhole, rebox.getBox(), mayHaveNonDefaultLowerBounds); return false; } std::optional fir::getAllocaByteSize(fir::AllocaOp alloca, const mlir::DataLayout &dl, const fir::KindMapping &kindMap) { mlir::Type type = alloca.getInType(); // TODO: should use the constant operands when all info is not available in // the type. if (!alloca.isDynamic()) if (auto sizeAndAlignment = getTypeSizeAndAlignment(alloca.getLoc(), type, dl, kindMap)) return sizeAndAlignment->first; return std::nullopt; } //===----------------------------------------------------------------------===// // DeclareOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::DeclareOp::verify() { auto fortranVar = mlir::cast(this->getOperation()); return fortranVar.verifyDeclareLikeOpImpl(getMemref()); } //===----------------------------------------------------------------------===// // PackArrayOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::PackArrayOp::verify() { mlir::Type arrayType = getArray().getType(); if (!validTypeParams(arrayType, getTypeparams(), /*allowParamsForBox=*/true)) return emitOpError("invalid type parameters"); if (getInnermost() && fir::getBoxRank(arrayType) == 1) return emitOpError( "'innermost' is invalid for 1D arrays, use 'whole' instead"); return mlir::success(); } void fir::PackArrayOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { if (getStack()) effects.emplace_back( mlir::MemoryEffects::Allocate::get(), mlir::SideEffects::AutomaticAllocationScopeResource::get()); else effects.emplace_back(mlir::MemoryEffects::Allocate::get(), mlir::SideEffects::DefaultResource::get()); if (!getNoCopy()) effects.emplace_back(mlir::MemoryEffects::Read::get(), mlir::SideEffects::DefaultResource::get()); } static mlir::ParseResult parsePackArrayConstraints(mlir::OpAsmParser &parser, mlir::IntegerAttr &maxSize, mlir::IntegerAttr &maxElementSize, mlir::IntegerAttr &minStride) { mlir::OperationName opName = mlir::OperationName( fir::PackArrayOp::getOperationName(), parser.getContext()); struct { llvm::StringRef name; mlir::IntegerAttr &ref; } attributes[] = { {fir::PackArrayOp::getMaxSizeAttrName(opName), maxSize}, {fir::PackArrayOp::getMaxElementSizeAttrName(opName), maxElementSize}, {fir::PackArrayOp::getMinStrideAttrName(opName), minStride}}; mlir::NamedAttrList parsedAttrs; if (succeeded(parser.parseOptionalAttrDict(parsedAttrs))) { for (auto parsedAttr : parsedAttrs) { for (auto opAttr : attributes) { if (parsedAttr.getName() == opAttr.name) opAttr.ref = mlir::cast(parsedAttr.getValue()); } } return mlir::success(); } return mlir::failure(); } static void printPackArrayConstraints(mlir::OpAsmPrinter &p, fir::PackArrayOp &op, const mlir::IntegerAttr &maxSize, const mlir::IntegerAttr &maxElementSize, const mlir::IntegerAttr &minStride) { llvm::SmallVector attributes; if (maxSize) attributes.emplace_back(op.getMaxSizeAttrName(), maxSize); if (maxElementSize) attributes.emplace_back(op.getMaxElementSizeAttrName(), maxElementSize); if (minStride) attributes.emplace_back(op.getMinStrideAttrName(), minStride); p.printOptionalAttrDict(attributes); } //===----------------------------------------------------------------------===// // UnpackArrayOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::UnpackArrayOp::verify() { if (auto packOp = getTemp().getDefiningOp()) if (getStack() != packOp.getStack()) return emitOpError() << "the pack operation uses different memory for " "the temporary (stack vs heap): " << *packOp.getOperation() << "\n"; return mlir::success(); } void fir::UnpackArrayOp::getEffects( llvm::SmallVectorImpl< mlir::SideEffects::EffectInstance> &effects) { if (getStack()) effects.emplace_back( mlir::MemoryEffects::Free::get(), mlir::SideEffects::AutomaticAllocationScopeResource::get()); else effects.emplace_back(mlir::MemoryEffects::Free::get(), mlir::SideEffects::DefaultResource::get()); if (!getNoCopy()) effects.emplace_back(mlir::MemoryEffects::Write::get(), mlir::SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// // IsContiguousBoxOp //===----------------------------------------------------------------------===// namespace { struct SimplifyIsContiguousBoxOp : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const override; }; } // namespace mlir::LogicalResult SimplifyIsContiguousBoxOp::matchAndRewrite( fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { auto boxType = mlir::cast(op.getBox().getType()); // Nothing to do for assumed-rank arrays and !fir.box. if (boxType.isAssumedRank() || fir::isBoxNone(boxType)) return mlir::failure(); if (fir::getBoxRank(boxType) == 0) { // Scalars are always contiguous. mlir::Type i1Type = rewriter.getI1Type(); rewriter.replaceOpWithNewOp( op, i1Type, rewriter.getIntegerAttr(i1Type, 1)); return mlir::success(); } // TODO: support more patterns, e.g. a result of fir.embox without // the slice is contiguous. We can add fir::isSimplyContiguous(box) // that walks def-use to figure it out. return mlir::failure(); } void fir::IsContiguousBoxOp::getCanonicalizationPatterns( mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // BoxTotalElementsOp //===----------------------------------------------------------------------===// namespace { struct SimplifyBoxTotalElementsOp : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const override; }; } // namespace mlir::LogicalResult SimplifyBoxTotalElementsOp::matchAndRewrite( fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { auto boxType = mlir::cast(op.getBox().getType()); // Nothing to do for assumed-rank arrays and !fir.box. if (boxType.isAssumedRank() || fir::isBoxNone(boxType)) return mlir::failure(); if (fir::getBoxRank(boxType) == 0) { // Scalar: 1 element. rewriter.replaceOpWithNewOp( op, op.getType(), rewriter.getIntegerAttr(op.getType(), 1)); return mlir::success(); } // TODO: support more cases, e.g. !fir.box>. return mlir::failure(); } void fir::BoxTotalElementsOp::getCanonicalizationPatterns( mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // LocalitySpecifierOp //===----------------------------------------------------------------------===// // TODO This is a copy of omp::PrivateClauseOp::verifiyRegions(). Once we find a // solution to merge both ops into one this duplication will not be needed. See: // https://discourse.llvm.org/t/dialect-for-data-locality-sharing-specifiers-clauses-in-openmp-openacc-and-do-concurrent/86108. llvm::LogicalResult fir::LocalitySpecifierOp::verifyRegions() { mlir::Type argType = getArgType(); auto verifyTerminator = [&](mlir::Operation *terminator, bool yieldsValue) -> llvm::LogicalResult { if (!terminator->getBlock()->getSuccessors().empty()) return llvm::success(); if (!llvm::isa(terminator)) return mlir::emitError(terminator->getLoc()) << "expected exit block terminator to be an `fir.yield` op."; YieldOp yieldOp = llvm::cast(terminator); mlir::TypeRange yieldedTypes = yieldOp.getResults().getTypes(); if (!yieldsValue) { if (yieldedTypes.empty()) return llvm::success(); return mlir::emitError(terminator->getLoc()) << "Did not expect any values to be yielded."; } if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType) return llvm::success(); auto error = mlir::emitError(yieldOp.getLoc()) << "Invalid yielded value. Expected type: " << argType << ", got: "; if (yieldedTypes.empty()) error << "None"; else error << yieldedTypes; return error; }; auto verifyRegion = [&](mlir::Region ®ion, unsigned expectedNumArgs, llvm::StringRef regionName, bool yieldsValue) -> llvm::LogicalResult { assert(!region.empty()); if (region.getNumArguments() != expectedNumArgs) return mlir::emitError(region.getLoc()) << "`" << regionName << "`: " << "expected " << expectedNumArgs << " region arguments, got: " << region.getNumArguments(); for (mlir::Block &block : region) { // MLIR will verify the absence of the terminator for us. if (!block.mightHaveTerminator()) continue; if (failed(verifyTerminator(block.getTerminator(), yieldsValue))) return llvm::failure(); } return llvm::success(); }; // Ensure all of the region arguments have the same type for (mlir::Region *region : getRegions()) for (mlir::Type ty : region->getArgumentTypes()) if (ty != argType) return emitError() << "Region argument type mismatch: got " << ty << " expected " << argType << "."; mlir::Region &initRegion = getInitRegion(); if (!initRegion.empty() && failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init", /*yieldsValue=*/true))) return llvm::failure(); LocalitySpecifierType dsType = getLocalitySpecifierType(); if (dsType == LocalitySpecifierType::Local && !getCopyRegion().empty()) return emitError("`local` specifiers do not require a `copy` region."); if (dsType == LocalitySpecifierType::LocalInit && getCopyRegion().empty()) return emitError( "`local_init` specifiers require at least a `copy` region."); if (dsType == LocalitySpecifierType::LocalInit && failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy", /*yieldsValue=*/true))) return llvm::failure(); if (!getDeallocRegion().empty() && failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc", /*yieldsValue=*/false))) return llvm::failure(); return llvm::success(); } // TODO This is a copy of omp::DeclareReductionOp::verifiyRegions(). Once we // find a solution to merge both ops into one this duplication will not be // needed. mlir::LogicalResult fir::DeclareReductionOp::verifyRegions() { if (!getAllocRegion().empty()) { for (YieldOp yieldOp : getAllocRegion().getOps()) { if (yieldOp.getResults().size() != 1 || yieldOp.getResults().getTypes()[0] != getType()) return emitOpError() << "expects alloc region to yield a value " "of the reduction type"; } } if (getInitializerRegion().empty()) return emitOpError() << "expects non-empty initializer region"; mlir::Block &initializerEntryBlock = getInitializerRegion().front(); if (initializerEntryBlock.getNumArguments() == 1) { if (!getAllocRegion().empty()) return emitOpError() << "expects two arguments to the initializer region " "when an allocation region is used"; } else if (initializerEntryBlock.getNumArguments() == 2) { if (getAllocRegion().empty()) return emitOpError() << "expects one argument to the initializer region " "when no allocation region is used"; } else { return emitOpError() << "expects one or two arguments to the initializer region"; } for (mlir::Value arg : initializerEntryBlock.getArguments()) if (arg.getType() != getType()) return emitOpError() << "expects initializer region argument to match " "the reduction type"; for (YieldOp yieldOp : getInitializerRegion().getOps()) { if (yieldOp.getResults().size() != 1 || yieldOp.getResults().getTypes()[0] != getType()) return emitOpError() << "expects initializer region to yield a value " "of the reduction type"; } if (getReductionRegion().empty()) return emitOpError() << "expects non-empty reduction region"; mlir::Block &reductionEntryBlock = getReductionRegion().front(); if (reductionEntryBlock.getNumArguments() != 2 || reductionEntryBlock.getArgumentTypes()[0] != reductionEntryBlock.getArgumentTypes()[1] || reductionEntryBlock.getArgumentTypes()[0] != getType()) return emitOpError() << "expects reduction region with two arguments of " "the reduction type"; for (YieldOp yieldOp : getReductionRegion().getOps()) { if (yieldOp.getResults().size() != 1 || yieldOp.getResults().getTypes()[0] != getType()) return emitOpError() << "expects reduction region to yield a value " "of the reduction type"; } if (!getAtomicReductionRegion().empty()) { mlir::Block &atomicReductionEntryBlock = getAtomicReductionRegion().front(); if (atomicReductionEntryBlock.getNumArguments() != 2 || atomicReductionEntryBlock.getArgumentTypes()[0] != atomicReductionEntryBlock.getArgumentTypes()[1]) return emitOpError() << "expects atomic reduction region with two " "arguments of the same type"; } if (getCleanupRegion().empty()) return mlir::success(); mlir::Block &cleanupEntryBlock = getCleanupRegion().front(); if (cleanupEntryBlock.getNumArguments() != 1 || cleanupEntryBlock.getArgument(0).getType() != getType()) return emitOpError() << "expects cleanup region with one argument " "of the reduction type"; return mlir::success(); } //===----------------------------------------------------------------------===// // DoConcurrentOp //===----------------------------------------------------------------------===// llvm::LogicalResult fir::DoConcurrentOp::verify() { mlir::Block *body = getBody(); if (body->empty()) return emitOpError("body cannot be empty"); if (!body->mightHaveTerminator() || !mlir::isa(body->getTerminator())) return emitOpError("must be terminated by 'fir.do_concurrent.loop'"); return mlir::success(); } //===----------------------------------------------------------------------===// // DoConcurrentLoopOp //===----------------------------------------------------------------------===// static mlir::ParseResult parseSpecifierList( mlir::OpAsmParser &parser, mlir::OperationState &result, llvm::StringRef specifierKeyword, llvm::StringRef symsAttrName, llvm::SmallVectorImpl ®ionArgs, llvm::SmallVectorImpl ®ionArgTypes, int32_t &numSpecifierOperands, bool isReduce = false) { auto &builder = parser.getBuilder(); llvm::SmallVector specifierOperands; if (failed(parser.parseOptionalKeyword(specifierKeyword))) return mlir::success(); std::size_t oldArgTypesSize = regionArgTypes.size(); if (failed(parser.parseLParen())) return mlir::failure(); llvm::SmallVector isByRefVec; llvm::SmallVector spceifierSymbolVec; llvm::SmallVector attributes; if (failed(parser.parseCommaSeparatedList([&]() { if (isReduce) isByRefVec.push_back( parser.parseOptionalKeyword("byref").succeeded()); if (failed(parser.parseAttribute(spceifierSymbolVec.emplace_back()))) return mlir::failure(); if (isReduce && failed(parser.parseAttribute(attributes.emplace_back()))) return mlir::failure(); if (parser.parseOperand(specifierOperands.emplace_back()) || parser.parseArrow() || parser.parseArgument(regionArgs.emplace_back())) return mlir::failure(); return mlir::success(); }))) return mlir::failure(); if (failed(parser.parseColon())) return mlir::failure(); if (failed(parser.parseCommaSeparatedList([&]() { if (failed(parser.parseType(regionArgTypes.emplace_back()))) return mlir::failure(); return mlir::success(); }))) return mlir::failure(); if (regionArgs.size() != regionArgTypes.size()) return parser.emitError(parser.getNameLoc(), "mismatch in number of " + specifierKeyword.str() + " arg and types"); if (failed(parser.parseRParen())) return mlir::failure(); for (auto operandType : llvm::zip_equal(specifierOperands, llvm::drop_begin(regionArgTypes, oldArgTypesSize))) if (parser.resolveOperand(std::get<0>(operandType), std::get<1>(operandType), result.operands)) return mlir::failure(); if (isReduce) result.addAttribute( fir::DoConcurrentLoopOp::getReduceByrefAttrName(result.name), isByRefVec.empty() ? nullptr : mlir::DenseBoolArrayAttr::get(builder.getContext(), isByRefVec)); llvm::SmallVector symbolAttrs(spceifierSymbolVec.begin(), spceifierSymbolVec.end()); result.addAttribute(symsAttrName, builder.getArrayAttr(symbolAttrs)); if (isReduce) { llvm::SmallVector arrayAttr(attributes.begin(), attributes.end()); result.addAttribute( fir::DoConcurrentLoopOp::getReduceAttrsAttrName(result.name), builder.getArrayAttr(arrayAttr)); } numSpecifierOperands = specifierOperands.size(); return mlir::success(); } mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` llvm::SmallVector regionArgs; if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren)) return mlir::failure(); llvm::SmallVector argTypes(regionArgs.size(), builder.getIndexType()); // Parse loop bounds. llvm::SmallVector lower; if (parser.parseEqual() || parser.parseOperandList(lower, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(lower, builder.getIndexType(), result.operands)) return mlir::failure(); llvm::SmallVector upper; if (parser.parseKeyword("to") || parser.parseOperandList(upper, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(upper, builder.getIndexType(), result.operands)) return mlir::failure(); // Parse step values. llvm::SmallVector steps; if (parser.parseKeyword("step") || parser.parseOperandList(steps, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(steps, builder.getIndexType(), result.operands)) return mlir::failure(); int32_t numLocalOperands = 0; if (failed(parseSpecifierList(parser, result, "local", getLocalSymsAttrName(result.name), regionArgs, argTypes, numLocalOperands))) return mlir::failure(); int32_t numReduceOperands = 0; if (failed(parseSpecifierList( parser, result, "reduce", getReduceSymsAttrName(result.name), regionArgs, argTypes, numReduceOperands, /*isReduce=*/true))) return mlir::failure(); // Set `operandSegmentSizes` attribute. result.addAttribute( DoConcurrentLoopOp::getOperandSegmentSizeAttr(), builder.getDenseI32ArrayAttr({static_cast(lower.size()), static_cast(upper.size()), static_cast(steps.size()), static_cast(numLocalOperands), static_cast(numReduceOperands)})); // Now parse the body. for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes)) arg.type = type; mlir::Region *body = result.addRegion(); if (parser.parseRegion(*body, regionArgs)) return mlir::failure(); // Parse attributes. if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); return mlir::success(); } void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { p << " (" << getBody()->getArguments().slice(0, getNumInductionVars()) << ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step (" << getStep() << ")"; if (!getLocalVars().empty()) { p << " local("; llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(), getRegionLocalArgs()), p, [&](auto it) { p << std::get<0>(it) << " " << std::get<1>(it) << " -> " << std::get<2>(it); }); p << " : "; llvm::interleaveComma(getLocalVars(), p, [&](auto it) { p << it.getType(); }); p << ")"; } if (!getReduceVars().empty()) { p << " reduce("; llvm::interleaveComma( llvm::zip_equal(getReduceByrefAttr().asArrayRef(), getReduceSymsAttr(), getReduceAttrsAttr(), getReduceVars(), getRegionReduceArgs()), p, [&](auto it) { if (std::get<0>(it)) p << "byref "; p << std::get<1>(it) << " " << std::get<2>(it) << " " << std::get<3>(it) << " -> " << std::get<4>(it); }); p << " : "; llvm::interleaveComma(getReduceVars(), p, [&](auto it) { p << it.getType(); }); p << ")"; } p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( (*this)->getAttrs(), /*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(), DoConcurrentLoopOp::getLocalSymsAttrName(), DoConcurrentLoopOp::getReduceSymsAttrName(), DoConcurrentLoopOp::getReduceAttrsAttrName(), DoConcurrentLoopOp::getReduceByrefAttrName()}); } llvm::SmallVector fir::DoConcurrentLoopOp::getLoopRegions() { return {&getRegion()}; } llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { mlir::Operation::operand_range lbValues = getLowerBound(); mlir::Operation::operand_range ubValues = getUpperBound(); mlir::Operation::operand_range stepValues = getStep(); mlir::Operation::operand_range localVars = getLocalVars(); mlir::Operation::operand_range reduceVars = getReduceVars(); if (lbValues.empty()) return emitOpError( "needs at least one tuple element for lowerBound, upperBound and step"); if (lbValues.size() != ubValues.size() || ubValues.size() != stepValues.size()) return emitOpError("different number of tuple elements for lowerBound, " "upperBound or step"); // Check that the body defines the same number of block arguments as the // number of tuple elements in step. mlir::Block *body = getBody(); unsigned numIndVarArgs = body->getNumArguments() - localVars.size() - reduceVars.size(); if (numIndVarArgs != stepValues.size()) return emitOpError() << "expects the same number of induction variables: " << body->getNumArguments() << " as bound and step values: " << stepValues.size(); for (auto arg : body->getArguments().slice(0, numIndVarArgs)) if (!arg.getType().isIndex()) return emitOpError( "expects arguments for the induction variable to be of index type"); auto reduceAttrs = getReduceAttrsAttr(); if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) return emitOpError( "mismatch in number of reduction variables and reduction attributes"); return mlir::success(); } std::optional> fir::DoConcurrentLoopOp::getLoopInductionVars() { return llvm::SmallVector{ getBody()->getArguments().slice(0, getLowerBound().size())}; } //===----------------------------------------------------------------------===// // FIROpsDialect //===----------------------------------------------------------------------===// void fir::FIROpsDialect::registerOpExternalInterfaces() { // Attach default declare target interfaces to operations which can be marked // as declare target. fir::GlobalOp::attachInterface< mlir::omp::DeclareTargetDefaultModel>(*getContext()); } // Tablegen operators #define GET_OP_CLASSES #include "flang/Optimizer/Dialect/FIROps.cpp.inc"