diff options
author | Matthias Springer <mspringer@nvidia.com> | 2024-06-21 12:09:41 +0200 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2024-06-21 12:09:41 +0200 |
commit | 5326bed0186de6a7690a19945da5d684e831de66 (patch) | |
tree | eb9d113bb9d01af20297f61921e898e64472817f | |
parent | 40800a6a661c5686b389b6dda9c1440c510b46cb (diff) | |
download | llvm-users/matthias-springer/tmp_dialect_conv_decouple.zip llvm-users/matthias-springer/tmp_dialect_conv_decouple.tar.gz llvm-users/matthias-springer/tmp_dialect_conv_decouple.tar.bz2 |
-rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 26 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 388 |
2 files changed, 105 insertions, 309 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 87b5dd9..369f61f 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1070,6 +1070,30 @@ public: // ConversionConfig //===----------------------------------------------------------------------===// +/// The type of materialization. +enum MaterializationKind { + /// This materialization materializes a conversion for an illegal block + /// argument type, to a legal one. + Argument, + + /// This materialization materializes a conversion from an illegal type to a + /// legal one. + Target, + + /// This materialization materializes a conversion from a legal type back to + /// an illegal one. + Source +}; + +struct UnresolvedMaterialization { + UnresolvedMaterialization(UnrealizedConversionCastOp op, MaterializationKind kind, const TypeConverter *converter) + : op(op), kind(kind), converter(converter) {} + + UnrealizedConversionCastOp op; + MaterializationKind kind; + const TypeConverter *converter; +}; + /// Dialect conversion configuration. struct ConversionConfig { /// An optional callback used to notify about match failure diagnostics during @@ -1122,6 +1146,8 @@ struct ConversionConfig { // already been modified) and iterators into past IR state cannot be // represented at the moment. RewriterBase::Listener *listener = nullptr; + + SmallVector<UnresolvedMaterialization> *unresolvedMaterializations = nullptr; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 47e0338..614312c 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -672,21 +672,6 @@ public: void rollback() override; }; -/// The type of materialization. -enum MaterializationKind { - /// This materialization materializes a conversion for an illegal block - /// argument type, to a legal one. - Argument, - - /// This materialization materializes a conversion from an illegal type to a - /// legal one. - Target, - - /// This materialization materializes a conversion from a legal type back to - /// an illegal one. - Source -}; - /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" /// op. Unresolved materializations are erased at the end of the dialect /// conversion. @@ -712,10 +697,10 @@ public: return cast<UnrealizedConversionCastOp>(op); } - void rollback() override; - void cleanup(RewriterBase &rewriter) override; + void rollback() override; + /// Return the type converter of this materialization (which may be null). const TypeConverter *getConverter() const { return converterAndKind.getPointer(); @@ -890,13 +875,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// no new IR is created between calls to `eraseOp`/`eraseBlock`. struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { public: - SingleEraseRewriter(MLIRContext *context) - : RewriterBase(context, /*listener=*/this) {} + SingleEraseRewriter(MLIRContext *context, ConversionPatternRewriterImpl &impl) + : RewriterBase(context, /*listener=*/this), impl(impl) {} /// Erase the given op (unless it was already erased). void eraseOp(Operation *op) override { if (erased.contains(op)) return; + llvm::errs() << "ERASE OP: " << op << "\n"; + if (impl.config.unresolvedMaterializations) { + for (int i = 0; i < impl.config.unresolvedMaterializations->size(); ++i) { + if ((*impl.config.unresolvedMaterializations)[i].op == op) { + impl.config.unresolvedMaterializations->erase(impl.config.unresolvedMaterializations->begin() +i); + break; + } + } + } op->dropAllUses(); RewriterBase::eraseOp(op); } @@ -916,6 +910,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Pointers to all erased operations and blocks. DenseSet<void *> erased; + + ConversionPatternRewriterImpl &impl; }; //===--------------------------------------------------------------------===// @@ -1019,6 +1015,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { }); // Notify the listener that the operation is about to be replaced. + llvm::errs() << "NOTIFY LISTENER OP REPLACED: " << op << "\n"; if (listener) listener->notifyOperationReplaced(op, replacements); @@ -1063,7 +1060,14 @@ void CreateOperationRewrite::rollback() { op->erase(); } +void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) { + llvm::errs() << "add cast to vector: " << getOperation().getOperation() << "\n"; + if (rewriterImpl.config.unresolvedMaterializations) + rewriterImpl.config.unresolvedMaterializations->emplace_back(getOperation(), getMaterializationKind(), getConverter()); +} + void UnresolvedMaterializationRewrite::rollback() { + llvm::errs() << "UnresolvedMaterializationRewrite::rollback!\n"; if (getMaterializationKind() == MaterializationKind::Target) { for (Value input : op->getOperands()) rewriterImpl.mapping.erase(input); @@ -1071,10 +1075,6 @@ void UnresolvedMaterializationRewrite::rollback() { op->erase(); } -void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) { - rewriter.eraseOp(op); -} - void ConversionPatternRewriterImpl::applyRewrites() { // Commit all rewrites. IRRewriter rewriter(context, config.listener); @@ -1082,7 +1082,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { rewrite->commit(rewriter); // Clean up all rewrites. - SingleEraseRewriter eraseRewriter(context); + SingleEraseRewriter eraseRewriter(context, *this); for (auto &rewrite : rewrites) rewrite->cleanup(eraseRewriter); } @@ -2360,12 +2360,6 @@ private: legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl); - /// Legalize any unresolved type materializations. - LogicalResult legalizeUnresolvedMaterializations( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping); - /// Legalize an operation result that was marked as "erased". LogicalResult legalizeErasedResult(Operation *op, OpResult result, @@ -2468,9 +2462,7 @@ LogicalResult OperationConverter::finalize(ConversionPatternRewriter &rewriter) { std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping; ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) || - failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, - inverseMapping))) + if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) return failure(); // Process requested operation replacements. @@ -2552,278 +2544,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( return success(); } -/// Replace the results of a materialization operation with the given values. -static void -replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, - ResultRange matResults, ValueRange values, - DenseMap<Value, SmallVector<Value>> &inverseMapping) { - matResults.replaceAllUsesWith(values); - - // For each of the materialization results, update the inverse mappings to - // point to the replacement values. - for (auto [matResult, newValue] : llvm::zip(matResults, values)) { - auto inverseMapIt = inverseMapping.find(matResult); - if (inverseMapIt == inverseMapping.end()) - continue; - - // Update the reverse mapping, or remove the mapping if we couldn't update - // it. Not being able to update signals that the mapping would have become - // circular (i.e. %foo -> newValue -> %foo), which may occur as values are - // propagated through temporary materializations. We simply drop the - // mapping, and let the post-conversion replacement logic handle updating - // uses. - for (Value inverseMapVal : inverseMapIt->second) - if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue)) - rewriterImpl.mapping.erase(inverseMapVal); - } -} - -/// Compute all of the unresolved materializations that will persist beyond the -/// conversion process, and require inserting a proper user materialization for. -static void computeNecessaryMaterializations( - DenseMap<Operation *, UnresolvedMaterializationRewrite *> - &materializationOps, - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - DenseMap<Value, SmallVector<Value>> &inverseMapping, - SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) { - auto isLive = [&](Value value) { - auto findFn = [&](Operation *user) { - auto matIt = materializationOps.find(user); - if (matIt != materializationOps.end()) - return !necessaryMaterializations.count(matIt->second); - return rewriterImpl.isOpIgnored(user); - }; - // This value may be replacing another value that has a live user. - for (Value inv : inverseMapping.lookup(value)) - if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end()) - return true; - // Or have live users itself. - return llvm::find_if_not(value.getUsers(), findFn) != value.user_end(); - }; - - llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue = - [&](Value invalidRoot, Value value, Type type) { - // Check to see if the input operation was remapped to a variant of the - // output. - Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); - if (remappedValue.getType() == type && remappedValue != invalidRoot) - return remappedValue; - - // Check to see if the input is a materialization operation that - // provides an inverse conversion. We just check blindly for - // UnrealizedConversionCastOp here, but it has no effect on correctness. - auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>(); - if (inputCastOp && inputCastOp->getNumOperands() == 1) - return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0), - type); - - return Value(); - }; - - SetVector<UnresolvedMaterializationRewrite *> worklist; - for (auto &rewrite : rewriterImpl.rewrites) { - auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get()); - if (!mat) - continue; - materializationOps.try_emplace(mat->getOperation(), mat); - worklist.insert(mat); - } - while (!worklist.empty()) { - UnresolvedMaterializationRewrite *mat = worklist.pop_back_val(); - UnrealizedConversionCastOp op = mat->getOperation(); - - // We currently only handle target materializations here. - assert(op->getNumResults() == 1 && "unexpected materialization type"); - OpResult opResult = op->getOpResult(0); - Type outputType = opResult.getType(); - Operation::operand_range inputOperands = op.getOperands(); - - // Try to forward propagate operands for user conversion casts that result - // in the input types of the current cast. - for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) { - auto castOp = dyn_cast<UnrealizedConversionCastOp>(user); - if (!castOp) - continue; - if (castOp->getResultTypes() == inputOperands.getTypes()) { - replaceMaterialization(rewriterImpl, opResult, inputOperands, - inverseMapping); - necessaryMaterializations.remove(materializationOps.lookup(user)); - } - } - - // Try to avoid materializing a resolved materialization if possible. - // Handle the case of a 1-1 materialization. - if (inputOperands.size() == 1) { - // Check to see if the input operation was remapped to a variant of the - // output. - Value remappedValue = - lookupRemappedValue(opResult, inputOperands[0], outputType); - if (remappedValue && remappedValue != opResult) { - replaceMaterialization(rewriterImpl, opResult, remappedValue, - inverseMapping); - necessaryMaterializations.remove(mat); - continue; - } - } else { - // TODO: Avoid materializing other types of conversions here. - } - - // If the materialization does not have any live users, we don't need to - // generate a user materialization for it. - bool isMaterializationLive = isLive(opResult); - if (!isMaterializationLive) - continue; - if (!necessaryMaterializations.insert(mat)) - continue; - - // Reprocess input materializations to see if they have an updated status. - for (Value input : inputOperands) { - if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) { - if (auto *mat = materializationOps.lookup(parentOp)) - worklist.insert(mat); - } - } - } -} - -/// Legalize the given unresolved materialization. Returns success if the -/// materialization was legalized, failure otherise. -static LogicalResult legalizeUnresolvedMaterialization( - UnresolvedMaterializationRewrite &mat, - DenseMap<Operation *, UnresolvedMaterializationRewrite *> - &materializationOps, - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - DenseMap<Value, SmallVector<Value>> &inverseMapping) { - auto findLiveUser = [&](auto &&users) { - auto liveUserIt = llvm::find_if_not( - users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); }); - return liveUserIt == users.end() ? nullptr : *liveUserIt; - }; - - llvm::unique_function<Value(Value, Type)> lookupRemappedValue = - [&](Value value, Type type) { - // Check to see if the input operation was remapped to a variant of the - // output. - Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); - if (remappedValue.getType() == type) - return remappedValue; - return Value(); - }; - - UnrealizedConversionCastOp op = mat.getOperation(); - if (!rewriterImpl.ignoredOps.insert(op)) - return success(); - - // We currently only handle target materializations here. - OpResult opResult = op->getOpResult(0); - Operation::operand_range inputOperands = op.getOperands(); - Type outputType = opResult.getType(); - - // If any input to this materialization is another materialization, resolve - // the input first. - for (Value value : op->getOperands()) { - auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>(); - if (!valueCast) - continue; - - auto matIt = materializationOps.find(valueCast); - if (matIt != materializationOps.end()) - if (failed(legalizeUnresolvedMaterialization( - *matIt->second, materializationOps, rewriter, rewriterImpl, - inverseMapping))) - return failure(); - } - - // Perform a last ditch attempt to avoid materializing a resolved - // materialization if possible. - // Handle the case of a 1-1 materialization. - if (inputOperands.size() == 1) { - // Check to see if the input operation was remapped to a variant of the - // output. - Value remappedValue = lookupRemappedValue(inputOperands[0], outputType); - if (remappedValue && remappedValue != opResult) { - replaceMaterialization(rewriterImpl, opResult, remappedValue, - inverseMapping); - return success(); - } - } else { - // TODO: Avoid materializing other types of conversions here. - } - - // Try to materialize the conversion. - if (const TypeConverter *converter = mat.getConverter()) { - rewriter.setInsertionPoint(op); - Value newMaterialization; - switch (mat.getMaterializationKind()) { - case MaterializationKind::Argument: - // Try to materialize an argument conversion. - // FIXME: The current argument materialization hook expects the original - // output type, even though it doesn't use that as the actual output type - // of the generated IR. The output type is just used as an indicator of - // the type of materialization to do. This behavior is really awkward in - // that it diverges from the behavior of the other hooks, and can be - // easily misunderstood. We should clean up the argument hooks to better - // represent the desired invariants we actually care about. - newMaterialization = converter->materializeArgumentConversion( - rewriter, op->getLoc(), mat.getOrigArgType(), inputOperands); - if (newMaterialization) - break; - - // If an argument materialization failed, fallback to trying a target - // materialization. - [[fallthrough]]; - case MaterializationKind::Target: - newMaterialization = converter->materializeTargetConversion( - rewriter, op->getLoc(), outputType, inputOperands); - break; - case MaterializationKind::Source: - newMaterialization = converter->materializeSourceConversion( - rewriter, op->getLoc(), outputType, inputOperands); - break; - } - if (newMaterialization) { - replaceMaterialization(rewriterImpl, opResult, newMaterialization, - inverseMapping); - return success(); - } - } - - InFlightDiagnostic diag = op->emitError() - << "failed to legalize unresolved materialization " - "from (" - << inputOperands.getTypes() << ") to " << outputType - << " that remained live after conversion"; - if (Operation *liveUser = findLiveUser(op->getUsers())) { - diag.attachNote(liveUser->getLoc()) - << "see existing live user here: " << *liveUser; - } - return failure(); -} - -LogicalResult OperationConverter::legalizeUnresolvedMaterializations( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) { - inverseMapping = rewriterImpl.mapping.getInverse(); - - // As an initial step, compute all of the inserted materializations that we - // expect to persist beyond the conversion process. - DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps; - SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations; - computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, - *inverseMapping, necessaryMaterializations); - - // Once computed, legalize any necessary materializations. - for (auto *mat : necessaryMaterializations) { - if (failed(legalizeUnresolvedMaterialization( - *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping))) - return failure(); - } - return success(); -} - LogicalResult OperationConverter::legalizeErasedResult( Operation *op, OpResult result, ConversionPatternRewriterImpl &rewriterImpl) { @@ -2895,8 +2615,11 @@ LogicalResult OperationConverter::legalizeChangedResultType( // Materialize a conversion for this live result value. Type resultType = result.getType(); - Value convertedValue = replConverter->materializeSourceConversion( - rewriter, op->getLoc(), resultType, newValue); + Value convertedValue = rewriterImpl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(newValue), + op->getLoc(), /*inputs=*/newValue, /*outputType=*/resultType, + /*origArgType=*/{}, replConverter); + if (!convertedValue) return emitConversionError(); @@ -3429,15 +3152,56 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { // Op Conversion Entry Points //===----------------------------------------------------------------------===// +static void resolveMaterialization(const UnresolvedMaterialization &mat) { + if (!mat.converter) + return; + IRRewriter rewriter(mat.op); + switch (mat.kind) { + case MaterializationKind::Argument: + { + assert(mat.op->getNumResults() == 1); + Value newMaterialization = mat.converter->materializeArgumentConversion( + rewriter, mat.op->getLoc(), mat.op->getResult(0).getType(), mat.op->getOperands()); + if(newMaterialization) + rewriter.replaceOp(mat.op, newMaterialization); + } + break; + case MaterializationKind::Target: + { + assert(mat.op->getNumResults() == 1); + Value newMaterialization = mat.converter->materializeTargetConversion( + rewriter, mat.op->getLoc(), mat.op->getResult(0).getType(), mat.op->getOperands()); + if(newMaterialization) + rewriter.replaceOp(mat.op, newMaterialization); + } + break; + case MaterializationKind::Source: + { + assert(mat.op->getNumResults() == 1); + Value newMaterialization = mat.converter->materializeArgumentConversion( + rewriter, mat.op->getLoc(), mat.op->getResult(0).getType(), mat.op->getOperands()); + if(newMaterialization) + rewriter.replaceOp(mat.op, newMaterialization); + } + break; + } +} + //===----------------------------------------------------------------------===// // Partial Conversion LogicalResult mlir::applyPartialConversion( ArrayRef<Operation *> ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config) { + SmallVector<UnresolvedMaterialization> unresolvedMaterializations; + config.unresolvedMaterializations = &unresolvedMaterializations; OperationConverter opConverter(target, patterns, config, OpConversionMode::Partial); - return opConverter.convertOperations(ops); + LogicalResult status = opConverter.convertOperations(ops); + if (failed(status)) return failure(); + for (auto mat : unresolvedMaterializations) + resolveMaterialization(mat); + return success(); } LogicalResult mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, @@ -3453,9 +3217,15 @@ LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config) { + SmallVector<UnresolvedMaterialization> unresolvedMaterializations; + config.unresolvedMaterializations = &unresolvedMaterializations; OperationConverter opConverter(target, patterns, config, OpConversionMode::Full); - return opConverter.convertOperations(ops); + LogicalResult status = opConverter.convertOperations(ops); + if (failed(status)) return failure(); + for (auto mat : unresolvedMaterializations) + resolveMaterialization(mat); + return success(); } LogicalResult mlir::applyFullConversion(Operation *op, const ConversionTarget &target, |