diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF')
-rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 104 |
1 files changed, 99 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index a9da6c2..744a595 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DebugLog.h" @@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; +// Pattern to eliminate ExecuteRegionOp results which forward external +// values from the region. In case there are multiple yield operations, +// all of them must have the same operands in order for the pattern to be +// applicable. +struct ExecuteRegionForwardingEliminator + : public OpRewritePattern<ExecuteRegionOp> { + using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getNumResults() == 0) + return failure(); + + SmallVector<Operation *> yieldOps; + for (Block &block : op.getRegion()) { + if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) + yieldOps.push_back(yield.getOperation()); + } + + if (yieldOps.empty()) + return failure(); + + // Check if all yield operations have the same operands. + auto yieldOpsOperands = yieldOps[0]->getOperands(); + for (auto *yieldOp : yieldOps) { + if (yieldOp->getOperands() != yieldOpsOperands) + return failure(); + } + + SmallVector<Value> externalValues; + SmallVector<Value> internalValues; + SmallVector<Value> opResultsToReplaceWithExternalValues; + SmallVector<Value> opResultsToKeep; + for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { + if (isValueFromInsideRegion(yieldedValue, op)) { + internalValues.push_back(yieldedValue); + opResultsToKeep.push_back(op.getResult(index)); + } else { + externalValues.push_back(yieldedValue); + opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); + } + } + // No yielded external values - nothing to do. + if (externalValues.empty()) + return failure(); + + // There are yielded external values - create a new execute_region returning + // just the internal values. + SmallVector<Type> resultTypes; + for (Value value : internalValues) + resultTypes.push_back(value.getType()); + auto newOp = + ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); + newOp->setAttrs(op->getAttrs()); + + // Move old op's region to the new operation. + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Replace all yield operations with a new yield operation with updated + // results. scf.execute_region must have at least one yield operation. + for (auto *yieldOp : yieldOps) { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, + ValueRange(internalValues)); + } + + // Replace the old operation with the external values directly. + rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, + externalValues); + // Replace the old operation's remaining results with the new operation's + // results. + rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); + rewriter.eraseOp(op); + return success(); + } + +private: + bool isValueFromInsideRegion(Value value, + ExecuteRegionOp executeRegionOp) const { + // Check if the value is defined within the execute_region + if (Operation *defOp = value.getDefiningOp()) + return &executeRegionOp.getRegion() == defOp->getParentRegion(); + + // If it's a block argument, check if it's from within the region + if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) + return &executeRegionOp.getRegion() == blockArg.getParentRegion(); + + return false; // Value is from outside the region + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, + ExecuteRegionForwardingEliminator>(context); } void ExecuteRegionOp::getSuccessorRegions( @@ -2490,8 +2584,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantTrue) - constantTrue = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); + constantTrue = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantTrue); }); @@ -2500,8 +2594,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantFalse) - constantFalse = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); + constantFalse = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantFalse); }); |