diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 26 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 82 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 85 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 81 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 120 |
5 files changed, 301 insertions, 93 deletions
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index ec182f1..9348d3c1 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -865,13 +865,7 @@ struct NVGPUMBarrierArriveLowering adaptor.getMbarId(), rewriter); Type tokenType = getTypeConverter()->convertType( nvgpu::MBarrierTokenType::get(op->getContext())); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, - barrier); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, - barrier); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier); return success(); } }; @@ -892,13 +886,8 @@ struct NVGPUMBarrierArriveNoCompleteLowering Type tokenType = getTypeConverter()->convertType( nvgpu::MBarrierTokenType::get(op->getContext())); Value count = truncToI32(b, adaptor.getCount()); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( - op, tokenType, barrier, count); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( - op, tokenType, barrier, count); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( + op, tokenType, barrier, count); return success(); } }; @@ -915,13 +904,8 @@ struct NVGPUMBarrierTestWaitLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Type retType = rewriter.getI1Type(); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( - op, retType, barrier, adaptor.getToken()); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( - op, retType, barrier, adaptor.getToken()); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier, + adaptor.getToken()); return success(); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 262d9b7..d43f881 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1752,15 +1752,21 @@ std::string NVVM::MBarrierInitOp::getPtx() { // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInitOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) - ? llvm::Intrinsic::nvvm_mbarrier_init_shared - : llvm::Intrinsic::nvvm_mbarrier_init; + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_init_shared + : llvm::Intrinsic::nvvm_mbarrier_init; // Fill the Intrinsic Args llvm::SmallVector<llvm::Value *> args; @@ -1773,16 +1779,72 @@ mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast<NVVM::MBarrierInvalOp>(op); - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) - .getAddressSpace(); - llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared ? llvm::Intrinsic::nvvm_mbarrier_inval_shared : llvm::Intrinsic::nvvm_mbarrier_inval; return {id, {mt.lookupValue(thisOp.getAddr())}}; } +mlir::NVVM::IDArgPair MBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared + ? llvm::Intrinsic::nvvm_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive; + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + +mlir::NVVM::IDArgPair MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierArriveNocompleteOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = + isShared ? llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete_shared + : llvm::Intrinsic::nvvm_mbarrier_arrive_noComplete; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + llvm::Intrinsic::ID id = isShared + ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared + : llvm::Intrinsic::nvvm_mbarrier_test_wait; + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getState())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncMBarrierArriveOp>(op); + bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); + + llvm::Intrinsic::ID id; + if (thisOp.getNoinc()) { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_noinc; + } else { + id = isShared ? llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive_shared + : llvm::Intrinsic::nvvm_cp_async_mbarrier_arrive; + } + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 397107b..fb5d1e7 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -280,27 +280,82 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, FailureOr<SmallVector<Value>> LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) { - // TODO: handle order attribute - auto hasDefaultOrder = [&]() { - DenseI32ArrayAttr order = getOrder(); - return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>( - llvm::reverse(order.asArrayRef()))); - }; - if (!hasDefaultOrder()) - return mlir::emitError(loc, "order attribute is currently not supported."); - SmallVector<int64_t> layout; + SmallVector<int64_t> sgLayoutInt; if (isForWorkgroup()) { - layout = getEffectiveSgLayoutAsInt(); + sgLayoutInt = getEffectiveSgLayoutAsInt(); } else if (isForSubgroup()) { - layout = getEffectiveLaneLayoutAsInt(); + sgLayoutInt = getEffectiveLaneLayoutAsInt(); } else { return failure(); } - auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { - return builder.createOrFold<arith::ConstantIndexOp>(loc, d); - }); - return affine::delinearizeIndex(builder, loc, linearId, dims); + DenseI32ArrayAttr orderAttr = getOrder(); + + // Handle order attribute + SmallVector<int64_t> order; + if (orderAttr && !orderAttr.empty()) { + order = llvm::to_vector( + llvm::map_range(orderAttr.asArrayRef(), + [](int32_t idx) { return static_cast<int64_t>(idx); })); + } else { + // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc. + order = llvm::to_vector( + llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size()))); + } + + if (order.size() != sgLayoutInt.size()) { + return failure(); + } + + SmallVector<Value> result(sgLayoutInt.size()); + Value remaining = linearId; + + /// Process dimensions in the order they appear in the order array + /// The first dimension in order is the fastest-changing + /// + /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]: + /// + /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx], + /// result=[?,?,?] + /// + /// i=0 (process columns, dimIdx=2, dimSize=4): + /// result[2] = 22 % 4 = 2 (column coordinate) + /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed) + /// + /// i=1 (process rows, dimIdx=1, dimSize=4): + /// result[1] = 5 % 4 = 1 (row coordinate) + /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed) + /// + /// i=2 (process layers, dimIdx=0, dimSize=2): + /// result[0] = 1 % 2 = 1 (layer coordinate) + /// (no remaining update - last iteration) + /// + /// Final result: [1,1,2] = Layer 1, Row 1, Column 2 + for (size_t i = 0; i < order.size(); ++i) { + int64_t dimIdx = order[i]; + int64_t dimSize = sgLayoutInt[dimIdx]; + + Value dimSizeVal = + builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize); + + /// Extract the coordinate for this dimension using modulo operation + /// This gives us "how far within this dimension" we are + /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within + /// this dimension) + result[dimIdx] = + builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal); + + /// Update remaining for the next dimension by removing what we've already + /// processed. Division tells us "how many complete groups of this dimension + /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've + /// completed 5 groups of 4) Skip this for the last iteration since there's + /// no next dimension to process + if (i < order.size() - 1) { + remaining = + builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal); + } + } + return result; } /// Implements DistributeLayoutAttr::computeDistributedCoords to generate diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index d12a04df..0a9ef0a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1219,6 +1219,70 @@ struct WgToSgMultiDimReductionOp } }; +// This pattern transforms vector.transpose ops to work at subgroup level. +struct WgToSgVectorTransposeOp + : public OpConversionPattern<vector::TransposeOp> { + using OpConversionPattern<vector::TransposeOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + + ArrayRef<int64_t> wgShape = resultType.getShape(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(op.getVector()); + if (!sourceLayout || !sourceLayout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sourceSgLayout = + sourceLayout.getEffectiveSgLayoutAsInt(); + SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt(); + DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); + DenseI32ArrayAttr resultOrder = layout.getOrder(); + + if (!sourceOrder || !resultOrder) { + return rewriter.notifyMatchFailure( + op, "Both source and result must have order attributes"); + } + + ArrayRef<int64_t> permutation = op.getPermutation(); + size_t permutationSize = permutation.size(); + if (sourceSgLayout.size() != permutationSize || + resultSgLayout.size() != permutationSize) { + return rewriter.notifyMatchFailure( + op, "Layouts and permutation must have the same rank"); + } + + // Check that sgLayout, sgData & order are properly transposed for source + // and result + if (!layout.isTransposeOf(sourceLayout, permutation)) + return rewriter.notifyMatchFailure( + op, "Result layout is not a valid transpose of source layout " + "according to permutation"); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + SmallVector<Value> newTransposeOps; + for (auto src : adaptor.getVector()) { + auto newTranspose = vector::TransposeOp::create( + rewriter, op.getLoc(), newResultType, src, permutation); + xegpu::setDistributeLayoutAttr(newTranspose->getResult(0), + layout.dropSgLayoutAndData()); + newTransposeOps.push_back(newTranspose.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newTransposeOps}); + return success(); + } +}; + } // namespace namespace mlir { @@ -1233,7 +1297,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp>(patterns.getContext()); + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1360,7 +1425,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>( + target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, + vector::TransposeOp, vector::BroadcastOp, + vector::MultiDimReductionOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); @@ -1379,16 +1446,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::BroadcastOp>( - [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - - target.addDynamicallyLegalOp<vector::MultiDimReductionOp>( - [=](vector::MultiDimReductionOp op) -> bool { - return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); - }); - target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 2fe0697..f8c38fa 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { return pt; } +namespace { +enum OpConversionMode { + /// In this mode, the conversion will ignore failed conversions to allow + /// illegal operations to co-exist in the IR. + Partial, + + /// In this mode, all operations must be legal for the given target for the + /// conversion to succeed. + Full, + + /// In this mode, operations are analyzed for legality. No actual rewrites are + /// applied to the operations on success. + Analysis, +}; +} // namespace + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// @@ -866,8 +882,9 @@ namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, - const ConversionConfig &config) - : rewriter(rewriter), config(config), + const ConversionConfig &config, + OperationConverter &opConverter) + : rewriter(rewriter), config(config), opConverter(opConverter), notifyingRewriter(rewriter.getContext(), config.listener) {} //===--------------------------------------------------------------------===// @@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// The operation converter to use for recursive legalization. + OperationConverter &opConverter; + /// A set of erased operations. This set is utilized only if /// `allowPatternRollback` is set to "false". Conceptually, this set is /// similar to `replacedOps` (which is maintained when the flag is set to @@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( //===----------------------------------------------------------------------===// ConversionPatternRewriter::ConversionPatternRewriter( - MLIRContext *ctx, const ConversionConfig &config) - : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this, config)) { + MLIRContext *ctx, const ConversionConfig &config, + OperationConverter &opConverter) + : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl( + *this, config, opConverter)) { setListener(impl.get()); } @@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, return success(); } +LogicalResult ConversionPatternRewriter::legalize(Region *r) { + // Fast path: If the region is empty, there is nothing to legalize. + if (r->empty()) + return success(); + + // Gather a list of all operations to legalize. This is done before + // converting the entry block signature because unrealized_conversion_cast + // ops should not be included. + SmallVector<Operation *> ops; + for (Block &b : *r) + for (Operation &op : b) + ops.push_back(&op); + + // If the current pattern runs with a type converter, convert the entry block + // signature. + if (const TypeConverter *converter = impl->currentTypeConverter) { + std::optional<TypeConverter::SignatureConversion> conversion = + converter->convertBlockSignature(&r->front()); + if (!conversion) + return failure(); + applySignatureConversion(&r->front(), *conversion, converter); + } + + // Legalize all operations in the region. + for (Operation *op : ops) + if (failed(legalize(op))) + return failure(); + + return success(); +} + void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues) { @@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts( // OperationConverter //===----------------------------------------------------------------------===// -namespace { -enum OpConversionMode { - /// In this mode, the conversion will ignore failed conversions to allow - /// illegal operations to co-exist in the IR. - Partial, - - /// In this mode, all operations must be legal for the given target for the - /// conversion to succeed. - Full, - - /// In this mode, operations are analyzed for legality. No actual rewrites are - /// applied to the operations on success. - Analysis, -}; -} // namespace - namespace mlir { // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the @@ -3217,16 +3253,20 @@ struct OperationConverter { const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : rewriter(ctx, config), opLegalizer(rewriter, target, patterns), + : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns), mode(mode) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef<Operation *> ops); -private: - /// Converts an operation with the given rewriter. - LogicalResult convert(Operation *op); + /// Converts a single operation. If `isRecursiveLegalization` is "true", the + /// conversion is a recursive legalization request, triggered from within a + /// pattern. In that case, do not emit errors because there will be another + /// attempt at legalizing the operation later (via the regular pre-order + /// legalization mechanism). + LogicalResult convert(Operation *op, bool isRecursiveLegalization = false); +private: /// The rewriter to use when converting operations. ConversionPatternRewriter rewriter; @@ -3238,32 +3278,42 @@ private: }; } // namespace mlir -LogicalResult OperationConverter::convert(Operation *op) { +LogicalResult ConversionPatternRewriter::legalize(Operation *op) { + return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true); +} + +LogicalResult OperationConverter::convert(Operation *op, + bool isRecursiveLegalization) { const ConversionConfig &config = rewriter.getConfig(); // Legalize the given operation. if (failed(opLegalizer.legalize(op))) { // Handle the case of a failed conversion for each of the different modes. // Full conversions expect all operations to be converted. - if (mode == OpConversionMode::Full) - return op->emitError() - << "failed to legalize operation '" << op->getName() << "'"; + if (mode == OpConversionMode::Full) { + if (!isRecursiveLegalization) + op->emitError() << "failed to legalize operation '" << op->getName() + << "'"; + return failure(); + } // Partial conversions allow conversions to fail iff the operation was not // explicitly marked as illegal. If the user provided a `unlegalizedOps` // set, non-legalizable ops are added to that set. if (mode == OpConversionMode::Partial) { - if (opLegalizer.isIllegal(op)) - return op->emitError() - << "failed to legalize operation '" << op->getName() - << "' that was explicitly marked illegal"; - if (config.unlegalizedOps) + if (opLegalizer.isIllegal(op)) { + if (!isRecursiveLegalization) + op->emitError() << "failed to legalize operation '" << op->getName() + << "' that was explicitly marked illegal"; + return failure(); + } + if (config.unlegalizedOps && !isRecursiveLegalization) config.unlegalizedOps->insert(op); } } else if (mode == OpConversionMode::Analysis) { // Analysis conversions don't fail if any operations fail to legalize, // they are only interested in the operations that were successfully // legalized. - if (config.legalizableOps) + if (config.legalizableOps && !isRecursiveLegalization) config.legalizableOps->insert(op); } return success(); |
