//===-- CUFFunctionRewrite.cpp --------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/CodeGen/TypeConverter.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Support/DataLayout.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "flang-cuf-function-rewrite" namespace fir { #define GEN_PASS_DEF_CUFFUNCTIONREWRITE #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir using namespace mlir; namespace { using genFunctionType = std::function; class CallConversion : public OpRewritePattern { public: CallConversion(MLIRContext *context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(fir::CallOp op, mlir::PatternRewriter &rewriter) const override { auto callee = op.getCallee(); if (!callee) return failure(); auto name = callee->getRootReference().getValue(); if (genMappings_.contains(name)) { auto fct = genMappings_.find(name); mlir::Value result = fct->second(rewriter, op); if (result) rewriter.replaceOp(op, result); else rewriter.eraseOp(op); return success(); } return failure(); } private: static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter, fir::CallOp op) { assert(op.getArgs().size() == 0 && "expect 0 arguments"); mlir::Location loc = op.getLoc(); unsigned inGPUMod = op->getParentOfType() ? 1 : 0; mlir::Type i1Ty = rewriter.getIntegerType(1); mlir::Value t = mlir::arith::ConstantOp::create( rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod)); return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t); } const llvm::StringMap genMappings_ = { {"on_device", &genOnDevice}}; }; class CUFFunctionRewrite : public fir::impl::CUFFunctionRewriteBase { public: void runOnOperation() override { auto *ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); patterns.insert(patterns.getContext()); if (mlir::failed( mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(ctx), "error in CUFFunctionRewrite op conversion\n"); signalPassFailure(); } } }; } // namespace