aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp8
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp2
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp20
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp16
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp55
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp122
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp14
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp18
-rw-r--r--mlir/lib/IR/PatternMatch.cpp10
-rw-r--r--mlir/lib/Support/TypeID.cpp3
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp1
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp164
-rw-r--r--mlir/lib/Transforms/Utils/Inliner.cpp33
17 files changed, 300 insertions, 190 deletions
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 51fa773..fb5649e 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#define DEBUG_TYPE "constant-propagation"
@@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const {
LogicalResult SparseConstantPropagation::visitOperation(
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) {
- LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
+ LDBG() << "SCP: Visiting operation: " << *op;
// Don't try to simulate the results of a region operation as we can't
// guarantee that folding will be out-of-place. We don't allow in-place
@@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation(
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
- LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
+ LDBG() << "Folded to constant: " << attr;
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "Folded to value: " << cast<Value>(foldResult) << "\n");
+ LDBG() << "Folded to value: " << cast<Value>(foldResult);
AbstractSparseForwardDataFlowAnalysis::join(
lattice, *getLatticeElement(cast<Value>(foldResult)));
}
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 197f97f..509f520 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,7 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG() << "Dumping liveness state for op";
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 176d53e..16f7033 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -14,7 +14,7 @@
#include "llvm/ADT/iterator.h"
#include "llvm/Config/abi-breaking.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "dataflow"
@@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent,
(void)inserted;
DATAFLOW_DEBUG({
if (inserted) {
- llvm::dbgs() << "Creating dependency between " << debugName << " of "
- << anchor << "\nand " << debugName << " on " << dependent
- << "\n";
+ LDBG() << "Creating dependency between " << debugName << " of " << anchor
+ << "\nand " << debugName << " on " << dependent;
}
});
}
@@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
// Initialize the analyses.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
- DATAFLOW_DEBUG(llvm::dbgs()
- << "Priming analysis: " << analysis.debugName << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
if (failed(analysis.initialize(top)))
return failure();
}
@@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
auto [point, analysis] = worklist.front();
worklist.pop();
- DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
- << "' on: " << point << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName
+ << "' on: " << point);
if (failed(analysis->visit(point)))
return failure();
}
@@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
assert(isRunning &&
"DataFlowSolver is not running, should not use propagateIfChanged");
if (changed == ChangeResult::Change) {
- DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
- << " of " << state->anchor << "\n"
- << "Value: " << *state << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName
+ << " of " << state->anchor << "\n"
+ << "Value: " << *state);
state->onUpdate(this);
}
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index 3c00b32..6265f46 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -15,13 +15,13 @@
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
using namespace mlir::affine;
#define DEBUG_TYPE "decompose-affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
/// Count the number of loops surrounding `operand` such that operand could be
/// hoisted above.
@@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(
op, "only add or mul binary expr can be reassociated");
- LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
+ LDBG() << "Start decomposeIntoFinerGrainedOps: " << op;
// 2. Iteratively extract the RHS subexpressions while the top-level binary
// expr kind remains the same.
@@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
subExpressions.push_back(remainingExp);
- LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
+ LDBG() << "--terminal: " << subExpressions.back();
break;
}
subExpressions.push_back(currentBinExpr.getRHS());
- LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
+ LDBG() << "--subExpr: " << subExpressions.back();
remainingExp = currentBinExpr.getLHS();
}
@@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
return getMaxSymbol(e1) < getMaxSymbol(e2);
});
- LLVM_DEBUG(
- llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
- llvm::dbgs() << "\n");
+ LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions);
// 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
auto s0 = getAffineSymbolExpr(0, ctx);
@@ -162,7 +160,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
Value tmp = createSubApply(rewriter, op, subExpressions[i]);
current = AffineApplyOp::create(rewriter, op.getLoc(), binMap,
ValueRange{current, tmp});
- LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
+ LDBG() << "--reassociate into: " << current;
}
// 5. Replace original op.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index 8493b60..2521512 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -19,11 +19,10 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "affine-min-max"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::affine;
@@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
ValueRange operands = affineOp.getOperands();
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
- LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+ LDBG() << "analyzing value: `" << affineOp;
// Create a `Variable` list with values corresponding to each of the results
// in the affine affineMap.
@@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});
- LLVM_DEBUG({
- DBGS() << "- constructed variables are: "
- << llvm::interleaved_array(llvm::map_range(
- variables, [](const Variable &v) { return v.getMap(); }))
- << "`\n";
- });
+ LDBG() << "- constructed variables are: "
+ << llvm::interleaved_array(llvm::map_range(
+ variables, [](const Variable &v) { return v.getMap(); }));
// Get the comparison operation.
ComparisonOperator cmpOp =
@@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Initialize the bound.
Variable *bound = &v;
- LLVM_DEBUG({
- DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
- << "`\n";
- });
+ LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
// Check against the other variables.
for (size_t j = i + 1; j < variables.size(); ++j) {
@@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Get the bound of the equivalence class or itself.
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
- LLVM_DEBUG({
- DBGS() << "- comparing with variable: #" << jEqClass
- << ", with map: " << nv->getMap() << "\n";
- });
+ LDBG() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap();
// Compare the variables.
FailureOr<bool> cmpResult =
@@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// The variables cannot be compared.
if (failed(cmpResult)) {
- LLVM_DEBUG({
- DBGS() << "-- classes: #" << i << ", #" << jEqClass
- << " cannot be merged\n";
- });
+ LDBG() << "-- classes: #" << i << ", #" << jEqClass
+ << " cannot be merged";
continue;
}
// Join the equivalent classes and update the bound if necessary.
- LLVM_DEBUG({
- DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
- << ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
- });
+ LDBG() << "-- merging classes: #" << i << ", #" << jEqClass
+ << ", is cmp(lhs, rhs): " << *cmpResult << "`";
if (*cmpResult) {
boundedClasses.join(eqClass, jEqClass);
} else {
@@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Return if there's no simplification.
if (bounds.size() >= affineMap.getNumResults()) {
- LLVM_DEBUG(
- { DBGS() << "- the affine operation couldn't get simplified\n"; });
+ LDBG() << "- the affine operation couldn't get simplified";
return false;
}
@@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));
- LLVM_DEBUG({
- DBGS() << "- starting from map: " << affineMap << "\n";
- DBGS() << "- creating new map with: \n";
- DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
- DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
- DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
- });
+ LDBG() << "- starting from map: " << affineMap;
+ LDBG() << "- creating new map with:";
+ LDBG() << "--- dims: " << affineMap.getNumDims();
+ LDBG() << "--- syms: " << affineMap.getNumSymbols();
+ LDBG() << "--- res: " << llvm::interleaved_array(results);
affineMap =
AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
@@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
- LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+ LDBG() << "- simplified affine op: `" << affineOp << "`";
return true;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
index c926dfb..5c8c2de 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/MathExtras.h"
@@ -21,7 +22,6 @@
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
static Attribute linearId0(MLIRContext *ctx) {
return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0);
@@ -81,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
this->threadMapping =
llvm::to_vector(ArrayRef(allThreadMappings)
.take_back(this->smallestBoundingTileSizes.size()));
- LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n");
+ LDBG() << *this;
}
int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 2fe72a3..d4a3e5f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -15,14 +15,13 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
//===----------------------------------------------------------------------===//
// StructuredMatchOp
@@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
return emitSilenceableError() << "expected a Linalg op";
}
// If errors are suppressed, succeed and set all results to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
+ LDBG() << "optional nested matcher expected a Linalg op";
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
@@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
// When they are defined in this block, we additionally check if we have
// already applied the operation that defines them. If not, the
// corresponding results will be set to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
- << "\n");
+ LDBG() << "optional nested matcher failed: " << diag.getMessage();
(void)diag.silence();
SmallVector<OpOperand *> undefinedOperands;
for (OpOperand &terminatorOperand :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0170837..793eec7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1913,14 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());
- ReifiedRankedShapedTypeDims reifiedRetShapes;
- LogicalResult status =
- cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
- .reifyResultShapes(rewriter, reifiedRetShapes);
- if (status.failed()) {
- LDBG() << "Unable to reify result shapes of " << unpackOp;
- return failure();
- }
Location loc = unpackOp->getLoc();
auto padValue = arith::ConstantOp::create(
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 106c3b4..cce80db 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
- // We only support static sizes.
- if (isa<Value>(opSize)) {
- return failure();
- }
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ecd93ff..3cafb19 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+static void printInitializationList(OpAsmPrinter &parser,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ parser << prefix << '(';
+ llvm::interleaveComma(
+ llvm::zip(blocksArgs, initializers), parser,
+ [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+ parser << ")";
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
@@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
- auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
- // Create a i1 tensor type for the boolean condition.
- Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+
+ if (parser.parseOperand(cond))
return failure();
- // Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types))
+
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+ // Parse the optional block arguments
+ OptionalParseResult listResult =
+ parser.parseOptionalAssignmentList(regionArgs, operands);
+ if (listResult.has_value() && failed(listResult.value()))
return failure();
+
+ // Parse a colon.
+ if (failed(parser.parseColon()))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Parse the type of the condition operand
+ Type condType;
+ if (failed(parser.parseType(condType)))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Resolve operand with provided type
+ if (failed(parser.resolveOperand(cond, condType, result.operands)))
+ return failure();
+
+ // Parse optional block arg types
+ if (listResult.has_value()) {
+ FunctionType functionType;
+
+ if (failed(parser.parseType(functionType)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected list of types for block arguments "
+ << "followed by arrow type and list of return types";
+
+ result.addTypes(functionType.getResults());
+
+ if (functionType.getNumInputs() != operands.size()) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected as many input types as operands "
+ << "(expected " << operands.size() << " got "
+ << functionType.getNumInputs() << ")";
+ }
+
+ // Resolve input operands.
+ if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+ } else {
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ }
+
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
-
p << " " << getCondition();
- if (!getResults().empty()) {
- p << " -> (" << getResultTypes() << ")";
- // Print yield explicitly if the op defines values.
- printBlockTerminators = true;
+
+ printInitializationList(p, getThenGraph().front().getArguments(),
+ getInputList(), " ");
+ p << " : ";
+ p << getCondition().getType();
+
+ if (!getInputList().empty()) {
+ p << " (";
+ llvm::interleaveComma(getInputList().getTypes(), p);
+ p << ")";
}
- p << ' ';
- p.printRegion(getThenGraph(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printArrowTypeList(getResultTypes());
+ p << " ";
+
+ p.printRegion(getThenGraph());
// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
- p.printRegion(elseRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printRegion(elseRegion);
}
p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
-static void printInitializationList(OpAsmPrinter &parser,
- Block::BlockArgListType blocksArgs,
- ValueRange initializers,
- StringRef prefix = "") {
- assert(blocksArgs.size() == initializers.size() &&
- "expected same length of arguments and initializers");
- if (initializers.empty())
- return;
-
- parser << prefix << '(';
- llvm::interleaveComma(
- llvm::zip(blocksArgs, initializers), parser,
- [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
- parser << ")";
-}
-
void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 32b5fb6..8ec7765 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
// })
//
// Simplified:
- // %0 = tosa.cond_if %arg2 {
- // tosa.yield %arg0
+ // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
// } else {
- // tosa.yield %arg1
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
// }
- //
- // Unfortunately, the simplified syntax does not encapsulate values
- // used in then/else regions (see 'simplified' example above), so it
- // must be rewritten to use the generic syntax in order to be conformant
- // to the specification.
+
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
}
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index c0d20d4..e297f7c 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -23,7 +23,6 @@
#define DEBUG_TYPE "transform-dialect"
#define DEBUG_TYPE_FULL "transform-dialect-full"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
#ifndef NDEBUG
#define FULL_LDBG(X) \
DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL)
@@ -818,16 +817,14 @@ void transform::TransformState::compactOpHandles() {
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) {
- LLVM_DEBUG({
- DBGS() << "applying: ";
- transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "applying: "
+ << OpWithFlags(transform, OpPrintingFlags().skipRegions());
FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
auto printOnFailureRAII = llvm::make_scope_exit([this] {
(void)this;
- LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
- llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
+ LDBG() << "Failing Top-level payload:\n"
+ << OpWithFlags(getTopLevel(),
+ OpPrintingFlags().printGenericOpForm());
});
// Set current transform op.
@@ -995,8 +992,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
printOnFailureRAII.release();
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
- DBGS() << "Top-level payload:\n";
- getTopLevel()->print(llvm::dbgs());
+ LDBG() << "Top-level payload:\n" << *getTopLevel();
});
return result;
}
@@ -1273,7 +1269,7 @@ void transform::TrackingListener::notifyMatchFailure(
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
- DBGS() << "Match Failure : " << diag.str();
+ LDBG() << "Match Failure : " << diag.str();
});
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5c98417..9332f55 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -156,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
// Fast path: If no listener is attached, the op can be dropped in one go.
if (!rewriteListener) {
op->erase();
@@ -320,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
moveOpBefore(&source->front(), dest, before);
}
+ // If the current insertion point is within the source block, adjust the
+ // insertion point to the destination block.
+ if (getInsertionBlock() == source)
+ setInsertionPoint(dest, getInsertionPoint());
+
// Erase the source block.
assert(source->empty() && "expected 'source' to be empty");
eraseBlock(source);
diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp
index 01ad910..304253c 100644
--- a/mlir/lib/Support/TypeID.cpp
+++ b/mlir/lib/Support/TypeID.cpp
@@ -27,9 +27,6 @@ namespace {
struct ImplicitTypeIDRegistry {
/// Lookup or insert a TypeID for the given type name.
TypeID lookupOrInsert(StringRef typeName) {
- LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert("
- << typeName << ")\n");
-
// Perform a heuristic check to see if this type is in an anonymous
// namespace. String equality is not valid for anonymous types, so we try to
// abort whenever we see them.
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 5650de2..4ccb83f 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -60,7 +60,6 @@
#include <vector>
#define DEBUG_TYPE "remove-dead-values"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
namespace mlir {
#define GEN_PASS_DEF_REMOVEDEADVALUES
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7502dc6..08803e0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -508,9 +509,11 @@ private:
class MoveBlockRewrite : public BlockRewrite {
public:
MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
- Region *region, Block *insertBeforeBlock)
- : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
- insertBeforeBlock(insertBeforeBlock) {}
+ Region *previousRegion, Region::iterator previousIt)
+ : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
+ region(previousRegion),
+ insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
+ : &*previousIt) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::MoveBlock;
@@ -617,9 +620,12 @@ protected:
class MoveOperationRewrite : public OperationRewrite {
public:
MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Operation *op, Block *block, Operation *insertBeforeOp)
- : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
- insertBeforeOp(insertBeforeOp) {}
+ Operation *op, OpBuilder::InsertPoint previous)
+ : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
+ block(previous.getBlock()),
+ insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
+ ? nullptr
+ : &*previous.getPoint()) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::MoveOperation;
@@ -1588,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
+ // If no previous insertion point is provided, the op used to be detached.
+ bool wasDetached = !previous.isSet();
LLVM_DEBUG({
- logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
- << ")\n";
+ logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
+ << ")";
+ if (wasDetached)
+ logger.getOStream() << " (was detached)";
+ logger.getOStream() << "\n";
});
assert(!wasOpReplaced(op->getParentOp()) &&
"attempting to insert into a block within a replaced/erased op");
- if (!previous.isSet()) {
- // This is a newly created op.
+ if (wasDetached) {
+ // If the op was detached, it is most likely a newly created op.
+ // TODO: If the same op is inserted multiple times from a detached state,
+ // the rollback mechanism may erase the same op multiple times. This is a
+ // bug in the rollback-based dialect conversion driver.
appendRewrite<CreateOperationRewrite>(op);
patternNewOps.insert(op);
return;
}
- Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
- ? nullptr
- : &*previous.getPoint();
- appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
+
+ // The op was moved from one place to another.
+ appendRewrite<MoveOperationRewrite>(op, previous);
}
void ConversionPatternRewriterImpl::replaceOp(
@@ -1669,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- assert(!wasOpReplaced(block->getParentOp()) &&
- "attempting to insert into a region within a replaced/erased op");
+ // If no previous insertion point is provided, the block used to be detached.
+ bool wasDetached = !previous;
+ Operation *newParentOp = block->getParentOp();
LLVM_DEBUG(
{
- Operation *parent = block->getParentOp();
+ Operation *parent = newParentOp;
if (parent) {
logger.startLine() << "** Insert Block into : '" << parent->getName()
- << "'(" << parent << ")\n";
+ << "' (" << parent << ")";
} else {
logger.startLine()
- << "** Insert Block into detached Region (nullptr parent op)'\n";
+ << "** Insert Block into detached Region (nullptr parent op)";
}
+ if (wasDetached)
+ logger.getOStream() << " (was detached)";
+ logger.getOStream() << "\n";
});
+ assert(!wasOpReplaced(newParentOp) &&
+ "attempting to insert into a region within a replaced/erased op");
+ (void)newParentOp;
patternInsertedBlocks.insert(block);
- if (!previous) {
- // This is a newly created block.
+ if (wasDetached) {
+ // If the block was detached, it is most likely a newly created block.
+ // TODO: If the same block is inserted multiple times from a detached state,
+ // the rollback mechanism may erase the same block multiple times. This is a
+ // bug in the rollback-based dialect conversion driver.
appendRewrite<CreateBlockRewrite>(block);
return;
}
- Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
- appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
+
+ // The block was moved from one place to another.
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1736,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> newVals =
llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
return v ? SmallVector<Value>{v} : SmallVector<Value>();
@@ -1751,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
impl->replaceOp(op, std::move(newValues));
}
@@ -1759,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
impl->replaceOp(op, std::move(nullRepls));
}
@@ -1865,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
moveOpBefore(&source->front(), dest, before);
}
+ // If the current insertion point is within the source block, adjust the
+ // insertion point to the destination block.
+ if (getInsertionBlock() == source)
+ setInsertionPoint(dest, getInsertionPoint());
+
// Erase the source block.
eraseBlock(source);
}
@@ -1996,6 +2043,7 @@ private:
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
+ const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks);
@@ -2193,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
});
- (void)rewriterImpl;
+
+ // Clear pattern state, so that the next pattern application starts with a
+ // clean slate. (The op/block sets are populated by listener notifications.)
+ auto cleanup = llvm::make_scope_exit([&]() {
+ rewriterImpl.patternNewOps.clear();
+ rewriterImpl.patternModifiedOps.clear();
+ rewriterImpl.patternInsertedBlocks.clear();
+ });
+
+ // Upon failure, undo all changes made by the folder.
+ RewriterState curState = rewriterImpl.getCurrentState();
// Try to fold the operation.
StringRef opName = op->getName().getStringRef();
SmallVector<Value, 2> replacementValues;
SmallVector<Operation *, 2> newOps;
rewriter.setInsertionPoint(op);
+ rewriter.startOpModification(op);
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+ rewriter.cancelOpModification(op);
return failure();
}
+ rewriter.finalizeOpModification(op);
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
if (failed(legalize(newOp, rewriter))) {
@@ -2222,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
"op '" + opName +
"' folder rollback of IR modifications requested");
}
- // Legalization failed: erase all materialized constants.
- for (Operation *op : newOps)
- rewriter.eraseOp(op);
+ rewriterImpl.resetState(
+ curState, std::string(op->getName().getStringRef()) + " folder");
return failure();
}
}
- // Insert a replacement for 'op' with the folded replacement values.
- rewriter.replaceOp(op, replacementValues);
-
LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
return success();
}
@@ -2241,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ Operation *checkOp;
+ std::optional<OperationFingerPrint> topLevelFingerPrint;
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // The op may be getting erased, so we have to check the parent op.
+ // (In rare cases, a pattern may even erase the parent op, which will cause
+ // a crash here. Expensive checks are "best effort".) Skip the check if the
+ // op does not have a parent op.
+ if ((checkOp = op->getParentOp())) {
+ if (!op->getContext()->isMultithreadingEnabled()) {
+ topLevelFingerPrint = OperationFingerPrint(checkOp);
+ } else {
+ // Another thread may be modifying a sibling operation. Therefore, the
+ // fingerprinting mechanism of the parent op works only in
+ // single-threaded mode.
+ LLVM_DEBUG({
+ rewriterImpl.logger.startLine()
+ << "WARNING: Multi-threadeding is enabled. Some dialect "
+ "conversion expensive checks are skipped in multithreading "
+ "mode!\n";
+ });
+ }
+ }
+ }
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
bool canApply = canApplyPattern(op, pattern, rewriter);
@@ -2253,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Returning "failure" after modifying IR is not allowed.
+ if (checkOp) {
+ OperationFingerPrint fingerPrintAfterPattern(checkOp);
+ if (fingerPrintAfterPattern != *topLevelFingerPrint)
+ llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+ "' returned failure but IR did change");
+ }
+ }
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2281,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, newOps,
+ auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2324,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const SetVector<Operation *> &newOps,
+ const RewriterState &curState, const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
@@ -2340,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
};
if (!replacedRoot() && !updatedRootInPlace())
- llvm::report_fatal_error("expected pattern to replace the root operation");
+ llvm::report_fatal_error(
+ "expected pattern to replace the root operation or modify it in place");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index b639e87f..26c965c 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -21,7 +21,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "inlining"
@@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
// InlinerInterfaceImpl
//===----------------------------------------------------------------------===//
-#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
-#endif
/// Return true if the specified `inlineHistoryID` indicates an inline history
/// that already includes `node`.
@@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
LLVM_DEBUG({
- llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
+ LDBG() << "* Inliner: Initial calls in SCC are: {";
for (unsigned i = 0, e = calls.size(); i < e; ++i)
- llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
- llvm::dbgs() << "}\n";
+ LDBG() << " " << i << ". " << calls[i].call << ",";
+ LDBG() << "}";
});
// Try to inline each of the call operations. Don't cache the end iterator
@@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
- llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Inlining call: " << i << ". " << call;
else
- llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Not inlining call: " << i << ". " << call;
});
if (!doInline)
continue;
@@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
- LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
+ LDBG() << "** Failed to inline";
continue;
}
inlinedAnyCalls = true;
@@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
auto historyToString = [](InlineHistoryT h) {
return h.has_value() ? std::to_string(*h) : "root";
};
- (void)historyToString;
- LLVM_DEBUG(llvm::dbgs()
- << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
- << getNodeName(call) << ", " << historyToString(inlineHistoryID)
- << "]\n");
+ LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
+ << getNodeName(call) << ", " << historyToString(inlineHistoryID)
+ << "]";
for (unsigned k = prevSize; k != calls.size(); ++k) {
callHistory.push_back(newInlineHistoryID);
- LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
- << "}\n with historyID = " << newInlineHistoryID
- << ", added due to inlining of\n call {" << call
- << "}\n with historyID = "
- << historyToString(inlineHistoryID) << "\n");
+ LDBG() << "* new call " << k << " {" << calls[k].call
+ << "}\n with historyID = " << newInlineHistoryID
+ << ", added due to inlining of\n call {" << call
+ << "}\n with historyID = " << historyToString(inlineHistoryID);
}
// If the inlining was successful, Merge the new uses into the source node.