aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp8
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp2
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp20
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp69
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp64
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp16
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp55
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp28
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp122
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp14
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp28
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp111
-rw-r--r--mlir/lib/Pass/Pass.cpp2
-rw-r--r--mlir/lib/Support/TypeID.cpp3
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp13
-rw-r--r--mlir/lib/Transforms/Utils/Inliner.cpp33
23 files changed, 188 insertions, 443 deletions
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 51fa773..fb5649e 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#define DEBUG_TYPE "constant-propagation"
@@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const {
LogicalResult SparseConstantPropagation::visitOperation(
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) {
- LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
+ LDBG() << "SCP: Visiting operation: " << *op;
// Don't try to simulate the results of a region operation as we can't
// guarantee that folding will be out-of-place. We don't allow in-place
@@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation(
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
- LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
+ LDBG() << "Folded to constant: " << attr;
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "Folded to value: " << cast<Value>(foldResult) << "\n");
+ LDBG() << "Folded to value: " << cast<Value>(foldResult);
AbstractSparseForwardDataFlowAnalysis::join(
lattice, *getLatticeElement(cast<Value>(foldResult)));
}
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 197f97f..509f520 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,7 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG() << "Dumping liveness state for op";
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 176d53e..16f7033 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -14,7 +14,7 @@
#include "llvm/ADT/iterator.h"
#include "llvm/Config/abi-breaking.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "dataflow"
@@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent,
(void)inserted;
DATAFLOW_DEBUG({
if (inserted) {
- llvm::dbgs() << "Creating dependency between " << debugName << " of "
- << anchor << "\nand " << debugName << " on " << dependent
- << "\n";
+ LDBG() << "Creating dependency between " << debugName << " of " << anchor
+ << "\nand " << debugName << " on " << dependent;
}
});
}
@@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
// Initialize the analyses.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
- DATAFLOW_DEBUG(llvm::dbgs()
- << "Priming analysis: " << analysis.debugName << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
if (failed(analysis.initialize(top)))
return failure();
}
@@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
auto [point, analysis] = worklist.front();
worklist.pop();
- DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
- << "' on: " << point << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName
+ << "' on: " << point);
if (failed(analysis->visit(point)))
return failure();
}
@@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
assert(isRunning &&
"DataFlowSolver is not running, should not use propagateIfChanged");
if (changed == ChangeResult::Change) {
- DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
- << " of " << state->anchor << "\n"
- << "Value: " << *state << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName
+ << " of " << state->anchor << "\n"
+ << "Value: " << *state);
state->onUpdate(this);
}
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4307bc6..17a79e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1070,39 +1070,6 @@ public:
}
};
-class VectorExtractElementOpConversion
- : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
-public:
- using ConvertOpToLLVMPattern<
- vector::ExtractElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getSourceVectorType();
- auto llvmType = typeConverter->convertType(vectorType.getElementType());
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1206,39 +1173,6 @@ public:
}
};
-class VectorInsertElementOpConversion
- : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
-public:
- using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter->convertType(vectorType);
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index b1af5f0..508f4e2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 986eae3..a4be7d4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -335,63 +335,6 @@ struct VectorInsertOpConvert final
}
};
-struct VectorExtractElementOpConvert final
- : public OpConversionPattern<vector::ExtractElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(extractOp.getType());
- if (!resultType)
- return failure();
-
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, resultType, adaptor.getVector(),
- rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
- else
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
-struct VectorInsertElementOpConvert final
- : public OpConversionPattern<vector::InsertElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type vectorType = getTypeConverter()->convertType(insertOp.getType());
- if (!vectorType)
- return failure();
-
- if (isa<spirv::ScalarType>(vectorType)) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(),
- cstPos.getSExtValue());
- else
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorToElementOpConvert, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index 3c00b32..6265f46 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -15,13 +15,13 @@
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
using namespace mlir::affine;
#define DEBUG_TYPE "decompose-affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
/// Count the number of loops surrounding `operand` such that operand could be
/// hoisted above.
@@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(
op, "only add or mul binary expr can be reassociated");
- LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
+ LDBG() << "Start decomposeIntoFinerGrainedOps: " << op;
// 2. Iteratively extract the RHS subexpressions while the top-level binary
// expr kind remains the same.
@@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
subExpressions.push_back(remainingExp);
- LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
+ LDBG() << "--terminal: " << subExpressions.back();
break;
}
subExpressions.push_back(currentBinExpr.getRHS());
- LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
+ LDBG() << "--subExpr: " << subExpressions.back();
remainingExp = currentBinExpr.getLHS();
}
@@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
return getMaxSymbol(e1) < getMaxSymbol(e2);
});
- LLVM_DEBUG(
- llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
- llvm::dbgs() << "\n");
+ LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions);
// 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
auto s0 = getAffineSymbolExpr(0, ctx);
@@ -162,7 +160,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
Value tmp = createSubApply(rewriter, op, subExpressions[i]);
current = AffineApplyOp::create(rewriter, op.getLoc(), binMap,
ValueRange{current, tmp});
- LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
+ LDBG() << "--reassociate into: " << current;
}
// 5. Replace original op.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index 8493b60..2521512 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -19,11 +19,10 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "affine-min-max"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::affine;
@@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
ValueRange operands = affineOp.getOperands();
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
- LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+ LDBG() << "analyzing value: `" << affineOp;
// Create a `Variable` list with values corresponding to each of the results
// in the affine affineMap.
@@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});
- LLVM_DEBUG({
- DBGS() << "- constructed variables are: "
- << llvm::interleaved_array(llvm::map_range(
- variables, [](const Variable &v) { return v.getMap(); }))
- << "`\n";
- });
+ LDBG() << "- constructed variables are: "
+ << llvm::interleaved_array(llvm::map_range(
+ variables, [](const Variable &v) { return v.getMap(); }));
// Get the comparison operation.
ComparisonOperator cmpOp =
@@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Initialize the bound.
Variable *bound = &v;
- LLVM_DEBUG({
- DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
- << "`\n";
- });
+ LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
// Check against the other variables.
for (size_t j = i + 1; j < variables.size(); ++j) {
@@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Get the bound of the equivalence class or itself.
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
- LLVM_DEBUG({
- DBGS() << "- comparing with variable: #" << jEqClass
- << ", with map: " << nv->getMap() << "\n";
- });
+ LDBG() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap();
// Compare the variables.
FailureOr<bool> cmpResult =
@@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// The variables cannot be compared.
if (failed(cmpResult)) {
- LLVM_DEBUG({
- DBGS() << "-- classes: #" << i << ", #" << jEqClass
- << " cannot be merged\n";
- });
+ LDBG() << "-- classes: #" << i << ", #" << jEqClass
+ << " cannot be merged";
continue;
}
// Join the equivalent classes and update the bound if necessary.
- LLVM_DEBUG({
- DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
- << ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
- });
+ LDBG() << "-- merging classes: #" << i << ", #" << jEqClass
+ << ", is cmp(lhs, rhs): " << *cmpResult << "`";
if (*cmpResult) {
boundedClasses.join(eqClass, jEqClass);
} else {
@@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Return if there's no simplification.
if (bounds.size() >= affineMap.getNumResults()) {
- LLVM_DEBUG(
- { DBGS() << "- the affine operation couldn't get simplified\n"; });
+ LDBG() << "- the affine operation couldn't get simplified";
return false;
}
@@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));
- LLVM_DEBUG({
- DBGS() << "- starting from map: " << affineMap << "\n";
- DBGS() << "- creating new map with: \n";
- DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
- DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
- DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
- });
+ LDBG() << "- starting from map: " << affineMap;
+ LDBG() << "- creating new map with:";
+ LDBG() << "--- dims: " << affineMap.getNumDims();
+ LDBG() << "--- syms: " << affineMap.getNumSymbols();
+ LDBG() << "--- res: " << llvm::interleaved_array(results);
affineMap =
AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
@@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
- LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+ LDBG() << "- simplified affine op: `" << affineOp << "`";
return true;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b6617..b56a212 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
- assert(
- sourceShape.size() == limitShape.size() &&
- "expected source shape rank, and limit of the shape to have same rank");
- return llvm::all_of(
- llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
- int64_t sourceExtent = std::get<0>(it);
- int64_t limit = std::get<1>(it);
- return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent <= limit;
- });
-}
-
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
- }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
+ if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape()))) {
+ return op->emitError("expected ")
+ << expectedPackedType << " for the packed domain value, got "
+ << packedType;
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
index c926dfb..5c8c2de 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/MathExtras.h"
@@ -21,7 +22,6 @@
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
static Attribute linearId0(MLIRContext *ctx) {
return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0);
@@ -81,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
this->threadMapping =
llvm::to_vector(ArrayRef(allThreadMappings)
.take_back(this->smallestBoundingTileSizes.size()));
- LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n");
+ LDBG() << *this;
}
int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 2fe72a3..d4a3e5f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -15,14 +15,13 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
//===----------------------------------------------------------------------===//
// StructuredMatchOp
@@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
return emitSilenceableError() << "expected a Linalg op";
}
// If errors are suppressed, succeed and set all results to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
+ LDBG() << "optional nested matcher expected a Linalg op";
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
@@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
// When they are defined in this block, we additionally check if we have
// already applied the operation that defines them. If not, the
// corresponding results will be set to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
- << "\n");
+ LDBG() << "optional nested matcher failed: " << diag.getMessage();
(void)diag.silence();
SmallVector<OpOperand *> undefinedOperands;
for (OpOperand &terminatorOperand :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 277e50b..9d7f4e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index dad3526..57b610b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -932,20 +932,6 @@ struct PackOpTiling
continue;
}
- // If the dimension needs padding, it is not supported because there are
- // iterations that only write padding values to the whole tile. The
- // consumer fusion is driven by the source, so it is not possible to map
- // an empty slice to the tile.
- bool needExtraPadding =
- ShapedType::isDynamic(destDimSize) || !cstInnerSize ||
- destDimSize * cstInnerSize.value() != srcDimSize;
- // Prioritize the case that the op already says that it does not need
- // padding.
- if (!packOp.getPaddingValue())
- needExtraPadding = false;
- if (needExtraPadding)
- return failure();
-
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0170837..793eec7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1913,14 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());
- ReifiedRankedShapedTypeDims reifiedRetShapes;
- LogicalResult status =
- cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
- .reifyResultShapes(rewriter, reifiedRetShapes);
- if (status.failed()) {
- LDBG() << "Unable to reify result shapes of " << unpackOp;
- return failure();
- }
Location loc = unpackOp->getLoc();
auto padValue = arith::ConstantOp::create(
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 106c3b4..cce80db 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
- // We only support static sizes.
- if (isa<Value>(opSize)) {
- return failure();
- }
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ecd93ff..3cafb19 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+static void printInitializationList(OpAsmPrinter &parser,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ parser << prefix << '(';
+ llvm::interleaveComma(
+ llvm::zip(blocksArgs, initializers), parser,
+ [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+ parser << ")";
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
@@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
- auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
- // Create a i1 tensor type for the boolean condition.
- Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+
+ if (parser.parseOperand(cond))
return failure();
- // Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types))
+
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+ // Parse the optional block arguments
+ OptionalParseResult listResult =
+ parser.parseOptionalAssignmentList(regionArgs, operands);
+ if (listResult.has_value() && failed(listResult.value()))
return failure();
+
+ // Parse a colon.
+ if (failed(parser.parseColon()))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Parse the type of the condition operand
+ Type condType;
+ if (failed(parser.parseType(condType)))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Resolve operand with provided type
+ if (failed(parser.resolveOperand(cond, condType, result.operands)))
+ return failure();
+
+ // Parse optional block arg types
+ if (listResult.has_value()) {
+ FunctionType functionType;
+
+ if (failed(parser.parseType(functionType)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected list of types for block arguments "
+ << "followed by arrow type and list of return types";
+
+ result.addTypes(functionType.getResults());
+
+ if (functionType.getNumInputs() != operands.size()) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected as many input types as operands "
+ << "(expected " << operands.size() << " got "
+ << functionType.getNumInputs() << ")";
+ }
+
+ // Resolve input operands.
+ if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+ } else {
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ }
+
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
-
p << " " << getCondition();
- if (!getResults().empty()) {
- p << " -> (" << getResultTypes() << ")";
- // Print yield explicitly if the op defines values.
- printBlockTerminators = true;
+
+ printInitializationList(p, getThenGraph().front().getArguments(),
+ getInputList(), " ");
+ p << " : ";
+ p << getCondition().getType();
+
+ if (!getInputList().empty()) {
+ p << " (";
+ llvm::interleaveComma(getInputList().getTypes(), p);
+ p << ")";
}
- p << ' ';
- p.printRegion(getThenGraph(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printArrowTypeList(getResultTypes());
+ p << " ";
+
+ p.printRegion(getThenGraph());
// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
- p.printRegion(elseRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printRegion(elseRegion);
}
p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
-static void printInitializationList(OpAsmPrinter &parser,
- Block::BlockArgListType blocksArgs,
- ValueRange initializers,
- StringRef prefix = "") {
- assert(blocksArgs.size() == initializers.size() &&
- "expected same length of arguments and initializers");
- if (initializers.empty())
- return;
-
- parser << prefix << '(';
- llvm::interleaveComma(
- llvm::zip(blocksArgs, initializers), parser,
- [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
- parser << ")";
-}
-
void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 32b5fb6..8ec7765 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
// })
//
// Simplified:
- // %0 = tosa.cond_if %arg2 {
- // tosa.yield %arg0
+ // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
// } else {
- // tosa.yield %arg1
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
// }
- //
- // Unfortunately, the simplified syntax does not encapsulate values
- // used in then/else regions (see 'simplified' example above), so it
- // must be rewritten to use the generic syntax in order to be conformant
- // to the specification.
+
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
}
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index c0d20d4..14a4fdf 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -21,17 +21,8 @@
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "transform-dialect"
-#define DEBUG_TYPE_FULL "transform-dialect-full"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
-#ifndef NDEBUG
-#define FULL_LDBG(X) \
- DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL)
-#else
-#define FULL_LDBG(X) \
- for (bool _c = false; _c; _c = false) \
- ::llvm::nulls()
-#endif
+#define FULL_LDBG() LDBG(4)
using namespace mlir;
@@ -818,16 +809,14 @@ void transform::TransformState::compactOpHandles() {
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) {
- LLVM_DEBUG({
- DBGS() << "applying: ";
- transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "applying: "
+ << OpWithFlags(transform, OpPrintingFlags().skipRegions());
FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
auto printOnFailureRAII = llvm::make_scope_exit([this] {
(void)this;
- LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
- llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
+ LDBG() << "Failing Top-level payload:\n"
+ << OpWithFlags(getTopLevel(),
+ OpPrintingFlags().printGenericOpForm());
});
// Set current transform op.
@@ -995,8 +984,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
printOnFailureRAII.release();
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
- DBGS() << "Top-level payload:\n";
- getTopLevel()->print(llvm::dbgs());
+ LDBG() << "Top-level payload:\n" << *getTopLevel();
});
return result;
}
@@ -1273,7 +1261,7 @@ void transform::TrackingListener::notifyMatchFailure(
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
- DBGS() << "Match Failure : " << diag.str();
+ LDBG() << "Match Failure : " << diag.str();
});
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bce358d..8789f55 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
CanonicalizeContractAdd<arith::AddFOp>>(context);
}
-//===----------------------------------------------------------------------===//
-// ExtractElementOp
-//===----------------------------------------------------------------------===//
-
-void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
-void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
- Value source) {
- result.addOperands({source});
- result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
-}
-
-LogicalResult vector::ExtractElementOp::verify() {
- VectorType vectorType = getSourceVectorType();
- if (vectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (vectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here now.
- if (!adaptor.getPosition())
- return {};
-
- // Fold extractelement (splat X) -> X.
- if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
- return splat.getInput();
-
- // Fold extractelement(broadcast(X)) -> X.
- if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
- return broadcast.getSource();
-
- auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!pos || !src)
- return {};
-
- auto srcElements = src.getValues<Attribute>();
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= srcElements.size())
- return {};
-
- return srcElements[posIdx];
-}
-
// Returns `true` if `index` is either within [0, maxIndex) or equal to
// `poisonValue`.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
@@ -3184,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// InsertElementOp
-//===----------------------------------------------------------------------===//
-
-void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
-}
-
-void InsertElementOp::build(OpBuilder &builder, OperationState &result,
- Value source, Value dest) {
- build(builder, result, source, dest, {});
-}
-
-LogicalResult InsertElementOp::verify() {
- auto dstVectorType = getDestVectorType();
- if (dstVectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (dstVectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here.
- if (!adaptor.getPosition())
- return {};
-
- auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
- auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!src || !dst || !pos)
- return {};
-
- if (src.getType() != getDestVectorType().getElementType())
- return {};
-
- auto dstElements = dst.getValues<Attribute>();
-
- SmallVector<Attribute> results(dstElements);
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= results.size())
- return {};
- results[posIdx] = src;
-
- return DenseElementsAttr::get(getDestVectorType(), results);
-}
-
-//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0db9808..7094c8e 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) {
if (failed(initialize(context, impl->initializationGeneration + 1)))
return failure();
initializationKey = newInitKey;
- pipelineKey = pipelineInitializationKey;
+ pipelineInitializationKey = pipelineKey;
}
// Construct a top level analysis manager for the pipeline.
diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp
index 01ad910..304253c 100644
--- a/mlir/lib/Support/TypeID.cpp
+++ b/mlir/lib/Support/TypeID.cpp
@@ -27,9 +27,6 @@ namespace {
struct ImplicitTypeIDRegistry {
/// Lookup or insert a TypeID for the given type name.
TypeID lookupOrInsert(StringRef typeName) {
- LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert("
- << typeName << ")\n");
-
// Perform a heuristic check to see if this type is in an anonymous
// namespace. String equality is not valid for anonymous types, so we try to
// abort whenever we see them.
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 58e5353..a8a2b2e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -446,6 +446,19 @@ LogicalResult Serializer::processType(Location loc, Type type,
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
+
+ // Map unsigned integer types to singless integer types.
+ // This is needed otherwise the generated spirv assembly will contain
+ // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
+ // such module fails validation. Indeed at MLIR level the two types are
+ // different and lookup in the cache below misses.
+ // Note: This conversion needs to happen here before the type is looked up in
+ // the cache.
+ if (type.isUnsignedInteger()) {
+ type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+
typeID = getTypeID(type);
if (typeID)
return success();
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index b639e87f..26c965c 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -21,7 +21,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "inlining"
@@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
// InlinerInterfaceImpl
//===----------------------------------------------------------------------===//
-#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
-#endif
/// Return true if the specified `inlineHistoryID` indicates an inline history
/// that already includes `node`.
@@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
LLVM_DEBUG({
- llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
+ LDBG() << "* Inliner: Initial calls in SCC are: {";
for (unsigned i = 0, e = calls.size(); i < e; ++i)
- llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
- llvm::dbgs() << "}\n";
+ LDBG() << " " << i << ". " << calls[i].call << ",";
+ LDBG() << "}";
});
// Try to inline each of the call operations. Don't cache the end iterator
@@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
- llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Inlining call: " << i << ". " << call;
else
- llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Not inlining call: " << i << ". " << call;
});
if (!doInline)
continue;
@@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
- LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
+ LDBG() << "** Failed to inline";
continue;
}
inlinedAnyCalls = true;
@@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
auto historyToString = [](InlineHistoryT h) {
return h.has_value() ? std::to_string(*h) : "root";
};
- (void)historyToString;
- LLVM_DEBUG(llvm::dbgs()
- << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
- << getNodeName(call) << ", " << historyToString(inlineHistoryID)
- << "]\n");
+ LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
+ << getNodeName(call) << ", " << historyToString(inlineHistoryID)
+ << "]";
for (unsigned k = prevSize; k != calls.size(); ++k) {
callHistory.push_back(newInlineHistoryID);
- LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
- << "}\n with historyID = " << newInlineHistoryID
- << ", added due to inlining of\n call {" << call
- << "}\n with historyID = "
- << historyToString(inlineHistoryID) << "\n");
+ LDBG() << "* new call " << k << " {" << calls[k].call
+ << "}\n with historyID = " << newInlineHistoryID
+ << ", added due to inlining of\n call {" << call
+ << "}\n with historyID = " << historyToString(inlineHistoryID);
}
// If the inlining was successful, Merge the new uses into the source node.