aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Loops.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp4
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp12
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp7
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp11
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp24
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp22
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h17
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp67
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformDialect.cpp15
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp18
22 files changed, 116 insertions, 139 deletions
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 9196d2e..39e398b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -170,7 +170,7 @@ public:
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
- .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
+ .DefaultUnreachable("unexpected extend op!");
} else if (kind == arm_sme::CombiningKind::Sub) {
TypeSwitch<Operation *>(extOp)
.Case<arith::ExtFOp>([&](auto) {
@@ -188,7 +188,7 @@ public:
op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
op1.getAcc());
})
- .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
+ .DefaultUnreachable("unexpected extend op!");
} else {
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index c0f9132..19eba6b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -375,7 +375,7 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
os << shape.back() << 'x' << fragTy.getElementType();
os << ", \"" << fragTy.getOperand() << "\"" << '>';
})
- .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
+ .DefaultUnreachable("unexpected 'gpu' type kind");
}
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 2561f66..0a3ef7d 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -847,9 +847,7 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
*maybeMaskingAttr);
})
- .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
- llvm_unreachable("unknown mapping attribute");
- });
+ .DefaultUnreachable("unknown mapping attribute");
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index ef38027..cee943d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1096,10 +1096,8 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
Value intVal = buildMemsetValue(type.getWidth());
return LLVM::BitcastOp::create(builder, op.getLoc(), type, intVal);
})
- .Default([](Type) -> Value {
- llvm_unreachable(
- "getStored should not be called on memset to unsupported type");
- });
+ .DefaultUnreachable(
+ "getStored should not be called on memset to unsupported type");
}
template <class MemsetIntr>
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 297640c..705d07d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,9 +45,7 @@ static StringRef getTypeKeyword(Type type) {
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
.Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; })
- .Default([](Type) -> StringRef {
- llvm_unreachable("unexpected 'llvm' type kind");
- });
+ .DefaultUnreachable("unexpected 'llvm' type kind");
}
/// Prints a structure type. Keeps track of known struct names to handle self-
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 38f1a8b..42160a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -192,7 +192,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
.Case([&](affine::AffineForOp affineForOp) {
allIvs.push_back(affineForOp.getInductionVar());
})
- .Default([&](Operation *op) { assert(false && "unexpected op"); });
+ .DefaultUnreachable("unexpected op");
}
assert(linalgOp.getNumLoops() == allIvs.size() &&
"expected the number of loops and induction variables to match");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
index 00a076b..c904556 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
@@ -48,10 +48,7 @@ ElementwiseKind getKind(Operation *op) {
.Case([](SquareOp) { return ElementwiseKind::square; })
.Case([](TanhOp) { return ElementwiseKind::tanh; })
.Case([](ErfOp) { return ElementwiseKind::erf; })
- .Default([&](Operation *op) {
- llvm_unreachable("unhandled case in named to elementwise");
- return ElementwiseKind::sub;
- });
+ .DefaultUnreachable("unhandled case in named to elementwise");
}
template <typename NamedOpTy>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e9a8b25..7863c21 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1427,10 +1427,7 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
.Case([&](linalg::PoolingNchwMaxOp op) {
return std::make_tuple(0, 1, 2, 3);
})
- .Default([&](Operation *op) {
- llvm_unreachable("unexpected conv2d/pool2d operation.");
- return std::make_tuple(0, 0, 0, 0);
- });
+ .DefaultUnreachable("unexpected conv2d/pool2d operation.");
// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 3593b53..24d3722 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -604,9 +604,7 @@ static Operation *materializeTiledShape(OpBuilder &builder, Location loc,
builder, loc, valueToTile, sliceParams.offsets,
sliceParams.sizes, sliceParams.strides);
})
- .Default([](ShapedType) -> Operation * {
- llvm_unreachable("Unexpected shaped type");
- });
+ .DefaultUnreachable("Unexpected shaped type");
return sliceOp;
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 24da447..214410f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -315,7 +315,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
op, op.getType(), subViewOp.getSource(), sourceIndices,
op.getTranspose(), op.getNumTiles());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -367,7 +367,7 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -415,7 +415,7 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -482,7 +482,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
op, op.getSrc(), subViewOp.getSource(), sourceIndices,
op.getLeadDimension(), op.getTransposeAttr());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -535,7 +535,7 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
@@ -584,7 +584,7 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
- .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ .DefaultUnreachable("unexpected operation");
return success();
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5672942..fd4cabbad 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3425,10 +3425,7 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
}
llvm_unreachable("Unexpected generatee argument");
})
- .Default([&](Operation *op) {
- assert(false && "TODO: Custom name for this operation");
- return "transformed";
- });
+ .DefaultUnreachable("TODO: Custom name for this operation");
}
setNameFn(result, cliName);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 36685d3..29b770f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -2177,10 +2177,9 @@ cloneAsInsertSlices(RewriterBase &rewriter,
auto clonedOp = cloneAsInsertSlice(rewriter, op);
clonedSlices.push_back(clonedOp);
})
- .Default([&](Operation *op) {
- // Assert here assuming this has already been checked.
- assert(0 && "unexpected slice type while cloning as insert slice");
- });
+ // Assert here assuming this has already been checked.
+ .DefaultUnreachable(
+ "unexpected slice type while cloning as insert slice");
}
return clonedSlices;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c8efdf0..24c33f9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -987,7 +987,7 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
- .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
+ .DefaultUnreachable("Unhandled SPIR-V type");
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 7e9a80e..f895807 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -57,7 +57,7 @@ public:
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
- .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ .DefaultUnreachable("Unhandled type");
}
void add(Type type) { add(cast<SPIRVType>(type)); }
@@ -107,7 +107,7 @@ public:
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
- .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
+ .DefaultUnreachable("Unhandled type");
}
void add(Type type) { add(cast<SPIRVType>(type)); }
@@ -198,8 +198,7 @@ Type CompositeType::getElementType(unsigned index) const {
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
- .Default(
- [](Type) -> Type { llvm_unreachable("invalid composite type"); });
+ .DefaultUnreachable("Invalid composite type");
}
unsigned CompositeType::getNumElements() const {
@@ -207,9 +206,7 @@ unsigned CompositeType::getNumElements() const {
.Case<ArrayType, StructType, TensorArmType, VectorType>(
[](auto type) { return type.getNumElements(); })
.Case<MatrixType>([](MatrixType type) { return type.getNumColumns(); })
- .Default([](SPIRVType) -> unsigned {
- llvm_unreachable("Invalid type for number of elements query");
- });
+ .DefaultUnreachable("Invalid type for number of elements query");
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 122f61e0..88e1ab6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -622,7 +622,7 @@ static spirv::Dim convertRank(int64_t rank) {
}
static spirv::ImageFormat getImageFormat(Type elementType) {
- return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType)
+ return TypeSwitch<Type, spirv::ImageFormat>(elementType)
.Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
.Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
.Case<IntegerType>([](IntegerType intType) {
@@ -639,11 +639,7 @@ static spirv::ImageFormat getImageFormat(Type elementType) {
llvm_unreachable("Unhandled integer type!");
}
})
- .Default([](Type) {
- llvm_unreachable("Unhandled element type!");
- // We need to return something here to satisfy the type switch.
- return spirv::ImageFormat::R32f;
- });
+ .DefaultUnreachable("Unhandled element type!");
#undef BIT_WIDTH_CASE
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a1e35b8..0fc5cc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -59,7 +59,7 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
// Flattens an affine expression into a list of AffineDimExprs.
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
- explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+ explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
BitVector dims;
};
@@ -67,7 +67,7 @@ struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
// Flattens an affine expression into a list of AffineDimExprs.
struct AffineExprAdmissibleVisitor
: public AffineExprVisitor<AffineExprAdmissibleVisitor> {
- explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){};
+ explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
// We only allow AffineDimExpr on output.
void visitAddExpr(AffineBinaryOpExpr expr) {
@@ -407,7 +407,10 @@ public:
};
struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
- using OpRewritePattern::OpRewritePattern;
+ GenericOpScheduler(MLIRContext *context,
+ sparse_tensor::LoopOrderingStrategy strategy)
+ : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
+
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
PatternRewriter &rewriter) const override {
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
@@ -420,7 +423,8 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
if (linalgOp->hasAttr(sorted))
return failure();
- auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
+ // Pass strategy to IterationGraphSorter.
+ auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
bool isAdmissible = false;
AffineMap order;
// A const list of all masks that we used for iteration graph
@@ -582,6 +586,9 @@ private:
// TODO: convert more than one?
return failure();
}
+
+private:
+ sparse_tensor::LoopOrderingStrategy strategy;
};
//===----------------------------------------------------------------------===//
@@ -786,12 +793,13 @@ struct ForeachOpDemapper
} // namespace
-void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
- ReinterpretMapScope scope) {
+void mlir::populateSparseReinterpretMap(
+ RewritePatternSet &patterns, ReinterpretMapScope scope,
+ sparse_tensor::LoopOrderingStrategy strategy) {
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kGenericOnly) {
- patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
- patterns.getContext());
+ patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+ patterns.add<GenericOpScheduler>(patterns.getContext(), strategy);
}
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 153b9b1..b660e22 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -67,12 +67,13 @@ struct SparseReinterpretMap
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
scope = options.scope;
+ loopOrderingStrategy = options.loopOrderingStrategy;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
- populateSparseReinterpretMap(patterns, scope);
+ populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@@ -438,6 +439,14 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
return std::make_unique<SparseReinterpretMap>(options);
}
+std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass(
+ ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy) {
+ SparseReinterpretMapOptions options;
+ options.scope = scope;
+ options.loopOrderingStrategy = strategy;
+ return std::make_unique<SparseReinterpretMap>(options);
+}
+
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
return std::make_unique<PreSparsificationRewritePass>();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index c7e463a..73e0f3d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -100,7 +100,15 @@ AffineMap IterationGraphSorter::topoSort() {
// We always prefer a parallel loop over a reduction loop because putting
// a reduction loop early might make the loop sequence inadmissible.
auto &it = !parIt.empty() ? parIt : redIt;
- auto src = it.back();
+
+ // Select loop based on strategy.
+ unsigned src;
+ switch (strategy) {
+ case sparse_tensor::LoopOrderingStrategy::kDefault:
+ src = it.back();
+ break;
+ }
+
loopOrder.push_back(src);
it.pop_back();
// Update in-degree, and push 0-degree node into worklist.
@@ -122,8 +130,8 @@ AffineMap IterationGraphSorter::topoSort() {
return AffineMap();
}
-IterationGraphSorter
-IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
+IterationGraphSorter IterationGraphSorter::fromGenericOp(
+ linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy) {
// Must be a demapped sparse kernel.
assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
hasAnySparseOperandOrResult(genericOp) &&
@@ -140,14 +148,16 @@ IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
genericOp.getIteratorTypesArray();
return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
- std::move(iterTypes));
+ std::move(iterTypes), strategy);
}
IterationGraphSorter::IterationGraphSorter(
SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
- AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
+ AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
+ sparse_tensor::LoopOrderingStrategy strategy)
: ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
- loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
+ loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
+ strategy(strategy) {
// One map per tensor.
assert(loop2InsLvl.size() == ins.size());
// All the affine maps have the same number of dimensions (loops).
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index a6abe9e..b2a16e9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/IR/AffineMap.h"
namespace mlir {
@@ -41,9 +42,12 @@ enum class SortMask : unsigned {
class IterationGraphSorter {
public:
- /// Factory method that construct an iteration graph sorter
- /// for the given linalg.generic operation.
- static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp);
+ /// Factory method that constructs an iteration graph sorter
+ /// for the given linalg.generic operation with a specific loop ordering
+ /// strategy.
+ static IterationGraphSorter
+ fromGenericOp(linalg::GenericOp genericOp,
+ sparse_tensor::LoopOrderingStrategy strategy);
/// Returns a permutation that represents the scheduled loop order.
/// Note that the returned AffineMap could be null if the kernel
@@ -58,7 +62,9 @@ private:
IterationGraphSorter(SmallVector<Value> &&ins,
SmallVector<AffineMap> &&loop2InsLvl, Value out,
AffineMap loop2OutLvl,
- SmallVector<utils::IteratorType> &&iterTypes);
+ SmallVector<utils::IteratorType> &&iterTypes,
+ sparse_tensor::LoopOrderingStrategy strategy =
+ sparse_tensor::LoopOrderingStrategy::kDefault);
// Adds all the constraints in the given loop to level map.
void addConstraints(Value t, AffineMap loop2LvlMap);
@@ -84,6 +90,9 @@ private:
// InDegree used for topo sort.
std::vector<unsigned> inDegree;
+
+ // Loop ordering strategy.
+ sparse_tensor::LoopOrderingStrategy strategy;
};
} // namespace sparse_tensor
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 332f1a0..c51b5e9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
}
-// Returns the first declaration point prior to this operation or failure if
-// not found.
-static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
- StringRef symName) {
- ModuleOp module = op->getParentOfType<ModuleOp>();
- tosa::VariableOp varOp = nullptr;
-
- // TODO: Adopt SymbolTable trait to Varible ops.
- // Currently, the variable's definition point is searched via walk(),
- // starting from the top-level ModuleOp and stopping at the point of use. Once
- // TOSA control flow and variable extensions reach the complete state, may
- // leverage MLIR's Symbol Table functionality to look up symbol and enhance
- // the search to a TOSA specific graph traversal over the IR structure.
- module.walk([&](Operation *tempOp) {
- // Reach this op itself.
- if (tempOp == op) {
- return WalkResult::interrupt();
- }
-
- if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
- if (symName == tosaOp.getName()) {
- varOp = tosaOp;
- return WalkResult::interrupt();
- }
- }
-
- return WalkResult::advance();
- });
-
- if (varOp)
- return varOp;
-
- return failure();
-}
-
template <typename T>
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
- StringRef symName = op.getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
- if (failed(varOp))
+ Operation *symTableOp =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+ if (!symTableOp)
+ // If the operation is not the scope of a symbol table, we cannot
+ // verify it against it's declaration.
+ return success();
+
+ SymbolTable symTable(symTableOp);
+ const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
+
+ // Verify prior declaration
+ if (!varOp)
return op->emitOpError("'")
- << symName << "' has not been declared by 'tosa.variable'";
+ << op.getName() << "' has not been declared by 'tosa.variable'";
// Verify type and shape
- auto variableType = getVariableType(varOp.value());
+ auto variableType = getVariableType(varOp);
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
"the input tensor")
.failed())
return failure();
-
return success();
}
@@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> shape = shapedType.getShape();
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
- result.addAttribute("name", nameAttr);
+ result.addAttribute("sym_name", nameAttr);
result.addAttribute("var_shape", varShapeAttr);
result.addAttribute("type", elementTypeAttr);
result.addAttribute("initial_value", initialValue);
@@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
-LogicalResult tosa::VariableOp::verify() {
- StringRef symName = getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
- if (succeeded(varOp))
- return emitOpError("illegal to have multiple declaration of '")
- << symName << "'";
-
- return success();
-}
-
LogicalResult tosa::VariableReadOp::verify() {
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
.failed())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index a500228..45cef9c1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Verifier.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -140,6 +141,20 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
"operations with symbol tables";
}
+ // Pre-verify calls and callables because call graph construction below
+ // assumes they are valid, but this verifier runs before verifying the
+ // nested operations.
+ WalkResult walkResult = op->walk([](Operation *nested) {
+ if (!isa<CallableOpInterface, CallOpInterface>(nested))
+ return WalkResult::advance();
+
+ if (failed(verify(nested, /*verifyRecursively=*/false)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return failure();
+
const mlir::CallGraph callgraph(op);
for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
if (!scc.hasCycle())
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3385b2a..365afab 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -2097,17 +2097,11 @@ void transform::IncludeOp::getEffects(
getOperation(), getTarget());
if (!callee)
return defaultEffects();
- DiagnosedSilenceableFailure earlyVerifierResult =
- verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
- if (!earlyVerifierResult.succeeded()) {
- (void)earlyVerifierResult.silence();
- return defaultEffects();
- }
for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
consumesHandle(getOperation()->getOpOperand(i), effects);
- else
+ else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
onlyReadsHandle(getOperation()->getOpOperand(i), effects);
}
}
@@ -2597,10 +2591,7 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
.Case([&](TransformParamTypeInterface param) {
return llvm::range_size(state.getParams(getHandle()));
})
- .Default([](Type) {
- llvm_unreachable("unknown kind of transform dialect type");
- return 0;
- });
+ .DefaultUnreachable("unknown kind of transform dialect type");
results.setParams(cast<OpResult>(getNum()),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
@@ -2657,10 +2648,7 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
.Case<TransformParamTypeInterface>([&](auto x) {
return llvm::range_size(state.getParams(getHandle()));
})
- .Default([](auto x) {
- llvm_unreachable("unknown transform dialect type interface");
- return -1;
- });
+ .DefaultUnreachable("unknown transform dialect type interface");
auto produceNumOpsError = [&]() {
return emitSilenceableError()