diff options
author | Simon Camphausen <simon.camphausen@iml.fraunhofer.de> | 2024-07-10 15:24:32 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-10 15:24:32 +0200 |
commit | d521324e9fa89f2db8229fb1327c7d45df0ff3cc (patch) | |
tree | dff3fb030ae2bdb29adf2c6d6becb2b3f9bc4a82 /mlir/lib/Target/Cpp/TranslateToCpp.cpp | |
parent | b841e2eca3b5c8b408214a46593f6a025e0fe48b (diff) | |
download | llvm-d521324e9fa89f2db8229fb1327c7d45df0ff3cc.zip llvm-d521324e9fa89f2db8229fb1327c7d45df0ff3cc.tar.gz llvm-d521324e9fa89f2db8229fb1327c7d45df0ff3cc.tar.bz2 |
[mlir][EmitC] Unify handling of operations which are emitted in a deferred way (#97804)
Several operations from the EmitC dialect don't produce output directly
during emission, but rather when being used as an operand. These changes
unify the handling of such operations and fix a bug in the emission of
global ops.
Co-authored-by: Marius Brehler <marius.brehler@iml.fraunhofer.de>
Diffstat (limited to 'mlir/lib/Target/Cpp/TranslateToCpp.cpp')
-rw-r--r-- | mlir/lib/Target/Cpp/TranslateToCpp.cpp | 89 |
1 files changed, 43 insertions, 46 deletions
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 6266382..eda8d5c 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -174,6 +174,9 @@ struct CppEmitter { /// Emit an expression as a C expression. LogicalResult emitExpression(ExpressionOp expressionOp); + /// Insert the expression representing the operation into the value cache. + void cacheDeferredOpResult(Value value, StringRef str); + /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); @@ -273,6 +276,12 @@ private: }; } // namespace +/// Determine whether expression \p op should be emitted in a deferred way. +static bool hasDeferredEmission(Operation *op) { + return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, + emitc::SubscriptOp>(op); +} + /// Determine whether expression \p expressionOp should be emitted inline, i.e. /// as part of its user. This function recommends inlining of any expressions /// that can be inlined unless it is used by another expression, under the @@ -295,10 +304,9 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { Operation *user = *result.getUsers().begin(); - // Do not inline expressions used by subscript operations, since the - // way the subscript operation translation is implemented requires that - // variables be materialized. - if (isa<emitc::SubscriptOp>(user)) + // Do not inline expressions used by operations with deferred emission, since + // their translation requires the materialization of variables. + if (hasDeferredEmission(user)) return false; // Do not inline expressions used by ops with the CExpression trait. If this @@ -370,20 +378,6 @@ static LogicalResult printOperation(CppEmitter &emitter, return emitter.emitOperand(assignOp.getValue()); } -static LogicalResult printOperation(CppEmitter &emitter, - emitc::GetGlobalOp op) { - // Add name to cache so that `hasValueInScope` works. - emitter.getOrCreateName(op.getResult()); - return success(); -} - -static LogicalResult printOperation(CppEmitter &emitter, - emitc::SubscriptOp subscriptOp) { - // Add name to cache so that `hasValueInScope` works. - emitter.getOrCreateName(subscriptOp.getResult()); - return success(); -} - static LogicalResult printBinaryOperation(CppEmitter &emitter, Operation *operation, StringRef binaryOperator) { @@ -621,9 +615,7 @@ static LogicalResult printOperation(CppEmitter &emitter, if (t.getType().isIndex()) { int64_t idx = t.getInt(); Value operand = op.getOperand(idx); - auto literalDef = - dyn_cast_if_present<LiteralOp>(operand.getDefiningOp()); - if (!literalDef && !emitter.hasValueInScope(operand)) + if (!emitter.hasValueInScope(operand)) return op.emitOpError("operand ") << idx << "'s value not defined in scope"; os << emitter.getOrCreateName(operand); @@ -948,8 +940,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter, // regions. WalkResult result = functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { - if (isa<emitc::LiteralOp>(op) || - isa<emitc::ExpressionOp>(op->getParentOp()) || + if (isa<emitc::ExpressionOp>(op->getParentOp()) || (isa<emitc::ExpressionOp>(op) && shouldBeInlined(cast<emitc::ExpressionOp>(op)))) return WalkResult::skip(); @@ -1001,7 +992,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter, // trailing semicolon is handled within the printOperation function. bool trailingSemicolon = !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp, - emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op); + emitc::IfOp, emitc::VerbatimOp>(op); if (failed(emitter.emitOperation( op, /*trailingSemicolon=*/trailingSemicolon))) @@ -1134,20 +1125,18 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { return out; } +void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) { + if (!valueMapper.count(value)) + valueMapper.insert(value, str.str()); +} + /// Return the existing or a new name for a Value. StringRef CppEmitter::getOrCreateName(Value val) { - if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp())) - return literal.getValue(); if (!valueMapper.count(val)) { - if (auto subscript = - dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) { - valueMapper.insert(val, getSubscriptName(subscript)); - } else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>( - val.getDefiningOp())) { - valueMapper.insert(val, getGlobal.getName().str()); - } else { - valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); - } + assert(!hasDeferredEmission(val.getDefiningOp()) && + "cacheDeferredOpResult should have been called on this value, " + "update the emitOperation function."); + valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); } return *valueMapper.begin(val); } @@ -1341,9 +1330,6 @@ LogicalResult CppEmitter::emitOperand(Value value) { if (expressionOp && shouldBeInlined(expressionOp)) return emitExpression(expressionOp); - auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp()); - if (!literalOp && !hasValueInScope(value)) - return failure(); os << getOrCreateName(value); return success(); } @@ -1399,7 +1385,7 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, bool trailingSemicolon) { - if (isa<emitc::SubscriptOp>(result.getDefiningOp())) + if (hasDeferredEmission(result.getDefiningOp())) return success(); if (hasValueInScope(result)) { return result.getDefiningOp()->emitError( @@ -1498,16 +1484,27 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, - emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp, - emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, - emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, - emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp, - emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( + emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, + emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, + emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, + emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, + emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case<func::CallOp, func::FuncOp, func::ReturnOp>( [&](auto op) { return printOperation(*this, op); }) - .Case<emitc::LiteralOp>([&](auto op) { return success(); }) + .Case<emitc::GetGlobalOp>([&](auto op) { + cacheDeferredOpResult(op.getResult(), op.getName()); + return success(); + }) + .Case<emitc::LiteralOp>([&](auto op) { + cacheDeferredOpResult(op.getResult(), op.getValue()); + return success(); + }) + .Case<emitc::SubscriptOp>([&](auto op) { + cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); + return success(); + }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); @@ -1515,7 +1512,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (failed(status)) return failure(); - if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op)) + if (hasDeferredEmission(&op)) return success(); if (getEmittedExpression() || |