diff options
Diffstat (limited to 'mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp')
| -rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index b2f1d84..8c9c137 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1042,6 +1042,65 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +/// Remove empty acc.kernel_environment operations. If the operation has wait +/// operands, create a acc.wait operation to preserve synchronization. +struct RemoveEmptyKernelEnvironment + : public OpRewritePattern<acc::KernelEnvironmentOp> { + using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op, + PatternRewriter &rewriter) const override { + assert(op->getNumRegions() == 1 && "expected op to have one region"); + + Block &block = op.getRegion().front(); + if (!block.empty()) + return failure(); + + // Conservatively disable canonicalization of empty acc.kernel_environment + // operations if the wait operands in the kernel_environment cannot be fully + // represented by acc.wait operation. + + // Disable canonicalization if device type is not the default + if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) { + for (auto attr : deviceTypeAttr) { + if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) { + if (dtAttr.getValue() != mlir::acc::DeviceType::None) + return failure(); + } + } + } + + // Disable canonicalization if any wait segment has a devnum + if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) { + for (auto attr : hasDevnumAttr) { + if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) { + if (boolAttr.getValue()) + return failure(); + } + } + } + + // Disable canonicalization if there are multiple wait segments + if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) { + if (segmentsAttr.size() > 1) + return failure(); + } + + // Remove empty kernel environment. + // Preserve synchronization by creating acc.wait operation if needed. + if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr()) + rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(), + /*asyncOperand=*/Value(), + /*waitDevnum=*/Value(), + /*async=*/nullptr, + /*ifCond=*/Value()); + else + rewriter.eraseOp(op); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // Recipe Region Helpers //===----------------------------------------------------------------------===// @@ -2691,6 +2750,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// +// KernelEnvironmentOp +//===----------------------------------------------------------------------===// + +void acc::KernelEnvironmentOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add<RemoveEmptyKernelEnvironment>(context); +} + +//===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// |
