diff options
Diffstat (limited to 'mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp')
-rw-r--r-- | mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 206 |
1 files changed, 140 insertions, 66 deletions
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 67a43c4..92523ca 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> { // Lower scf::for to emitc::for, implementing result values using // emitc::variable's updated within the loop body. -struct ForLowering : public OpRewritePattern<ForOp> { - using OpRewritePattern<ForOp>::OpRewritePattern; +struct ForLowering : public OpConversionPattern<ForOp> { + using OpConversionPattern<ForOp>::OpConversionPattern; - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; // Create an uninitialized emitc::variable op for each result of the given op. template <typename T> -static SmallVector<Value> createVariablesForResults(T op, - PatternRewriter &rewriter) { - SmallVector<Value> resultVariables; - +static LogicalResult +createVariablesForResults(T op, const TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + SmallVector<Value> &resultVariables) { if (!op.getNumResults()) - return resultVariables; + return success(); Location loc = op->getLoc(); MLIRContext *context = op.getContext(); @@ -62,7 +64,9 @@ static SmallVector<Value> createVariablesForResults(T op, rewriter.setInsertionPoint(op); for (OpResult result : op.getResults()) { - Type resultType = result.getType(); + Type resultType = typeConverter->convertType(result.getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "result type conversion failed"); Type varType = emitc::LValueType::get(resultType); emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); emitc::VariableOp var = @@ -70,13 +74,13 @@ static SmallVector<Value> createVariablesForResults(T op, resultVariables.push_back(var); } - return resultVariables; + return success(); } // Create a series of assign ops assigning given values to given variables at // the current insertion point of given rewriter. -static void assignValues(ValueRange values, SmallVector<Value> &variables, - PatternRewriter &rewriter, Location loc) { +static void assignValues(ValueRange values, ValueRange variables, + ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) rewriter.create<emitc::AssignOp>(loc, var, value); } @@ -89,18 +93,25 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables, }); } -static void lowerYield(SmallVector<Value> &resultVariables, - PatternRewriter &rewriter, scf::YieldOp yield) { +static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + scf::YieldOp yield) { Location loc = yield.getLoc(); - ValueRange operands = yield.getOperands(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); - assignValues(operands, resultVariables, rewriter, loc); + SmallVector<Value> yieldOperands; + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); + } + + assignValues(yieldOperands, resultVariables, rewriter, loc); rewriter.create<emitc::YieldOp>(loc); rewriter.eraseOp(yield); + + return success(); } // Lower the contents of an scf::if/scf::index_switch regions to an @@ -108,27 +119,32 @@ static void lowerYield(SmallVector<Value> &resultVariables, // moved into the respective lowered region, but the scf::yield is replaced not // only with an emitc::yield, but also with a sequence of emitc::assign ops that // set the yielded values into the result variables. -static void lowerRegion(SmallVector<Value> &resultVariables, - PatternRewriter &rewriter, Region ®ion, - Region &loweredRegion) { +static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + Region ®ion, Region &loweredRegion) { rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); Operation *terminator = loweredRegion.back().getTerminator(); - lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator)); + return lowerYield(op, resultVariables, rewriter, + cast<scf::YieldOp>(terminator)); } -LogicalResult ForLowering::matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const { +LogicalResult +ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the loop body. - SmallVector<Value> resultVariables = - createVariablesForResults(forOp, rewriter); + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(forOp, + "create variables for results failed"); - assignValues(forOp.getInits(), resultVariables, rewriter, loc); + assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>( - loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); + loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); Block *loweredBody = loweredFor.getBody(); @@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, rewriter.restoreInsertionPoint(ip); + // Convert the original region types into the new types by adding unrealized + // casts in the beginning of the loop. This performs the conversion in place. + if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(forOp, "region types conversion failed"); + } + + // Register the replacements for the block arguments and inline the body of + // the scf.for loop into the body of the emitc::for loop. + Block *scfBody = &(forOp.getRegion().front()); SmallVector<Value> replacingValues; replacingValues.push_back(loweredFor.getInductionVar()); replacingValues.append(iterArgsValues.begin(), iterArgsValues.end()); + rewriter.mergeBlocks(scfBody, loweredBody, replacingValues); - rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues); - lowerYield(resultVariables, rewriter, - cast<scf::YieldOp>(loweredBody->getTerminator())); + auto result = lowerYield(forOp, resultVariables, rewriter, + cast<scf::YieldOp>(loweredBody->getTerminator())); + + if (failed(result)) { + return result; + } // Load variables into SSA values after the for loop. SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc); @@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, // Lower scf::if to emitc::if, implementing result values as emitc::variable's // updated within the then and else regions. -struct IfLowering : public OpRewritePattern<IfOp> { - using OpRewritePattern<IfOp>::OpRewritePattern; +struct IfLowering : public OpConversionPattern<IfOp> { + using OpConversionPattern<IfOp>::OpConversionPattern; - LogicalResult matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; } // namespace -LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const { +LogicalResult +IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = ifOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the then & else regions. - SmallVector<Value> resultVariables = - createVariablesForResults(ifOp, rewriter); - - Region &thenRegion = ifOp.getThenRegion(); - Region &elseRegion = ifOp.getElseRegion(); + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(ifOp, + "create variables for results failed"); + + // Utility function to lower the contents of an scf::if region to an emitc::if + // region. The contents of the scf::if regions is moved into the respective + // emitc::if regions, but the scf::yield is replaced not only with an + // emitc::yield, but also with a sequence of emitc::assign ops that set the + // yielded values into the result variables. + auto lowerRegion = [&resultVariables, &rewriter, + &ifOp](Region ®ion, Region &loweredRegion) { + rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); + Operation *terminator = loweredRegion.back().getTerminator(); + auto result = lowerYield(ifOp, resultVariables, rewriter, + cast<scf::YieldOp>(terminator)); + if (failed(result)) { + return result; + } + return success(); + }; + + Region &thenRegion = adaptor.getThenRegion(); + Region &elseRegion = adaptor.getElseRegion(); bool hasElseBlock = !elseRegion.empty(); auto loweredIf = - rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false); + rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); - lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion); + auto result = lowerRegion(thenRegion, loweredThenRegion); + if (failed(result)) { + return result; + } if (hasElseBlock) { Region &loweredElseRegion = loweredIf.getElseRegion(); - lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion); + auto result = lowerRegion(elseRegion, loweredElseRegion); + if (failed(result)) { + return result; + } } rewriter.setInsertionPointAfter(ifOp); @@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, // Lower scf::index_switch to emitc::switch, implementing result values as // emitc::variable's updated within the case and default regions. -struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> { - using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; +struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; -LogicalResult -IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp, - PatternRewriter &rewriter) const { +LogicalResult IndexSwitchOpLowering::matchAndRewrite( + IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = indexSwitchOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the case and default regions. - SmallVector<Value> resultVariables = - createVariablesForResults(indexSwitchOp, rewriter); + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(), + rewriter, resultVariables))) { + return rewriter.notifyMatchFailure(indexSwitchOp, + "create variables for results failed"); + } auto loweredSwitch = rewriter.create<emitc::SwitchOp>( - loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(), - indexSwitchOp.getNumCases()); + loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); // Lowering all case regions. - for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(), - loweredSwitch.getCaseRegions())) { - lowerRegion(resultVariables, rewriter, std::get<0>(pair), - std::get<1>(pair)); + for (auto pair : + llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) { + if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, + *std::get<0>(pair), std::get<1>(pair)))) { + return failure(); + } } // Lowering default region. - lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(), - loweredSwitch.getDefaultRegion()); + if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, + adaptor.getDefaultRegion(), + loweredSwitch.getDefaultRegion()))) { + return failure(); + } rewriter.setInsertionPointAfter(indexSwitchOp); SmallVector<Value> results = loadValues(resultVariables, rewriter, loc); @@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp, return success(); } -void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) { - patterns.add<ForLowering>(patterns.getContext()); - patterns.add<IfLowering>(patterns.getContext()); - patterns.add<IndexSwitchOpLowering>(patterns.getContext()); +void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add<ForLowering>(typeConverter, patterns.getContext()); + patterns.add<IfLowering>(typeConverter, patterns.getContext()); + patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext()); } void SCFToEmitCPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - populateSCFToEmitCConversionPatterns(patterns); + TypeConverter typeConverter; + // Fallback converter + // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter + // Type converters are called most to least recently inserted + typeConverter.addConversion([](Type t) { return t; }); + populateEmitCSizeTTypeConversions(typeConverter); + populateSCFToEmitCConversionPatterns(patterns, typeConverter); // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); |