diff options
15 files changed, 379 insertions, 85 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 83198c9..76d5607 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -247,7 +247,8 @@ public: /// Attempts a 1-1 type conversion, expecting the result type to be /// `TargetType`. Returns the converted type cast to `TargetType` on success, /// and a null type on conversion or cast failure. - template <typename TargetType> TargetType convertType(Type t) const { + template <typename TargetType> + TargetType convertType(Type t) const { return dyn_cast_or_null<TargetType>(convertType(t)); } @@ -657,7 +658,7 @@ struct ConversionPatternRewriterImpl; /// This class implements a pattern rewriter for use with ConversionPatterns. It /// extends the base PatternRewriter and provides special conversion specific /// hooks. -class ConversionPatternRewriter final : public PatternRewriter { +class ConversionPatternRewriter : public PatternRewriter { public: ~ConversionPatternRewriter() override; @@ -708,8 +709,18 @@ public: /// Return the converted values that replace 'keys' with types defined by the /// type converter of the currently executing pattern. Returns failure if the /// remap failed, success otherwise. - LogicalResult getRemappedValues(ValueRange keys, - SmallVectorImpl<Value> &results); + LogicalResult getRemappedValues(ValueRange keys, SmallVector<Value> &results); + + virtual void setCurrentTypeConverter(const TypeConverter *converter); + + virtual const TypeConverter *getCurrentTypeConverter() const; + + /// Populate the operands that are used for constructing the adapter into + /// `remapped`. + virtual LogicalResult getAdapterOperands(StringRef valueDiagTag, + std::optional<Location> inputLoc, + ValueRange values, + SmallVector<Value> &remapped); //===--------------------------------------------------------------------===// // PatternRewriter Hooks @@ -755,6 +766,14 @@ public: /// Return a reference to the internal implementation. detail::ConversionPatternRewriterImpl &getImpl(); +protected: + /// Protected constructor for `OneShotConversionPatternRewriter`. Does not + /// initialize `impl`. + explicit ConversionPatternRewriter(MLIRContext *ctx); + + // Hide unsupported pattern rewriter API. + using OpBuilder::setListener; + private: // Allow OperationConverter to construct new rewriters. friend struct OperationConverter; @@ -765,9 +784,6 @@ private: explicit ConversionPatternRewriter(MLIRContext *ctx, const ConversionConfig &config); - // Hide unsupported pattern rewriter API. - using OpBuilder::setListener; - std::unique_ptr<detail::ConversionPatternRewriterImpl> impl; }; diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index 763146a..7513427 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -18,6 +18,8 @@ namespace mlir { +class ConversionTarget; + /// This enum controls which ops are put on the worklist during a greedy /// pattern rewrite. enum class GreedyRewriteStrictness { @@ -78,6 +80,8 @@ public: /// excluded. GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; + bool enableOperationDce = true; + /// An optional listener that should be notified about IR modifications. RewriterBase::Listener *listener = nullptr; }; @@ -188,6 +192,10 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops, GreedyRewriteConfig config = GreedyRewriteConfig(), bool *changed = nullptr, bool *allErased = nullptr); +LogicalResult +applyPartialOneShotConversion(Operation *op, const ConversionTarget &target, + const FrozenRewritePatternSet &patterns); + } // namespace mlir #endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 10ccd5c..f7e0ffe 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" namespace mlir { @@ -563,8 +564,8 @@ class LowerAffinePass ConversionTarget target(getContext()); target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect, scf::SCFDialect, VectorDialect>(); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(applyPartialOneShotConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index d882f11..dcc3dd3 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include <type_traits> namespace mlir { @@ -479,8 +480,8 @@ struct ArithToLLVMConversionPass LLVMTypeConverter converter(&getContext(), options); mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(applyPartialOneShotConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index d8150ae..885bf7f 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include <memory> #include <type_traits> @@ -1346,8 +1347,8 @@ void ConvertComplexToStandardPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<arith::ArithDialect, math::MathDialect>(); target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>(); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialOneShotConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } } // namespace diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index b8e5aec..0d81277 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" #include <functional> @@ -240,8 +241,8 @@ struct ConvertControlFlowToLLVM LLVMTypeConverter converter(&getContext(), options); mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(applyPartialOneShotConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 23e9572..89c5a98 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS @@ -291,8 +292,8 @@ struct ConvertMathToLLVMPass LLVMTypeConverter converter(&getContext()); populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(applyPartialOneShotConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 11d2975..931cedc 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -475,9 +476,10 @@ struct ConvertNVGPUToNVVMPass target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); mlir::scf::populateSCFStructuralTypeConversionsAndLegality( converter, patterns, target); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(applyPartialOneShotConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); + // applyPartialConversion } }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d407d60..8f0d560 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1633,6 +1633,9 @@ ConversionPatternRewriter::ConversionPatternRewriter( setListener(impl.get()); } +ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) + : PatternRewriter(ctx), impl(nullptr) {} + ConversionPatternRewriter::~ConversionPatternRewriter() = default; void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { @@ -1717,19 +1720,17 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value ConversionPatternRewriter::getRemappedValue(Value key) { SmallVector<Value> remappedValues; - if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, - remappedValues))) + if (failed(getRemappedValues(key, remappedValues))) return nullptr; return remappedValues.front(); } LogicalResult ConversionPatternRewriter::getRemappedValues(ValueRange keys, - SmallVectorImpl<Value> &results) { + SmallVector<Value> &results) { if (keys.empty()) return success(); - return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, - results); + return getAdapterOperands("value", /*inputLoc=*/std::nullopt, keys, results); } void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, @@ -1819,6 +1820,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; } +void ConversionPatternRewriter::setCurrentTypeConverter( + const TypeConverter *converter) { + impl->currentTypeConverter = converter; +} + +const TypeConverter * +ConversionPatternRewriter::getCurrentTypeConverter() const { + return impl->currentTypeConverter; +} + +LogicalResult ConversionPatternRewriter::getAdapterOperands( + StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values, + SmallVector<Value> &remapped) { + return impl->remapValues(valueDiagTag, inputLoc, *this, values, remapped); +} + //===----------------------------------------------------------------------===// // ConversionPattern //===----------------------------------------------------------------------===// @@ -1827,16 +1844,18 @@ LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); - auto &rewriterImpl = dialectRewriter.getImpl(); // Track the current conversion pattern type converter in the rewriter. - llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter, - getTypeConverter()); + const TypeConverter *currentTypeConverter = + dialectRewriter.getCurrentTypeConverter(); + auto resetTypeConverter = llvm::make_scope_exit( + [&] { dialectRewriter.setCurrentTypeConverter(currentTypeConverter); }); + dialectRewriter.setCurrentTypeConverter(getTypeConverter()); // Remap the operands of the operation. - SmallVector<Value, 4> operands; - if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, - op->getOperands(), operands))) { + SmallVector<Value> operands; + if (failed(dialectRewriter.getAdapterOperands("operand", op->getLoc(), + op->getOperands(), operands))) { return failure(); } return matchAndRewrite(op, operands, dialectRewriter); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 597cb29..99e8282 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -14,10 +14,12 @@ #include "mlir/Config/mlir-config.h" #include "mlir/IR/Action.h" +#include "mlir/IR/Iterators.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/BitVector.h" @@ -321,7 +323,7 @@ private: /// to the worklist in the beginning. class GreedyPatternRewriteDriver : public RewriterBase::Listener { protected: - explicit GreedyPatternRewriteDriver(MLIRContext *ctx, + explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config); @@ -329,7 +331,7 @@ protected: void addSingleOpToWorklist(Operation *op); /// Add the given operation and its ancestors to the worklist. - void addToWorklist(Operation *op); + virtual void addToWorklist(Operation *op); /// Notify the driver that the specified operation may have been modified /// in-place. The operation is added to the worklist. @@ -356,7 +358,7 @@ protected: /// The pattern rewriter that is used for making IR modifications and is /// passed to rewrite patterns. - PatternRewriter rewriter; + PatternRewriter &rewriter; /// The worklist for this transformation keeps track of the operations that /// need to be (re)visited. @@ -375,6 +377,11 @@ protected: /// `config.strictMode` is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps; +#ifndef NDEBUG + /// A logger used to emit information during the application process. + llvm::ScopedPrinter logger{llvm::dbgs()}; +#endif + private: /// Look over the provided operands for any defining operations that should /// be re-added to the worklist. This function should be called when an @@ -394,11 +401,6 @@ private: notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override; -#ifndef NDEBUG - /// A logger used to emit information during the application process. - llvm::ScopedPrinter logger{llvm::dbgs()}; -#endif - /// The low-level pattern applicator. PatternApplicator matcher; @@ -409,9 +411,9 @@ private: } // namespace GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) - : rewriter(ctx), config(config), matcher(patterns) + : rewriter(rewriter), config(config), matcher(patterns) #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // clang-format off , expensiveChecks( @@ -476,7 +478,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { }); // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { + if (config.enableOperationDce && isOpTriviallyDead(op)) { rewriter.eraseOp(op); changed = true; @@ -780,7 +782,7 @@ namespace { /// This driver simplfies all ops in a region. class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { public: - explicit RegionPatternRewriteDriver(MLIRContext *ctx, + explicit RegionPatternRewriteDriver(PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, Region ®ions); @@ -796,9 +798,9 @@ private: } // namespace RegionPatternRewriteDriver::RegionPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, Region ®ion) - : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) { + : GreedyPatternRewriteDriver(rewriter, patterns, config), region(region) { // Populate strict mode ops. if (config.strictMode != GreedyRewriteStrictness::AnyOp) { region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); @@ -909,8 +911,8 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion, #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Start the pattern driver. - RegionPatternRewriteDriver driver(region.getContext(), patterns, config, - region); + PatternRewriter rewriter(region.getContext()); + RegionPatternRewriteDriver driver(rewriter, patterns, config, region); LogicalResult converged = std::move(driver).simplify(changed); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " @@ -928,7 +930,7 @@ namespace { class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr); @@ -950,10 +952,10 @@ private: } // namespace MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, llvm::SmallDenseSet<Operation *, 4> *survivingOps) - : GreedyPatternRewriteDriver(ctx, patterns, config), + : GreedyPatternRewriteDriver(rewriter, patterns, config), survivingOps(survivingOps) { if (config.strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.insert(ops.begin(), ops.end()); @@ -1040,9 +1042,9 @@ LogicalResult mlir::applyOpPatternsAndFold( #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Start the pattern driver. + PatternRewriter rewriter(ops.front()->getContext()); llvm::SmallDenseSet<Operation *, 4> surviving; - MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - config, ops, + MultiOpPatternRewriteDriver driver(rewriter, patterns, config, ops, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplify(ops, changed); if (allErased) @@ -1053,3 +1055,242 @@ LogicalResult mlir::applyOpPatternsAndFold( }); return converged; } + +//===----------------------------------------------------------------------===// +// One-Shot Dialect Conversion Infrastructure +//===----------------------------------------------------------------------===// + +namespace { +/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter +/// immediately materializes all IR changes. It derives from +/// `ConversionPatternRewriter` so that the existing conversion patterns can +/// be used with the One-Shot Dialect Conversion. +class OneShotConversionPatternRewriter : public ConversionPatternRewriter { +public: + OneShotConversionPatternRewriter(MLIRContext *ctx) + : ConversionPatternRewriter(ctx) {} + + bool canRecoverFromRewriteFailure() const override { return false; } + + void replaceOp(Operation *op, ValueRange newValues) override; + + void replaceOp(Operation *op, Operation *newOp) override { + replaceOp(op, newOp->getResults()); + } + + void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); } + + void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); } + + void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, + ValueRange argValues = std::nullopt) override { + PatternRewriter::inlineBlockBefore(source, dest, before, argValues); + } + using PatternRewriter::inlineBlockBefore; + + void startOpModification(Operation *op) override { + PatternRewriter::startOpModification(op); + } + + void finalizeOpModification(Operation *op) override { + PatternRewriter::finalizeOpModification(op); + } + + void cancelOpModification(Operation *op) override { + PatternRewriter::cancelOpModification(op); + } + + void setCurrentTypeConverter(const TypeConverter *converter) override { + typeConverter = converter; + } + + const TypeConverter *getCurrentTypeConverter() const override { + return typeConverter; + } + + LogicalResult getAdapterOperands(StringRef valueDiagTag, + std::optional<Location> inputLoc, + ValueRange values, + SmallVector<Value> &remapped) override; + +private: + /// Build an unrealized_conversion_cast op or look it up in the cache. + Value buildUnrealizedConversionCast(Location loc, Type type, Value value); + + /// The current type converter. + const TypeConverter *typeConverter; + + /// A cache for unrealized_conversion_casts. To ensure that identical casts + /// are not built multiple times. + DenseMap<std::pair<Value, Type>, Value> castCache; +}; + +void OneShotConversionPatternRewriter::replaceOp(Operation *op, + ValueRange newValues) { + assert(op->getNumResults() == newValues.size()); + for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) { + if (orig.getType() != repl.getType()) { + // Type mismatch: insert unrealized_conversion cast. + replaceAllUsesWith(orig, buildUnrealizedConversionCast( + op->getLoc(), orig.getType(), repl)); + } else { + // Same type: use replacement value directly. + replaceAllUsesWith(orig, repl); + } + } + eraseOp(op); +} + +Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast( + Location loc, Type type, Value value) { + auto it = castCache.find(std::make_pair(value, type)); + if (it != castCache.end()) + return it->second; + + // Insert cast at the beginning of the block (for block arguments) or right + // after the defining op. + OpBuilder::InsertionGuard g(*this); + Block *insertBlock = value.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = dyn_cast<OpResult>(value)) + insertPt = ++inputRes.getOwner()->getIterator(); + setInsertionPoint(insertBlock, insertPt); + auto castOp = create<UnrealizedConversionCastOp>(loc, type, value); + castCache[std::make_pair(value, type)] = castOp.getOutputs()[0]; + return castOp.getOutputs()[0]; +} + +class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver { +public: + ConversionPatternRewriteDriver(PatternRewriter &rewriter, + const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, + const ConversionTarget &target) + : GreedyPatternRewriteDriver(rewriter, patterns, config), target(target) { + } + + /// Populate the worklist with all illegal ops and start the conversion + /// process. + LogicalResult convert(Operation *op) &&; + +protected: + void addToWorklist(Operation *op) override; + + /// Notify the driver that the specified operation was removed. Update the + /// worklist as needed: The operation and its children are removed from the + /// worklist. + void notifyOperationErased(Operation *op) override; + +private: + const ConversionTarget ⌖ +}; +} // namespace + +LogicalResult ConversionPatternRewriteDriver::convert(Operation *op) && { + op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>([&](Operation *op) { + auto legalityInfo = target.isLegal(op); + if (!legalityInfo) { + addSingleOpToWorklist(op); + return WalkResult::advance(); + } + if (legalityInfo->isRecursivelyLegal) { + // Don't check this operation's children for conversion if the + // operation is recursively legal. + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + + // Reverse the list so our pop-back loop processes them in-order. + // TODO: newly enqueued ops must also be reversed + worklist.reverse(); + + processWorklist(); + + return success(); +} + +void ConversionPatternRewriteDriver::addToWorklist(Operation *op) { + if (!target.isLegal(op)) + addSingleOpToWorklist(op); +} + +// TODO: Refactor. This is the same as +// `GreedyPatternRewriteDriver::notifyOperationErased`, but does not add ops to +// the worklist. +void ConversionPatternRewriteDriver::notifyOperationErased(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Erase : '" << op->getName() << "'(" << op + << ")\n"; + }); + +#ifndef NDEBUG + // Only ops that are within the configured scope are added to the worklist of + // the greedy pattern rewriter. Moreover, the parent op of the scope region is + // the part of the IR that is taken into account for the "expensive checks". + // A greedy pattern rewrite is not allowed to erase the parent op of the scope + // region, as that would break the worklist handling and the expensive checks. + if (config.scope && config.scope->getParentOp() == op) + llvm_unreachable( + "scope region must not be erased during greedy pattern rewrite"); +#endif // NDEBUG + + if (config.listener) + config.listener->notifyOperationErased(op); + + worklist.remove(op); + + if (config.strictMode != GreedyRewriteStrictness::AnyOp) + strictModeFilteredOps.erase(op); +} + +/// Populate the converted operands in `remapped`. (Based on the currently set +/// type converter.) +LogicalResult OneShotConversionPatternRewriter::getAdapterOperands( + StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values, + SmallVector<Value> &remapped) { + // TODO: Refactor. This is mostly copied from the current dialect conversion. + for (Value v : values) { + // Skip all unrealized_conversion_casts in the chain of defining ops. + Value vBase = v; + while (auto castOp = vBase.getDefiningOp<UnrealizedConversionCastOp>()) + vBase = castOp.getInputs()[0]; + + if (!getCurrentTypeConverter()) { + // No type converter set. Just replicate what the current type conversion + // is doing. + // TODO: We may have to distinguish between newly-inserted an + // pre-existing unrealized_conversion_casts. + remapped.push_back(vBase); + continue; + } + + Type desiredType; + SmallVector<Type, 1> legalTypes; + if (failed(getCurrentTypeConverter()->convertType(v.getType(), legalTypes))) + return failure(); + assert(legalTypes.size() == 1 && "1:N conversion not supported yet"); + desiredType = legalTypes.front(); + if (desiredType == vBase.getType()) { + // Type already matches. No need to convert anything. + remapped.push_back(vBase); + continue; + } + + Location operandLoc = inputLoc ? *inputLoc : v.getLoc(); + remapped.push_back( + buildUnrealizedConversionCast(operandLoc, desiredType, vBase)); + } + return success(); +} + +LogicalResult +mlir::applyPartialOneShotConversion(Operation *op, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns) { + GreedyRewriteConfig config; + config.enableOperationDce = false; + OneShotConversionPatternRewriter rewriter(op->getContext()); + ConversionPatternRewriteDriver driver(rewriter, patterns, config, target); + return std::move(driver).convert(op); +} diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 23e0edd..865840a 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -430,7 +430,7 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index #map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)> // CHECK-LABEL: func @affine_applies( -func.func @affine_applies(%arg0 : index) { +func.func @affine_applies(%arg0 : index) -> (index, index, index, index, index) { // CHECK: %[[c0:.*]] = arith.constant 0 : index %zero = affine.apply #map0() @@ -442,9 +442,7 @@ func.func @affine_applies(%arg0 : index) { %102 = arith.constant 102 : index %copy = affine.apply #map2(%zero) -// CHECK-NEXT: %[[v0:.*]] = arith.addi %[[c0]], %[[c0]] : index // CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[v0]], %[[c1]] : index %one = affine.apply #map3(%symbZero)[%zero] // CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index @@ -466,7 +464,9 @@ func.func @affine_applies(%arg0 : index) { // CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index // CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0] - return + +// CHECK: return %[[c0]], %[[c0]], %[[c0]], %[[c1]], %[[v13]] + return %zero, %symbZero, %copy, %one, %four : index, index, index, index, index } // CHECK-LABEL: func @args_ret_affine_apply( diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 56ae930..2360d6e 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -321,9 +321,9 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) { // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast // CHECK-NEXT: = llvm.sext %[[ARG0]] : vector<1xi3> to vector<1xi6> %0 = arith.extsi %arg0 : vector<i3> to vector<i6> -// CHECK-NEXT: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6> +// CHECK: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6> %1 = arith.extui %arg0 : vector<i3> to vector<i6> -// CHECK-NEXT: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2> +// CHECK: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2> %2 = arith.trunci %arg0 : vector<i3> to vector<i2> return } @@ -478,11 +478,12 @@ func.func @mului_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) - // ----- // CHECK-LABEL: func @cmpf_2dvector( +// CHECK-SAME: %[[FARG0:.*]]: vector<4x3xf32>, %[[FARG1:.*]]: vector<4x3xf32> func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) { - // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast - // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast - // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xf32>> - // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xf32>> + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]] + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]] + // CHECK-DAG: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xf32>> + // CHECK-DAG: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xf32>> // CHECK: %[[CMP:.*]] = llvm.fcmp "olt" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xf32> // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %2[0] : !llvm.array<4 x vector<3xi1>> %0 = arith.cmpf olt, %arg0, %arg1 : vector<4x3xf32> @@ -492,9 +493,10 @@ func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) { // ----- // CHECK-LABEL: func @cmpi_0dvector( +// CHECK-SAME: %[[FARG0:.*]]: vector<i32>, %[[FARG1:.*]]: vector<i32> func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) { - // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast - // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]] + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]] // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32> %0 = arith.cmpi ult, %arg0, %arg1 : vector<i32> func.return @@ -503,11 +505,12 @@ func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) { // ----- // CHECK-LABEL: func @cmpi_2dvector( +// CHECK-SAME: %[[FARG0:.*]]: vector<4x3xi32>, %[[FARG1:.*]]: vector<4x3xi32> func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) { - // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast - // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast - // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi32>> - // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>> + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]] + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]] + // CHECK-DAG: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi32>> + // CHECK-DAG: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>> // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xi32> // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %2[0] : !llvm.array<4 x vector<3xi1>> %0 = arith.cmpi ult, %arg0, %arg1 : vector<4x3xi32> diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir index 6398934..b234cbb 100644 --- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir +++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir @@ -199,9 +199,9 @@ func.func @bitcast_2d(%arg0: vector<2x4xf32>) { // CHECK-LABEL: func @select_2d( func.func @select_2d(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) { - // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0 - // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1 - // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2 + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0 + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1 + // CHECK-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2 // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi1>> // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>> // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[ARG2]][0] : !llvm.array<4 x vector<3xi32>> diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index b861034..599bb6c 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -294,6 +294,7 @@ func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) - // CHECK: %[[MEM:[a-zA-Z0-9]*]]: memref func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?xf32, strided<[?, 1], offset: ?>> { // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEMREF]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64 // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64 @@ -316,7 +317,6 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // Sizes and strides @rank 1: static stride 1, dynamic size unchanged from source memref. // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %c0 = arith.constant 1 : index @@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout( // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index -// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64 // CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64 // CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index // CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64 diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 1b046d3..57f8053 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -275,8 +275,8 @@ func.func @async_cp_i4( // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index func.func @async_cp_zfill_f32_align4( %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { - // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 - // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64 + // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64 // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)> // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64 // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64 @@ -310,8 +310,8 @@ func.func @async_cp_zfill_f32_align4( // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index) func.func @async_cp_zfill_f32_align1( %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { - // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 - // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64 + // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64 // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)> // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64 // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64 @@ -533,7 +533,9 @@ func.func @mbarrier_nocomplete() { } // CHECK-LABEL: func @mbarrier_wait +// CHECK-SAME: %[[barriers:.*]]: !nvgpu.mbarrier.group func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, %token : !tokenType) { +// CHECK: %[[barriersCast:.*]] = builtin.unrealized_conversion_cast %[[barriers]] %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %n = arith.constant 100 : index @@ -545,7 +547,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad // CHECK: scf.for %[[i:.*]] = // CHECK: %[[S2:.+]] = arith.remui %[[i]], %[[c5]] : index // CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[S2]] : index to i64 -// CHECK: %[[S4:.+]] = llvm.extractvalue %0[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[S4:.+]] = llvm.extractvalue %[[barriersCast]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[S5:.+]] = llvm.getelementptr %[[S4]][%[[S3]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 %mbarId = arith.remui %i, %numBarriers : index %isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, !tokenType @@ -871,9 +873,9 @@ func.func @warpgroup_mma_128_128_64( %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>) { -// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64 -// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64 -// CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> +// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64 +// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64 +// CHECK-DAG: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: nvvm.wgmma.fence.aligned // CHECK: %[[UD:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> @@ -1280,9 +1282,9 @@ func.func @warpgroup_matrix_multiply_m128n128k64( to memref<128x128xf32,3> -// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64 -// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64 -// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64 +// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64 +// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> @@ -1296,7 +1298,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64( // CHECK: nvvm.wgmma.fence.aligned // CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> // CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> -// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvvm.wgmma.mma_async // CHECK: nvvm.wgmma.mma_async // CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async |