//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Pass/Pass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { /// A pattern that converts the result and operand types, attributes, and region /// arguments of an OpenMP operation to the LLVM dialect. /// /// Attributes are copied verbatim by default, and only translated if they are /// type attributes. /// /// Region bodies, if any, are not modified and expected to either be processed /// by the conversion infrastructure or already contain ops compatible with LLVM /// dialect types. template struct OpenMPOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; OpenMPOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit) { // Operations using CanonicalLoopInfoType are lowered only by // mlir::translateModuleToLLVMIR() using the OpenMPIRBuilder. Until then, // the type and operations using it must be preserved. typeConverter.addConversion( [&](::mlir::omp::CanonicalLoopInfoType type) { return type; }); } LogicalResult matchAndRewrite(T op, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Translate result types. const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(op->getResultTypes(), resTypes))) return failure(); // Translate type attributes. // They are kept unmodified except if they are type attributes. SmallVector convertedAttrs; for (NamedAttribute attr : op->getAttrs()) { if (auto typeAttr = dyn_cast(attr.getValue())) { Type convertedType = converter->convertType(typeAttr.getValue()); convertedAttrs.emplace_back(attr.getName(), TypeAttr::get(convertedType)); } else { convertedAttrs.push_back(attr); } } // Translate operands. SmallVector convertedOperands; convertedOperands.reserve(op->getNumOperands()); for (auto [originalOperand, convertedOperand] : llvm::zip_equal(op->getOperands(), adaptor.getOperands())) { if (!originalOperand) return failure(); // TODO: Revisit whether we need to trigger an error specifically for this // set of operations. Consider removing this check or updating the list. if constexpr (llvm::is_one_of::value) { if (isa(originalOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(op, "memref is not supported yet"); } } convertedOperands.push_back(convertedOperand); } // Create new operation. auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands, convertedAttrs); // Translate regions. for (auto [originalRegion, convertedRegion] : llvm::zip_equal(op->getRegions(), newOp->getRegions())) { rewriter.inlineRegionBefore(originalRegion, convertedRegion, convertedRegion.end()); if (failed(rewriter.convertRegionTypes(&convertedRegion, *this->getTypeConverter()))) return failure(); } // Delete old operation and replace result uses with those of the new one. rewriter.replaceOp(op, newOp->getResults()); return success(); } }; } // namespace void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, const LLVMTypeConverter &typeConverter) { target.addDynamicallyLegalOp< #define GET_OP_LIST #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" >([&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()) && llvm::all_of(op->getRegions(), [&](Region ®ion) { return typeConverter.isLegal(®ion); }) && llvm::all_of(op->getAttrs(), [&](NamedAttribute attr) { auto typeAttr = dyn_cast(attr.getValue()); return !typeAttr || typeConverter.isLegal(typeAttr.getValue()); }); }); } /// Add an `OpenMPOpConversion` conversion pattern for each operation type /// passed as template argument. template static inline RewritePatternSet & addOpenMPOpConversions(LLVMTypeConverter &converter, RewritePatternSet &patterns) { return patterns.add...>(converter); } void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { // This type is allowed when converting OpenMP to LLVM Dialect, it carries // bounds information for map clauses and the operation and type are // discarded on lowering to LLVM-IR from the OpenMP dialect. converter.addConversion( [&](omp::MapBoundsType type) -> Type { return type; }); // Add conversions for all OpenMP operations. addOpenMPOpConversions< #define GET_OP_LIST #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" >(converter, patterns); } namespace { struct ConvertOpenMPToLLVMPass : public impl::ConvertOpenMPToLLVMPassBase { using Base::Base; void runOnOperation() override; }; } // namespace void ConvertOpenMPToLLVMPass::runOnOperation() { auto module = getOperation(); // Convert to OpenMP operations with LLVM IR dialect RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); arith::populateArithToLLVMConversionPatterns(converter, patterns); cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); cf::populateAssertToLLVMConversionPattern(converter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); populateFuncToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalOp(); configureOpenMPToLLVMConversionLegality(target, converter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// namespace { /// Implement the interface to convert OpenMP to LLVM. struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { configureOpenMPToLLVMConversionLegality(target, typeConverter); populateOpenMPToLLVMConversionPatterns(typeConverter, patterns); } }; } // namespace void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) { dialect->addInterfaces(); }); }