aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp')
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp177
1 files changed, 172 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 1f239aa..519d9c8 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
#define GEN_PASS_DEF_SCFTOEMITC
@@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables,
emitc::AssignOp::create(rewriter, loc, var, value);
}
-SmallVector<Value> loadValues(const SmallVector<Value> &variables,
+SmallVector<Value> loadValues(ArrayRef<Value> variables,
PatternRewriter &rewriter, Location loc) {
return llvm::map_to_vector<>(variables, [&](Value var) {
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
@@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
- scf::YieldOp yield) {
+ scf::YieldOp yield, bool createYield = true) {
Location loc = yield.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);
SmallVector<Value> yieldOperands;
- if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
+ if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
- }
assignValues(yieldOperands, resultVariables, rewriter, loc);
@@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
return success();
}
+// Lower scf::while to emitc::do using mutable variables to maintain loop state
+// across iterations. The do-while structure ensures the condition is evaluated
+// after each iteration, matching SCF while semantics.
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = whileOp.getLoc();
+ MLIRContext *context = loc.getContext();
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
+ resultVariables)))
+ return rewriter.notifyMatchFailure(whileOp,
+ "Failed to create result variables");
+
+ // Create variable storage for loop-carried values to enable imperative
+ // updates while maintaining SSA semantics at conversion boundaries.
+ SmallVector<Value> loopVariables;
+ if (failed(createVariablesForLoopCarriedValues(
+ whileOp, rewriter, loopVariables, loc, context)))
+ return failure();
+
+ if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
+ rewriter, loc)))
+ return failure();
+
+ rewriter.setInsertionPointAfter(whileOp);
+
+ // Load the final result values from result variables.
+ SmallVector<Value> finalResults =
+ loadValues(resultVariables, rewriter, loc);
+ rewriter.replaceOp(whileOp, finalResults);
+
+ return success();
+ }
+
+private:
+ // Initialize variables for loop-carried values to enable state updates
+ // across iterations without SSA argument passing.
+ LogicalResult createVariablesForLoopCarriedValues(
+ WhileOp whileOp, ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &loopVars, Location loc,
+ MLIRContext *context) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(whileOp);
+
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+
+ for (Value init : whileOp.getInits()) {
+ Type convertedType = getTypeConverter()->convertType(init.getType());
+ if (!convertedType)
+ return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
+
+ emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
+ loc, emitc::LValueType::get(convertedType), noInit);
+ rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ loopVars.push_back(var);
+ }
+
+ return success();
+ }
+
+ // Lower scf.while to emitc.do.
+ LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
+ ArrayRef<Value> resultVars, MLIRContext *context,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ // Create a global boolean variable to store the loop condition state.
+ Type i1Type = IntegerType::get(context, 1);
+ auto globalCondition =
+ rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
+ Value conditionVal = globalCondition.getResult();
+
+ auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+
+ // Convert region types to match the target dialect type system.
+ if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
+ *getTypeConverter(), nullptr)) ||
+ failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
+ *getTypeConverter(), nullptr))) {
+ return rewriter.notifyMatchFailure(whileOp,
+ "region types conversion failed");
+ }
+
+ // Prepare the before region (condition evaluation) for merging.
+ Block *beforeBlock = &whileOp.getBefore().front();
+ Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
+ rewriter.setInsertionPointToStart(bodyBlock);
+
+ // Load current variable values to use as initial arguments for the
+ // condition block.
+ SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
+ rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
+
+ Operation *condTerminator =
+ loweredDo.getBodyRegion().back().getTerminator();
+ scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
+ rewriter.setInsertionPoint(condOp);
+
+ // Update result variables with values from scf::condition.
+ SmallVector<Value> conditionArgs;
+ for (Value arg : condOp.getArgs()) {
+ conditionArgs.push_back(rewriter.getRemappedValue(arg));
+ }
+ assignValues(conditionArgs, resultVars, rewriter, loc);
+
+ // Convert scf.condition to condition variable assignment.
+ Value condition = rewriter.getRemappedValue(condOp.getCondition());
+ rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+
+ // Wrap body region in conditional to preserve scf semantics. Only create
+ // ifOp if after-region is non-empty.
+ if (whileOp.getAfterBody()->getOperations().size() > 1) {
+ auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+
+ // Prepare the after region (loop body) for merging.
+ Block *afterBlock = &whileOp.getAfter().front();
+ Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
+
+ // Replacement values for after block using condition op arguments.
+ SmallVector<Value> afterReplacingValues;
+ for (Value arg : condOp.getArgs())
+ afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
+
+ rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
+
+ if (failed(lowerYield(whileOp, loopVars, rewriter,
+ cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
+ return failure();
+ }
+
+ rewriter.eraseOp(condOp);
+
+ // Create condition region that loads from the flag variable.
+ Region &condRegion = loweredDo.getConditionRegion();
+ Block *condBlock = rewriter.createBlock(&condRegion);
+ rewriter.setInsertionPointToStart(condBlock);
+
+ auto exprOp = rewriter.create<emitc::ExpressionOp>(
+ loc, i1Type, conditionVal, /*do_not_inline=*/false);
+ Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
+
+ // Set up the expression block to load the condition variable.
+ exprBlock->addArgument(conditionVal.getType(), loc);
+ rewriter.setInsertionPointToStart(exprBlock);
+
+ // Load the condition value and yield it as the expression result.
+ Value cond =
+ rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
+ rewriter.create<emitc::YieldOp>(loc, cond);
+
+ // Yield the expression as the condition region result.
+ rewriter.setInsertionPointToEnd(condBlock);
+ rewriter.create<emitc::YieldOp>(loc, exprOp);
+
+ return success();
+ }
+};
+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ForLowering>(typeConverter, patterns.getContext());
patterns.add<IfLowering>(typeConverter, patterns.getContext());
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
+ patterns.add<WhileLowering>(typeConverter, patterns.getContext());
}
void SCFToEmitCPass::runOnOperation() {
@@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
- target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
+ target
+ .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))