//===-- LLVMInsertChainFolder.cpp -----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "flang-insert-folder" #include namespace { // Helper class to construct the attribute elements of an aggregate value being // folded without creating a full mlir::Attribute representation for each step // of the insert value chain, which would both be expensive in terms of // compilation time and memory (since the intermediate Attribute would survive, // unused, inside the mlir context). class InsertChainBackwardFolder { // Type for the current value of an element of the aggregate value being // constructed by the insert chain. // At any point of the insert chain, the value of an element is either: // - nullptr: not yet known, the insert has not yet been seen. // - an mlir::Attribute: the element is fully defined. // - a nested InsertChainBackwardFolder: the element is itself an aggregate // and its sub-elements have been partially defined (insert with mutliple // indices have been seen). // The insertion folder assumes backward walk of the insert chain. Once an // element or sub-element has been defined, it is not overriden by new // insertions (last insert wins). using InFlightValue = llvm::PointerUnion; public: InsertChainBackwardFolder( mlir::Type type, std::deque *folderStorage) : values(getNumElements(type), mlir::Attribute{}), folderStorage{folderStorage}, type{type} {} /// Push bool pushValue(mlir::Attribute val, llvm::ArrayRef at); mlir::Attribute finalize(mlir::Attribute defaultFieldValue); private: static int64_t getNumElements(mlir::Type type) { if (auto structTy = llvm::dyn_cast_if_present(type)) return structTy.getBody().size(); if (auto arrayTy = llvm::dyn_cast_if_present(type)) return arrayTy.getNumElements(); return 0; } static mlir::Type getSubElementType(mlir::Type type, int64_t field) { if (auto arrayTy = llvm::dyn_cast_if_present(type)) return arrayTy.getElementType(); if (auto structTy = llvm::dyn_cast_if_present(type)) return structTy.getBody()[field]; return nullptr; } // Current element value of the aggregate value being built. llvm::SmallVector values; // std::deque is used to allocate storage for nested list and guarantee the // stability of the InsertChainBackwardFolder* used as element value. std::deque *folderStorage; // Type of the aggregate value being built. mlir::Type type; }; } // namespace // Helper to fold the value being inserted by an llvm.insert_value. // This may call tryFoldingLLVMInsertChain if the value is an aggregate and // was itself constructed by a different insert chain. // Returns a nullptr Attribute if the value could not be folded. static mlir::Attribute getAttrIfConstant(mlir::Value val, mlir::OpBuilder &rewriter) { if (auto cst = val.getDefiningOp()) return cst.getValue(); if (auto insert = val.getDefiningOp()) { llvm::FailureOr attr = fir::tryFoldingLLVMInsertChain(val, rewriter); if (succeeded(attr)) return *attr; return nullptr; } if (val.getDefiningOp()) return mlir::LLVM::ZeroAttr::get(val.getContext()); if (val.getDefiningOp()) return mlir::LLVM::UndefAttr::get(val.getContext()); if (mlir::Operation *op = val.getDefiningOp()) { unsigned resNum = llvm::cast(val).getResultNumber(); llvm::SmallVector results; if (mlir::succeeded(rewriter.tryFold(op, results)) && results.size() > resNum) { if (auto cst = results[resNum].getDefiningOp()) return cst.getValue(); } } if (auto trunc = val.getDefiningOp()) if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter)) if (auto intAttr = llvm::dyn_cast(attr)) return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt()); LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val << "\n"); return nullptr; } mlir::Attribute InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) { llvm::SmallVector attrs = llvm::map_to_vector( values, [&](InFlightValue inFlight) -> mlir::Attribute { if (!inFlight) return defaultFieldValue; if (auto attr = llvm::dyn_cast(inFlight)) return attr; return llvm::cast(inFlight)->finalize( defaultFieldValue); }); return mlir::ArrayAttr::get(type.getContext(), attrs); } bool InsertChainBackwardFolder::pushValue(mlir::Attribute val, llvm::ArrayRef at) { if (at.size() == 0 || at[0] >= static_cast(values.size())) return false; InFlightValue &inFlight = values[at[0]]; if (!inFlight) { if (at.size() == 1) { inFlight = val; return true; } // This is the first insert to a nested field. Create a // InsertChainBackwardFolder for the current element value. mlir::Type subType = getSubElementType(type, at[0]); if (!subType) return false; InsertChainBackwardFolder &inFlightList = folderStorage->emplace_back(subType, folderStorage); inFlight = &inFlightList; return inFlightList.pushValue(val, at.drop_front()); } // Keep last inserted value if already set. if (llvm::isa(inFlight)) return true; auto *inFlightList = llvm::cast(inFlight); if (at.size() == 1) { if (!llvm::isa(val)) { LLVM_DEBUG(llvm::dbgs() << "insert chain sub-element partially overwritten initial " "value is not zero or undef\n"); return false; } inFlight = inFlightList->finalize(val); return true; } return inFlightList->pushValue(val, at.drop_front()); } llvm::FailureOr fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) { if (auto cst = val.getDefiningOp()) return cst.getValue(); if (auto insert = val.getDefiningOp()) { LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n"); if (auto structTy = llvm::dyn_cast(insert.getType())) { mlir::LLVM::InsertValueOp currentInsert = insert; mlir::LLVM::InsertValueOp lastInsert; std::deque folderStorage; InsertChainBackwardFolder inFlightList(structTy, &folderStorage); while (currentInsert) { mlir::Attribute attr = getAttrIfConstant(currentInsert.getValue(), rewriter); if (!attr) return llvm::failure(); if (!inFlightList.pushValue(attr, currentInsert.getPosition())) return llvm::failure(); lastInsert = currentInsert; currentInsert = currentInsert.getContainer() .getDefiningOp(); } mlir::Attribute defaultVal; if (lastInsert) { if (lastInsert.getContainer().getDefiningOp()) defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext()); else if (lastInsert.getContainer().getDefiningOp()) defaultVal = mlir::LLVM::UndefAttr::get(val.getContext()); } if (!defaultVal) { LLVM_DEBUG(llvm::dbgs() << "insert chain initial value is not Zero or Undef\n"); return llvm::failure(); } return inFlightList.finalize(defaultVal); } } return llvm::failure(); }