diff options
Diffstat (limited to 'mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp')
-rw-r--r-- | mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 177 |
1 files changed, 172 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 1f239aa..519d9c8 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/LogicalResult.h" namespace mlir { #define GEN_PASS_DEF_SCFTOEMITC @@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables, emitc::AssignOp::create(rewriter, loc, var, value); } -SmallVector<Value> loadValues(const SmallVector<Value> &variables, +SmallVector<Value> loadValues(ArrayRef<Value> variables, PatternRewriter &rewriter, Location loc) { return llvm::map_to_vector<>(variables, [&](Value var) { Type type = cast<emitc::LValueType>(var.getType()).getValueType(); @@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables, static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, ConversionPatternRewriter &rewriter, - scf::YieldOp yield) { + scf::YieldOp yield, bool createYield = true) { Location loc = yield.getLoc(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); SmallVector<Value> yieldOperands; - if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); - } assignValues(yieldOperands, resultVariables, rewriter, loc); @@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite( return success(); } +// Lower scf::while to emitc::do using mutable variables to maintain loop state +// across iterations. The do-while structure ensures the condition is evaluated +// after each iteration, matching SCF while semantics. +struct WhileLowering : public OpConversionPattern<WhileOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = whileOp.getLoc(); + MLIRContext *context = loc.getContext(); + + // Create an emitc::variable op for each result. These variables will be + // assigned to by emitc::assign ops within the loop body. + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(whileOp, + "Failed to create result variables"); + + // Create variable storage for loop-carried values to enable imperative + // updates while maintaining SSA semantics at conversion boundaries. + SmallVector<Value> loopVariables; + if (failed(createVariablesForLoopCarriedValues( + whileOp, rewriter, loopVariables, loc, context))) + return failure(); + + if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context, + rewriter, loc))) + return failure(); + + rewriter.setInsertionPointAfter(whileOp); + + // Load the final result values from result variables. + SmallVector<Value> finalResults = + loadValues(resultVariables, rewriter, loc); + rewriter.replaceOp(whileOp, finalResults); + + return success(); + } + +private: + // Initialize variables for loop-carried values to enable state updates + // across iterations without SSA argument passing. + LogicalResult createVariablesForLoopCarriedValues( + WhileOp whileOp, ConversionPatternRewriter &rewriter, + SmallVectorImpl<Value> &loopVars, Location loc, + MLIRContext *context) const { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(whileOp); + + emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); + + for (Value init : whileOp.getInits()) { + Type convertedType = getTypeConverter()->convertType(init.getType()); + if (!convertedType) + return rewriter.notifyMatchFailure(whileOp, "type conversion failed"); + + emitc::VariableOp var = rewriter.create<emitc::VariableOp>( + loc, emitc::LValueType::get(convertedType), noInit); + rewriter.create<emitc::AssignOp>(loc, var.getResult(), init); + loopVars.push_back(var); + } + + return success(); + } + + // Lower scf.while to emitc.do. + LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars, + ArrayRef<Value> resultVars, MLIRContext *context, + ConversionPatternRewriter &rewriter, + Location loc) const { + // Create a global boolean variable to store the loop condition state. + Type i1Type = IntegerType::get(context, 1); + auto globalCondition = + rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type), + emitc::OpaqueAttr::get(context, "")); + Value conditionVal = globalCondition.getResult(); + + auto loweredDo = rewriter.create<emitc::DoOp>(loc); + + // Convert region types to match the target dialect type system. + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), + *getTypeConverter(), nullptr)) || + failed(rewriter.convertRegionTypes(&whileOp.getAfter(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(whileOp, + "region types conversion failed"); + } + + // Prepare the before region (condition evaluation) for merging. + Block *beforeBlock = &whileOp.getBefore().front(); + Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion()); + rewriter.setInsertionPointToStart(bodyBlock); + + // Load current variable values to use as initial arguments for the + // condition block. + SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc); + rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues); + + Operation *condTerminator = + loweredDo.getBodyRegion().back().getTerminator(); + scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator); + rewriter.setInsertionPoint(condOp); + + // Update result variables with values from scf::condition. + SmallVector<Value> conditionArgs; + for (Value arg : condOp.getArgs()) { + conditionArgs.push_back(rewriter.getRemappedValue(arg)); + } + assignValues(conditionArgs, resultVars, rewriter, loc); + + // Convert scf.condition to condition variable assignment. + Value condition = rewriter.getRemappedValue(condOp.getCondition()); + rewriter.create<emitc::AssignOp>(loc, conditionVal, condition); + + // Wrap body region in conditional to preserve scf semantics. Only create + // ifOp if after-region is non-empty. + if (whileOp.getAfterBody()->getOperations().size() > 1) { + auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false); + + // Prepare the after region (loop body) for merging. + Block *afterBlock = &whileOp.getAfter().front(); + Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion()); + + // Replacement values for after block using condition op arguments. + SmallVector<Value> afterReplacingValues; + for (Value arg : condOp.getArgs()) + afterReplacingValues.push_back(rewriter.getRemappedValue(arg)); + + rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues); + + if (failed(lowerYield(whileOp, loopVars, rewriter, + cast<scf::YieldOp>(ifBodyBlock->getTerminator())))) + return failure(); + } + + rewriter.eraseOp(condOp); + + // Create condition region that loads from the flag variable. + Region &condRegion = loweredDo.getConditionRegion(); + Block *condBlock = rewriter.createBlock(&condRegion); + rewriter.setInsertionPointToStart(condBlock); + + auto exprOp = rewriter.create<emitc::ExpressionOp>( + loc, i1Type, conditionVal, /*do_not_inline=*/false); + Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion()); + + // Set up the expression block to load the condition variable. + exprBlock->addArgument(conditionVal.getType(), loc); + rewriter.setInsertionPointToStart(exprBlock); + + // Load the condition value and yield it as the expression result. + Value cond = + rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0)); + rewriter.create<emitc::YieldOp>(loc, cond); + + // Yield the expression as the condition region result. + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create<emitc::YieldOp>(loc, exprOp); + + return success(); + } +}; + void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add<ForLowering>(typeConverter, patterns.getContext()); patterns.add<IfLowering>(typeConverter, patterns.getContext()); patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext()); + patterns.add<WhileLowering>(typeConverter, patterns.getContext()); } void SCFToEmitCPass::runOnOperation() { @@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() { // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); - target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(); + target + .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) |