diff options
Diffstat (limited to 'mlir/lib')
35 files changed, 155 insertions, 180 deletions
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index d57926ec..39d4815 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -243,7 +243,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList, .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) { getOperandTreePredicates(predList, val, builder, inputs, pos); }) - .Default([](auto *) { llvm_unreachable("unexpected position kind"); }); + .DefaultUnreachable("unexpected position kind"); } static void getAttributePredicates(pdl::AttributeOp op, diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 9b61540..50fca56 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1118,10 +1118,7 @@ StringRef getTypeMangling(Type type, bool isSigned) { llvm_unreachable("Unsupported integer width"); } }) - .Default([](auto) { - llvm_unreachable("No mangling defined"); - return ""; - }); + .DefaultUnreachable("No mangling defined"); } template <typename ReduceOp> diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 0f90acf..57877b8 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -68,9 +68,7 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) { llvm_unreachable("unhandled integer type"); } }) - .Default([](Type) -> std::string { - llvm_unreachable("unhandled type for mangling"); - }); + .DefaultUnreachable("unhandled type for mangling"); } std::string mangle(StringRef baseName, ArrayRef<Type> types, diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index 9196d2e..39e398b 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -170,7 +170,7 @@ public: op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); }) - .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); + .DefaultUnreachable("unexpected extend op!"); } else if (kind == arm_sme::CombiningKind::Sub) { TypeSwitch<Operation *>(extOp) .Case<arith::ExtFOp>([&](auto) { @@ -188,7 +188,7 @@ public: op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc()); }) - .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); + .DefaultUnreachable("unexpected extend op!"); } else { llvm_unreachable("unexpected arm_sme::CombiningKind!"); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index c0f9132..19eba6b 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -375,7 +375,7 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { os << shape.back() << 'x' << fragTy.getElementType(); os << ", \"" << fragTy.getOperand() << "\"" << '>'; }) - .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); + .DefaultUnreachable("unexpected 'gpu' type kind"); } static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 2561f66..0a3ef7d 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -847,9 +847,7 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping, *maybeMaskingAttr); }) - .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder { - llvm_unreachable("unknown mapping attribute"); - }); + .DefaultUnreachable("unknown mapping attribute"); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index ef38027..cee943d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1096,10 +1096,8 @@ static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot, Value intVal = buildMemsetValue(type.getWidth()); return LLVM::BitcastOp::create(builder, op.getLoc(), type, intVal); }) - .Default([](Type) -> Value { - llvm_unreachable( - "getStored should not be called on memset to unsupported type"); - }); + .DefaultUnreachable( + "getStored should not be called on memset to unsupported type"); } template <class MemsetIntr> diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index 297640c..705d07d 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -45,9 +45,7 @@ static StringRef getTypeKeyword(Type type) { .Case<LLVMStructType>([&](Type) { return "struct"; }) .Case<LLVMTargetExtType>([&](Type) { return "target"; }) .Case<LLVMX86AMXType>([&](Type) { return "x86_amx"; }) - .Default([](Type) -> StringRef { - llvm_unreachable("unexpected 'llvm' type kind"); - }); + .DefaultUnreachable("unexpected 'llvm' type kind"); } /// Prints a structure type. Keeps track of known struct names to handle self- diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 38f1a8b..42160a1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -192,7 +192,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, .Case([&](affine::AffineForOp affineForOp) { allIvs.push_back(affineForOp.getInductionVar()); }) - .Default([&](Operation *op) { assert(false && "unexpected op"); }); + .DefaultUnreachable("unexpected op"); } assert(linalgOp.getNumLoops() == allIvs.size() && "expected the number of loops and induction variables to match"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp index 00a076b..c904556 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp @@ -48,10 +48,7 @@ ElementwiseKind getKind(Operation *op) { .Case([](SquareOp) { return ElementwiseKind::square; }) .Case([](TanhOp) { return ElementwiseKind::tanh; }) .Case([](ErfOp) { return ElementwiseKind::erf; }) - .Default([&](Operation *op) { - llvm_unreachable("unhandled case in named to elementwise"); - return ElementwiseKind::sub; - }); + .DefaultUnreachable("unhandled case in named to elementwise"); } template <typename NamedOpTy> diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index e9a8b25..7863c21 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1427,10 +1427,7 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: .Case([&](linalg::PoolingNchwMaxOp op) { return std::make_tuple(0, 1, 2, 3); }) - .Default([&](Operation *op) { - llvm_unreachable("unexpected conv2d/pool2d operation."); - return std::make_tuple(0, 0, 0, 0); - }); + .DefaultUnreachable("unexpected conv2d/pool2d operation."); // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 3593b53..24d3722 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -604,9 +604,7 @@ static Operation *materializeTiledShape(OpBuilder &builder, Location loc, builder, loc, valueToTile, sliceParams.offsets, sliceParams.sizes, sliceParams.strides); }) - .Default([](ShapedType) -> Operation * { - llvm_unreachable("Unexpected shaped type"); - }); + .DefaultUnreachable("Unexpected shaped type"); return sliceOp; } diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 24da447..214410f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -315,7 +315,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite( op, op.getType(), subViewOp.getSource(), sourceIndices, op.getTranspose(), op.getNumTiles()); }) - .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + .DefaultUnreachable("unexpected operation"); return success(); } @@ -367,7 +367,7 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getPassThru()); }) - .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + .DefaultUnreachable("unexpected operation"); return success(); } @@ -415,7 +415,7 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getPassThru()); }) - .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + .DefaultUnreachable("unexpected operation"); return success(); } @@ -482,7 +482,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite( op, op.getSrc(), subViewOp.getSource(), sourceIndices, op.getLeadDimension(), op.getTransposeAttr()); }) - .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + .DefaultUnreachable("unexpected operation"); return success(); } @@ -535,7 +535,7 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getValueToStore()); }) - .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + .DefaultUnreachable("unexpected operation"); return success(); } @@ -584,7 +584,7 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getValueToStore()); }) - .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + .DefaultUnreachable("unexpected operation"); return success(); } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 5672942..fd4cabbad 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3425,10 +3425,7 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { } llvm_unreachable("Unexpected generatee argument"); }) - .Default([&](Operation *op) { - assert(false && "TODO: Custom name for this operation"); - return "transformed"; - }); + .DefaultUnreachable("TODO: Custom name for this operation"); } setNameFn(result, cliName); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 36685d3..29b770f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -2177,10 +2177,9 @@ cloneAsInsertSlices(RewriterBase &rewriter, auto clonedOp = cloneAsInsertSlice(rewriter, op); clonedSlices.push_back(clonedOp); }) - .Default([&](Operation *op) { - // Assert here assuming this has already been checked. - assert(0 && "unexpected slice type while cloning as insert slice"); - }); + // Assert here assuming this has already been checked. + .DefaultUnreachable( + "unexpected slice type while cloning as insert slice"); } return clonedSlices; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index c8efdf0..24c33f9 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -987,7 +987,7 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType, ImageType, SampledImageType, StructType, MatrixType, TensorArmType>( [&](auto type) { print(type, os); }) - .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); + .DefaultUnreachable("Unhandled SPIR-V type"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 7e9a80e..f895807 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -57,7 +57,7 @@ public: for (Type elementType : concreteType.getElementTypes()) add(elementType); }) - .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); }); + .DefaultUnreachable("Unhandled type"); } void add(Type type) { add(cast<SPIRVType>(type)); } @@ -107,7 +107,7 @@ public: for (Type elementType : concreteType.getElementTypes()) add(elementType); }) - .Default([](SPIRVType) { llvm_unreachable("Unhandled type"); }); + .DefaultUnreachable("Unhandled type"); } void add(Type type) { add(cast<SPIRVType>(type)); } @@ -198,8 +198,7 @@ Type CompositeType::getElementType(unsigned index) const { .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); }) .Case<StructType>( [index](StructType type) { return type.getElementType(index); }) - .Default( - [](Type) -> Type { llvm_unreachable("invalid composite type"); }); + .DefaultUnreachable("Invalid composite type"); } unsigned CompositeType::getNumElements() const { @@ -207,9 +206,7 @@ unsigned CompositeType::getNumElements() const { .Case<ArrayType, StructType, TensorArmType, VectorType>( [](auto type) { return type.getNumElements(); }) .Case<MatrixType>([](MatrixType type) { return type.getNumColumns(); }) - .Default([](SPIRVType) -> unsigned { - llvm_unreachable("Invalid type for number of elements query"); - }); + .DefaultUnreachable("Invalid type for number of elements query"); } bool CompositeType::hasCompileTimeKnownNumElements() const { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 122f61e0..88e1ab6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -622,7 +622,7 @@ static spirv::Dim convertRank(int64_t rank) { } static spirv::ImageFormat getImageFormat(Type elementType) { - return llvm::TypeSwitch<Type, spirv::ImageFormat>(elementType) + return TypeSwitch<Type, spirv::ImageFormat>(elementType) .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; }) .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; }) .Case<IntegerType>([](IntegerType intType) { @@ -639,11 +639,7 @@ static spirv::ImageFormat getImageFormat(Type elementType) { llvm_unreachable("Unhandled integer type!"); } }) - .Default([](Type) { - llvm_unreachable("Unhandled element type!"); - // We need to return something here to satisfy the type switch. - return spirv::ImageFormat::R32f; - }); + .DefaultUnreachable("Unhandled element type!"); #undef BIT_WIDTH_CASE } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index a1e35b8..0fc5cc7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -59,7 +59,7 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> { // Flattens an affine expression into a list of AffineDimExprs. struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { - explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; + explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {}; void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } BitVector dims; }; @@ -67,7 +67,7 @@ struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { // Flattens an affine expression into a list of AffineDimExprs. struct AffineExprAdmissibleVisitor : public AffineExprVisitor<AffineExprAdmissibleVisitor> { - explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){}; + explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {}; // We only allow AffineDimExpr on output. void visitAddExpr(AffineBinaryOpExpr expr) { @@ -407,7 +407,10 @@ public: }; struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> { - using OpRewritePattern::OpRewritePattern; + GenericOpScheduler(MLIRContext *context, + sparse_tensor::LoopOrderingStrategy strategy) + : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {} + LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, PatternRewriter &rewriter) const override { if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() || @@ -420,7 +423,8 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> { if (linalgOp->hasAttr(sorted)) return failure(); - auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp); + // Pass strategy to IterationGraphSorter. + auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy); bool isAdmissible = false; AffineMap order; // A const list of all masks that we used for iteration graph @@ -582,6 +586,9 @@ private: // TODO: convert more than one? return failure(); } + +private: + sparse_tensor::LoopOrderingStrategy strategy; }; //===----------------------------------------------------------------------===// @@ -786,12 +793,13 @@ struct ForeachOpDemapper } // namespace -void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, - ReinterpretMapScope scope) { +void mlir::populateSparseReinterpretMap( + RewritePatternSet &patterns, ReinterpretMapScope scope, + sparse_tensor::LoopOrderingStrategy strategy) { if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kGenericOnly) { - patterns.add<GenericOpReinterpretMap, GenericOpScheduler>( - patterns.getContext()); + patterns.add<GenericOpReinterpretMap>(patterns.getContext()); + patterns.add<GenericOpScheduler>(patterns.getContext(), strategy); } if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kExceptGeneric) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 153b9b1..b660e22 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -67,12 +67,13 @@ struct SparseReinterpretMap SparseReinterpretMap(const SparseReinterpretMap &pass) = default; SparseReinterpretMap(const SparseReinterpretMapOptions &options) { scope = options.scope; + loopOrderingStrategy = options.loopOrderingStrategy; } void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateSparseReinterpretMap(patterns, scope); + populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -438,6 +439,14 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) { return std::make_unique<SparseReinterpretMap>(options); } +std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass( + ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy) { + SparseReinterpretMapOptions options; + options.scope = scope; + options.loopOrderingStrategy = strategy; + return std::make_unique<SparseReinterpretMap>(options); +} + std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() { return std::make_unique<PreSparsificationRewritePass>(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index c7e463a..73e0f3d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -100,7 +100,15 @@ AffineMap IterationGraphSorter::topoSort() { // We always prefer a parallel loop over a reduction loop because putting // a reduction loop early might make the loop sequence inadmissible. auto &it = !parIt.empty() ? parIt : redIt; - auto src = it.back(); + + // Select loop based on strategy. + unsigned src; + switch (strategy) { + case sparse_tensor::LoopOrderingStrategy::kDefault: + src = it.back(); + break; + } + loopOrder.push_back(src); it.pop_back(); // Update in-degree, and push 0-degree node into worklist. @@ -122,8 +130,8 @@ AffineMap IterationGraphSorter::topoSort() { return AffineMap(); } -IterationGraphSorter -IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) { +IterationGraphSorter IterationGraphSorter::fromGenericOp( + linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy) { // Must be a demapped sparse kernel. assert(!hasAnyNonIdentityOperandsOrResults(genericOp) && hasAnySparseOperandOrResult(genericOp) && @@ -140,14 +148,16 @@ IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) { genericOp.getIteratorTypesArray(); return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap, - std::move(iterTypes)); + std::move(iterTypes), strategy); } IterationGraphSorter::IterationGraphSorter( SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out, - AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes) + AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes, + sparse_tensor::LoopOrderingStrategy strategy) : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out), - loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) { + loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), + strategy(strategy) { // One map per tensor. assert(loop2InsLvl.size() == ins.size()); // All the affine maps have the same number of dimensions (loops). diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h index a6abe9e..b2a16e9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_ #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_ +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/IR/AffineMap.h" namespace mlir { @@ -41,9 +42,12 @@ enum class SortMask : unsigned { class IterationGraphSorter { public: - /// Factory method that construct an iteration graph sorter - /// for the given linalg.generic operation. - static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp); + /// Factory method that constructs an iteration graph sorter + /// for the given linalg.generic operation with a specific loop ordering + /// strategy. + static IterationGraphSorter + fromGenericOp(linalg::GenericOp genericOp, + sparse_tensor::LoopOrderingStrategy strategy); /// Returns a permutation that represents the scheduled loop order. /// Note that the returned AffineMap could be null if the kernel @@ -58,7 +62,9 @@ private: IterationGraphSorter(SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out, AffineMap loop2OutLvl, - SmallVector<utils::IteratorType> &&iterTypes); + SmallVector<utils::IteratorType> &&iterTypes, + sparse_tensor::LoopOrderingStrategy strategy = + sparse_tensor::LoopOrderingStrategy::kDefault); // Adds all the constraints in the given loop to level map. void addConstraints(Value t, AffineMap loop2LvlMap); @@ -84,6 +90,9 @@ private: // InDegree used for topo sort. std::vector<unsigned> inDegree; + + // Loop ordering strategy. + sparse_tensor::LoopOrderingStrategy strategy; }; } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 332f1a0..c51b5e9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) { return shapeAdaptor.getNumElements() == 1 ? success() : failure(); } -// Returns the first declaration point prior to this operation or failure if -// not found. -static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op, - StringRef symName) { - ModuleOp module = op->getParentOfType<ModuleOp>(); - tosa::VariableOp varOp = nullptr; - - // TODO: Adopt SymbolTable trait to Varible ops. - // Currently, the variable's definition point is searched via walk(), - // starting from the top-level ModuleOp and stopping at the point of use. Once - // TOSA control flow and variable extensions reach the complete state, may - // leverage MLIR's Symbol Table functionality to look up symbol and enhance - // the search to a TOSA specific graph traversal over the IR structure. - module.walk([&](Operation *tempOp) { - // Reach this op itself. - if (tempOp == op) { - return WalkResult::interrupt(); - } - - if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) { - if (symName == tosaOp.getName()) { - varOp = tosaOp; - return WalkResult::interrupt(); - } - } - - return WalkResult::advance(); - }); - - if (varOp) - return varOp; - - return failure(); -} - template <typename T> static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { - StringRef symName = op.getName(); - FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName); - if (failed(varOp)) + Operation *symTableOp = + op->template getParentWithTrait<OpTrait::SymbolTable>(); + if (!symTableOp) + // If the operation is not the scope of a symbol table, we cannot + // verify it against it's declaration. + return success(); + + SymbolTable symTable(symTableOp); + const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName()); + + // Verify prior declaration + if (!varOp) return op->emitOpError("'") - << symName << "' has not been declared by 'tosa.variable'"; + << op.getName() << "' has not been declared by 'tosa.variable'"; // Verify type and shape - auto variableType = getVariableType(varOp.value()); + auto variableType = getVariableType(varOp); if (errorIfTypeOrShapeMismatch(op, type, name, variableType, "the input tensor") .failed()) return failure(); - return success(); } @@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result, ArrayRef<int64_t> shape = shapedType.getShape(); auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); - result.addAttribute("name", nameAttr); + result.addAttribute("sym_name", nameAttr); result.addAttribute("var_shape", varShapeAttr); result.addAttribute("type", elementTypeAttr); result.addAttribute("initial_value", initialValue); @@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() { return success(); } -LogicalResult tosa::VariableOp::verify() { - StringRef symName = getName(); - FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName); - if (succeeded(varOp)) - return emitOpError("illegal to have multiple declaration of '") - << symName << "'"; - - return success(); -} - LogicalResult tosa::VariableReadOp::verify() { if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'") .failed()) diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp index a500228..45cef9c1 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Transform/IR/Utils.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Verifier.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/TypeSwitch.h" @@ -140,6 +141,20 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute( "operations with symbol tables"; } + // Pre-verify calls and callables because call graph construction below + // assumes they are valid, but this verifier runs before verifying the + // nested operations. + WalkResult walkResult = op->walk([](Operation *nested) { + if (!isa<CallableOpInterface, CallOpInterface>(nested)) + return WalkResult::advance(); + + if (failed(verify(nested, /*verifyRecursively=*/false))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + const mlir::CallGraph callgraph(op); for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) { if (!scc.hasCycle()) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 3385b2a..365afab 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2097,17 +2097,11 @@ void transform::IncludeOp::getEffects( getOperation(), getTarget()); if (!callee) return defaultEffects(); - DiagnosedSilenceableFailure earlyVerifierResult = - verifyNamedSequenceOp(callee, /*emitWarnings=*/false); - if (!earlyVerifierResult.succeeded()) { - (void)earlyVerifierResult.silence(); - return defaultEffects(); - } for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) consumesHandle(getOperation()->getOpOperand(i), effects); - else + else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName)) onlyReadsHandle(getOperation()->getOpOperand(i), effects); } } @@ -2597,10 +2591,7 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter, .Case([&](TransformParamTypeInterface param) { return llvm::range_size(state.getParams(getHandle())); }) - .Default([](Type) { - llvm_unreachable("unknown kind of transform dialect type"); - return 0; - }); + .DefaultUnreachable("unknown kind of transform dialect type"); results.setParams(cast<OpResult>(getNum()), rewriter.getI64IntegerAttr(numAssociations)); return DiagnosedSilenceableFailure::success(); @@ -2657,10 +2648,7 @@ transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, .Case<TransformParamTypeInterface>([&](auto x) { return llvm::range_size(state.getParams(getHandle())); }) - .Default([](auto x) { - llvm_unreachable("unknown transform dialect type interface"); - return -1; - }); + .DefaultUnreachable("unknown transform dialect type interface"); auto produceNumOpsError = [&]() { return emitSilenceableError() diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 3b6330b..7823849 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -364,10 +364,7 @@ static DataLayoutSpecInterface getSpec(Operation *operation) { return llvm::TypeSwitch<Operation *, DataLayoutSpecInterface>(operation) .Case<ModuleOp, DataLayoutOpInterface>( [&](auto op) { return op.getDataLayoutSpec(); }) - .Default([](Operation *) { - llvm_unreachable("expected an op with data layout spec"); - return DataLayoutSpecInterface(); - }); + .DefaultUnreachable("expected an op with data layout spec"); } static TargetSystemSpecInterface getTargetSystemSpec(Operation *operation) { diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 5cbea5d..33fbd2a 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -764,9 +764,7 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) { pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) - .Default([](Operation *) { - llvm_unreachable("unknown `pdl_interp` operation"); - }); + .DefaultUnreachable("unknown `pdl_interp` operation"); } void Generator::generate(pdl_interp::ApplyConstraintOp op, @@ -913,9 +911,7 @@ void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) .Case([](pdl::TypeType) { return OpCode::ExtractType; }) - .Default([](Type) -> OpCode { - llvm_unreachable("unsupported element type"); - }); + .DefaultUnreachable("unsupported element type"); writer.append(opCode, op.getRange(), op.getIndex(), op.getResult()); } void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp index ec7adf3..b0ad3ee 100644 --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -25,7 +25,8 @@ using llvm::StringInit; // InterfaceMethod //===----------------------------------------------------------------------===// -InterfaceMethod::InterfaceMethod(const Record *def) : def(def) { +InterfaceMethod::InterfaceMethod(const Record *def, std::string uniqueName) + : def(def), uniqueName(uniqueName) { const DagInit *args = def->getValueAsDag("arguments"); for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { arguments.push_back({cast<StringInit>(args->getArg(i))->getValue(), @@ -42,6 +43,9 @@ StringRef InterfaceMethod::getName() const { return def->getValueAsString("name"); } +// Return the name of this method. +StringRef InterfaceMethod::getUniqueName() const { return uniqueName; } + // Return if this method is static. bool InterfaceMethod::isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); @@ -83,8 +87,19 @@ Interface::Interface(const Record *def) : def(def) { // Initialize the interface methods. auto *listInit = dyn_cast<ListInit>(def->getValueInit("methods")); - for (const Init *init : listInit->getElements()) - methods.emplace_back(cast<DefInit>(init)->getDef()); + // In case of overloaded methods, we need to find a unique name for each for + // the internal function pointer in the "vtable" we generate. This is an + // internal name, we could use a randomly generated name as long as there are + // no collisions. + StringSet<> uniqueNames; + for (const Init *init : listInit->getElements()) { + std::string name = + cast<DefInit>(init)->getDef()->getValueAsString("name").str(); + while (!uniqueNames.insert(name).second) { + name = name + "_" + std::to_string(uniqueNames.size()); + } + methods.emplace_back(cast<DefInit>(init)->getDef(), name); + } // Initialize the interface base classes. auto *basesInit = dyn_cast<ListInit>(def->getValueInit("baseInterfaces")); diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 5055cd9..4098ccc 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -56,8 +56,9 @@ ModuleToObject::getOrCreateTargetMachine() { return targetMachine.get(); // Load the target. std::string error; + llvm::Triple parsedTriple(triple); const llvm::Target *target = - llvm::TargetRegistry::lookupTarget(triple, error); + llvm::TargetRegistry::lookupTarget(parsedTriple, error); if (!target) { getOperation().emitError() << "Failed to lookup target for triple '" << triple << "' " << error; @@ -65,8 +66,8 @@ ModuleToObject::getOrCreateTargetMachine() { } // Create the target machine using the target. - targetMachine.reset(target->createTargetMachine(llvm::Triple(triple), chip, - features, {}, {})); + targetMachine.reset( + target->createTargetMachine(parsedTriple, chip, features, {}, {})); if (!targetMachine) return std::nullopt; return targetMachine.get(); diff --git a/mlir/lib/Target/LLVM/ROCDL/Target.cpp b/mlir/lib/Target/LLVM/ROCDL/Target.cpp index c9888c3..f813f8d 100644 --- a/mlir/lib/Target/LLVM/ROCDL/Target.cpp +++ b/mlir/lib/Target/LLVM/ROCDL/Target.cpp @@ -289,7 +289,7 @@ SerializeGPUModuleBase::assembleIsa(StringRef isa) { llvm::Triple triple(llvm::Triple::normalize(targetTriple)); std::string error; const llvm::Target *target = - llvm::TargetRegistry::lookupTarget(triple.normalize(), error); + llvm::TargetRegistry::lookupTarget(triple, error); if (!target) { emitError(loc, Twine("failed to lookup target: ") + error); return std::nullopt; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9fcb02e..1e2099d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4716,10 +4716,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, info.HasNoWait = updateDataOp.getNowait(); return success(); }) - .Default([&](Operation *op) { - llvm_unreachable("unexpected operation"); - return failure(); - }); + .DefaultUnreachable("unexpected operation"); if (failed(result)) return failure(); @@ -5312,9 +5309,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, (void)found; assert(found && "unsupported host_eval use"); }) - .Default([](Operation *) { - llvm_unreachable("unsupported host_eval use"); - }); + .DefaultUnreachable("unsupported host_eval use"); } } } diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp index f1d3622..3f414b6 100644 --- a/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp +++ b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp @@ -43,16 +43,17 @@ getTargetMachine(mlir::LLVM::TargetAttrInterface attr) { llvm::cast_if_present<LLVM::TargetFeaturesAttr>(attr.getFeatures()); std::string features = featuresAttr ? featuresAttr.getFeaturesString() : ""; + llvm::Triple parsedTriple(triple); std::string error; const llvm::Target *target = - llvm::TargetRegistry::lookupTarget(triple, error); + llvm::TargetRegistry::lookupTarget(parsedTriple, error); if (!target || !error.empty()) { LDBG() << "Looking up target '" << triple << "' failed: " << error << "\n"; return failure(); } - return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine( - llvm::Triple(triple), chipAKAcpu, features, {}, {})); + return std::unique_ptr<llvm::TargetMachine>( + target->createTargetMachine(parsedTriple, chipAKAcpu, features, {}, {})); } FailureOr<llvm::DataLayout> diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp index 4d20474..807a94c 100644 --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -74,9 +74,7 @@ public: LLVM::LLVMPointerType, LLVM::LLVMStructType, VectorType, LLVM::LLVMTargetExtType, PtrLikeTypeInterface>( [this](auto type) { return this->translate(type); }) - .Default([](Type t) -> llvm::Type * { - llvm_unreachable("unknown LLVM dialect type"); - }); + .DefaultUnreachable("unknown LLVM dialect type"); // Cache the result of the conversion and return. knownTranslations.try_emplace(type, translated); diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp index e2c987a..f49d3d0 100644 --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -154,7 +154,7 @@ void NodePrinter::print(Type type) { }) .Case([&](TypeType) { os << "Type"; }) .Case([&](ValueType) { os << "Value"; }) - .Default([](Type) { llvm_unreachable("unknown AST type"); }); + .DefaultUnreachable("unknown AST type"); } void NodePrinter::print(const Node *node) { @@ -182,7 +182,7 @@ void NodePrinter::print(const Node *node) { const VariableDecl, const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) - .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); + .DefaultUnreachable("unknown AST node"); elementIndentStack.pop_back(); } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp index 159ce62..5aa0937 100644 --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -72,7 +72,7 @@ public: const Module>( [&](auto derivedNode) { this->visitImpl(derivedNode); }) - .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); + .DefaultUnreachable("unknown AST node"); } private: |