aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/IR/SCF.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SCF/IR/SCF.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp96
1 files changed, 95 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9bd13f3..744a595 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
@@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};
+// Pattern to eliminate ExecuteRegionOp results which forward external
+// values from the region. In case there are multiple yield operations,
+// all of them must have the same operands in order for the pattern to be
+// applicable.
+struct ExecuteRegionForwardingEliminator
+ : public OpRewritePattern<ExecuteRegionOp> {
+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getNumResults() == 0)
+ return failure();
+
+ SmallVector<Operation *> yieldOps;
+ for (Block &block : op.getRegion()) {
+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
+ yieldOps.push_back(yield.getOperation());
+ }
+
+ if (yieldOps.empty())
+ return failure();
+
+ // Check if all yield operations have the same operands.
+ auto yieldOpsOperands = yieldOps[0]->getOperands();
+ for (auto *yieldOp : yieldOps) {
+ if (yieldOp->getOperands() != yieldOpsOperands)
+ return failure();
+ }
+
+ SmallVector<Value> externalValues;
+ SmallVector<Value> internalValues;
+ SmallVector<Value> opResultsToReplaceWithExternalValues;
+ SmallVector<Value> opResultsToKeep;
+ for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
+ if (isValueFromInsideRegion(yieldedValue, op)) {
+ internalValues.push_back(yieldedValue);
+ opResultsToKeep.push_back(op.getResult(index));
+ } else {
+ externalValues.push_back(yieldedValue);
+ opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
+ }
+ }
+ // No yielded external values - nothing to do.
+ if (externalValues.empty())
+ return failure();
+
+ // There are yielded external values - create a new execute_region returning
+ // just the internal values.
+ SmallVector<Type> resultTypes;
+ for (Value value : internalValues)
+ resultTypes.push_back(value.getType());
+ auto newOp =
+ ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
+ newOp->setAttrs(op->getAttrs());
+
+ // Move old op's region to the new operation.
+ rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
+ newOp.getRegion().end());
+
+ // Replace all yield operations with a new yield operation with updated
+ // results. scf.execute_region must have at least one yield operation.
+ for (auto *yieldOp : yieldOps) {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
+ ValueRange(internalValues));
+ }
+
+ // Replace the old operation with the external values directly.
+ rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
+ externalValues);
+ // Replace the old operation's remaining results with the new operation's
+ // results.
+ rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ bool isValueFromInsideRegion(Value value,
+ ExecuteRegionOp executeRegionOp) const {
+ // Check if the value is defined within the execute_region
+ if (Operation *defOp = value.getDefiningOp())
+ return &executeRegionOp.getRegion() == defOp->getParentRegion();
+
+ // If it's a block argument, check if it's from within the region
+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
+ return &executeRegionOp.getRegion() == blockArg.getParentRegion();
+
+ return false; // Value is from outside the region
+ }
+};
+
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
+ ExecuteRegionForwardingEliminator>(context);
}
void ExecuteRegionOp::getSuccessorRegions(