aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-06-21 12:09:41 +0200
committerMatthias Springer <mspringer@nvidia.com>2024-06-21 12:09:41 +0200
commit5326bed0186de6a7690a19945da5d684e831de66 (patch)
treeeb9d113bb9d01af20297f61921e898e64472817f
parent40800a6a661c5686b389b6dda9c1440c510b46cb (diff)
downloadllvm-users/matthias-springer/tmp_dialect_conv_decouple.zip
llvm-users/matthias-springer/tmp_dialect_conv_decouple.tar.gz
llvm-users/matthias-springer/tmp_dialect_conv_decouple.tar.bz2
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h26
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp388
2 files changed, 105 insertions, 309 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 87b5dd9..369f61f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1070,6 +1070,30 @@ public:
// ConversionConfig
//===----------------------------------------------------------------------===//
+/// The type of materialization.
+enum MaterializationKind {
+ /// This materialization materializes a conversion for an illegal block
+ /// argument type, to a legal one.
+ Argument,
+
+ /// This materialization materializes a conversion from an illegal type to a
+ /// legal one.
+ Target,
+
+ /// This materialization materializes a conversion from a legal type back to
+ /// an illegal one.
+ Source
+};
+
+struct UnresolvedMaterialization {
+ UnresolvedMaterialization(UnrealizedConversionCastOp op, MaterializationKind kind, const TypeConverter *converter)
+ : op(op), kind(kind), converter(converter) {}
+
+ UnrealizedConversionCastOp op;
+ MaterializationKind kind;
+ const TypeConverter *converter;
+};
+
/// Dialect conversion configuration.
struct ConversionConfig {
/// An optional callback used to notify about match failure diagnostics during
@@ -1122,6 +1146,8 @@ struct ConversionConfig {
// already been modified) and iterators into past IR state cannot be
// represented at the moment.
RewriterBase::Listener *listener = nullptr;
+
+ SmallVector<UnresolvedMaterialization> *unresolvedMaterializations = nullptr;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 47e0338..614312c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -672,21 +672,6 @@ public:
void rollback() override;
};
-/// The type of materialization.
-enum MaterializationKind {
- /// This materialization materializes a conversion for an illegal block
- /// argument type, to a legal one.
- Argument,
-
- /// This materialization materializes a conversion from an illegal type to a
- /// legal one.
- Target,
-
- /// This materialization materializes a conversion from a legal type back to
- /// an illegal one.
- Source
-};
-
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
/// op. Unresolved materializations are erased at the end of the dialect
/// conversion.
@@ -712,10 +697,10 @@ public:
return cast<UnrealizedConversionCastOp>(op);
}
- void rollback() override;
-
void cleanup(RewriterBase &rewriter) override;
+ void rollback() override;
+
/// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
return converterAndKind.getPointer();
@@ -890,13 +875,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
public:
- SingleEraseRewriter(MLIRContext *context)
- : RewriterBase(context, /*listener=*/this) {}
+ SingleEraseRewriter(MLIRContext *context, ConversionPatternRewriterImpl &impl)
+ : RewriterBase(context, /*listener=*/this), impl(impl) {}
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
if (erased.contains(op))
return;
+ llvm::errs() << "ERASE OP: " << op << "\n";
+ if (impl.config.unresolvedMaterializations) {
+ for (int i = 0; i < impl.config.unresolvedMaterializations->size(); ++i) {
+ if ((*impl.config.unresolvedMaterializations)[i].op == op) {
+ impl.config.unresolvedMaterializations->erase(impl.config.unresolvedMaterializations->begin() +i);
+ break;
+ }
+ }
+ }
op->dropAllUses();
RewriterBase::eraseOp(op);
}
@@ -916,6 +910,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Pointers to all erased operations and blocks.
DenseSet<void *> erased;
+
+ ConversionPatternRewriterImpl &impl;
};
//===--------------------------------------------------------------------===//
@@ -1019,6 +1015,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
});
// Notify the listener that the operation is about to be replaced.
+ llvm::errs() << "NOTIFY LISTENER OP REPLACED: " << op << "\n";
if (listener)
listener->notifyOperationReplaced(op, replacements);
@@ -1063,7 +1060,14 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+ llvm::errs() << "add cast to vector: " << getOperation().getOperation() << "\n";
+ if (rewriterImpl.config.unresolvedMaterializations)
+ rewriterImpl.config.unresolvedMaterializations->emplace_back(getOperation(), getMaterializationKind(), getConverter());
+}
+
void UnresolvedMaterializationRewrite::rollback() {
+ llvm::errs() << "UnresolvedMaterializationRewrite::rollback!\n";
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
@@ -1071,10 +1075,6 @@ void UnresolvedMaterializationRewrite::rollback() {
op->erase();
}
-void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
- rewriter.eraseOp(op);
-}
-
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
IRRewriter rewriter(context, config.listener);
@@ -1082,7 +1082,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrite->commit(rewriter);
// Clean up all rewrites.
- SingleEraseRewriter eraseRewriter(context);
+ SingleEraseRewriter eraseRewriter(context, *this);
for (auto &rewrite : rewrites)
rewrite->cleanup(eraseRewriter);
}
@@ -2360,12 +2360,6 @@ private:
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl);
- /// Legalize any unresolved type materializations.
- LogicalResult legalizeUnresolvedMaterializations(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
-
/// Legalize an operation result that was marked as "erased".
LogicalResult
legalizeErasedResult(Operation *op, OpResult result,
@@ -2468,9 +2462,7 @@ LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
- failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
- inverseMapping)))
+ if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
// Process requested operation replacements.
@@ -2552,278 +2544,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
return success();
}
-/// Replace the results of a materialization operation with the given values.
-static void
-replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
- ResultRange matResults, ValueRange values,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- matResults.replaceAllUsesWith(values);
-
- // For each of the materialization results, update the inverse mappings to
- // point to the replacement values.
- for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
- auto inverseMapIt = inverseMapping.find(matResult);
- if (inverseMapIt == inverseMapping.end())
- continue;
-
- // Update the reverse mapping, or remove the mapping if we couldn't update
- // it. Not being able to update signals that the mapping would have become
- // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
- // propagated through temporary materializations. We simply drop the
- // mapping, and let the post-conversion replacement logic handle updating
- // uses.
- for (Value inverseMapVal : inverseMapIt->second)
- if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
- rewriterImpl.mapping.erase(inverseMapVal);
- }
-}
-
-/// Compute all of the unresolved materializations that will persist beyond the
-/// conversion process, and require inserting a proper user materialization for.
-static void computeNecessaryMaterializations(
- DenseMap<Operation *, UnresolvedMaterializationRewrite *>
- &materializationOps,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping,
- SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
- auto isLive = [&](Value value) {
- auto findFn = [&](Operation *user) {
- auto matIt = materializationOps.find(user);
- if (matIt != materializationOps.end())
- return !necessaryMaterializations.count(matIt->second);
- return rewriterImpl.isOpIgnored(user);
- };
- // This value may be replacing another value that has a live user.
- for (Value inv : inverseMapping.lookup(value))
- if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
- return true;
- // Or have live users itself.
- return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
- };
-
- llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
- [&](Value invalidRoot, Value value, Type type) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
- if (remappedValue.getType() == type && remappedValue != invalidRoot)
- return remappedValue;
-
- // Check to see if the input is a materialization operation that
- // provides an inverse conversion. We just check blindly for
- // UnrealizedConversionCastOp here, but it has no effect on correctness.
- auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (inputCastOp && inputCastOp->getNumOperands() == 1)
- return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
- type);
-
- return Value();
- };
-
- SetVector<UnresolvedMaterializationRewrite *> worklist;
- for (auto &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- materializationOps.try_emplace(mat->getOperation(), mat);
- worklist.insert(mat);
- }
- while (!worklist.empty()) {
- UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
- UnrealizedConversionCastOp op = mat->getOperation();
-
- // We currently only handle target materializations here.
- assert(op->getNumResults() == 1 && "unexpected materialization type");
- OpResult opResult = op->getOpResult(0);
- Type outputType = opResult.getType();
- Operation::operand_range inputOperands = op.getOperands();
-
- // Try to forward propagate operands for user conversion casts that result
- // in the input types of the current cast.
- for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
- auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
- if (!castOp)
- continue;
- if (castOp->getResultTypes() == inputOperands.getTypes()) {
- replaceMaterialization(rewriterImpl, opResult, inputOperands,
- inverseMapping);
- necessaryMaterializations.remove(materializationOps.lookup(user));
- }
- }
-
- // Try to avoid materializing a resolved materialization if possible.
- // Handle the case of a 1-1 materialization.
- if (inputOperands.size() == 1) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue =
- lookupRemappedValue(opResult, inputOperands[0], outputType);
- if (remappedValue && remappedValue != opResult) {
- replaceMaterialization(rewriterImpl, opResult, remappedValue,
- inverseMapping);
- necessaryMaterializations.remove(mat);
- continue;
- }
- } else {
- // TODO: Avoid materializing other types of conversions here.
- }
-
- // If the materialization does not have any live users, we don't need to
- // generate a user materialization for it.
- bool isMaterializationLive = isLive(opResult);
- if (!isMaterializationLive)
- continue;
- if (!necessaryMaterializations.insert(mat))
- continue;
-
- // Reprocess input materializations to see if they have an updated status.
- for (Value input : inputOperands) {
- if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
- if (auto *mat = materializationOps.lookup(parentOp))
- worklist.insert(mat);
- }
- }
- }
-}
-
-/// Legalize the given unresolved materialization. Returns success if the
-/// materialization was legalized, failure otherise.
-static LogicalResult legalizeUnresolvedMaterialization(
- UnresolvedMaterializationRewrite &mat,
- DenseMap<Operation *, UnresolvedMaterializationRewrite *>
- &materializationOps,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- auto findLiveUser = [&](auto &&users) {
- auto liveUserIt = llvm::find_if_not(
- users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
- return liveUserIt == users.end() ? nullptr : *liveUserIt;
- };
-
- llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
- [&](Value value, Type type) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
- if (remappedValue.getType() == type)
- return remappedValue;
- return Value();
- };
-
- UnrealizedConversionCastOp op = mat.getOperation();
- if (!rewriterImpl.ignoredOps.insert(op))
- return success();
-
- // We currently only handle target materializations here.
- OpResult opResult = op->getOpResult(0);
- Operation::operand_range inputOperands = op.getOperands();
- Type outputType = opResult.getType();
-
- // If any input to this materialization is another materialization, resolve
- // the input first.
- for (Value value : op->getOperands()) {
- auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (!valueCast)
- continue;
-
- auto matIt = materializationOps.find(valueCast);
- if (matIt != materializationOps.end())
- if (failed(legalizeUnresolvedMaterialization(
- *matIt->second, materializationOps, rewriter, rewriterImpl,
- inverseMapping)))
- return failure();
- }
-
- // Perform a last ditch attempt to avoid materializing a resolved
- // materialization if possible.
- // Handle the case of a 1-1 materialization.
- if (inputOperands.size() == 1) {
- // Check to see if the input operation was remapped to a variant of the
- // output.
- Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
- if (remappedValue && remappedValue != opResult) {
- replaceMaterialization(rewriterImpl, opResult, remappedValue,
- inverseMapping);
- return success();
- }
- } else {
- // TODO: Avoid materializing other types of conversions here.
- }
-
- // Try to materialize the conversion.
- if (const TypeConverter *converter = mat.getConverter()) {
- rewriter.setInsertionPoint(op);
- Value newMaterialization;
- switch (mat.getMaterializationKind()) {
- case MaterializationKind::Argument:
- // Try to materialize an argument conversion.
- // FIXME: The current argument materialization hook expects the original
- // output type, even though it doesn't use that as the actual output type
- // of the generated IR. The output type is just used as an indicator of
- // the type of materialization to do. This behavior is really awkward in
- // that it diverges from the behavior of the other hooks, and can be
- // easily misunderstood. We should clean up the argument hooks to better
- // represent the desired invariants we actually care about.
- newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), mat.getOrigArgType(), inputOperands);
- if (newMaterialization)
- break;
-
- // If an argument materialization failed, fallback to trying a target
- // materialization.
- [[fallthrough]];
- case MaterializationKind::Target:
- newMaterialization = converter->materializeTargetConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- break;
- case MaterializationKind::Source:
- newMaterialization = converter->materializeSourceConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- break;
- }
- if (newMaterialization) {
- replaceMaterialization(rewriterImpl, opResult, newMaterialization,
- inverseMapping);
- return success();
- }
- }
-
- InFlightDiagnostic diag = op->emitError()
- << "failed to legalize unresolved materialization "
- "from ("
- << inputOperands.getTypes() << ") to " << outputType
- << " that remained live after conversion";
- if (Operation *liveUser = findLiveUser(op->getUsers())) {
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- }
- return failure();
-}
-
-LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
- inverseMapping = rewriterImpl.mapping.getInverse();
-
- // As an initial step, compute all of the inserted materializations that we
- // expect to persist beyond the conversion process.
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
- SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
- computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
- *inverseMapping, necessaryMaterializations);
-
- // Once computed, legalize any necessary materializations.
- for (auto *mat : necessaryMaterializations) {
- if (failed(legalizeUnresolvedMaterialization(
- *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
- return failure();
- }
- return success();
-}
-
LogicalResult OperationConverter::legalizeErasedResult(
Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl) {
@@ -2895,8 +2615,11 @@ LogicalResult OperationConverter::legalizeChangedResultType(
// Materialize a conversion for this live result value.
Type resultType = result.getType();
- Value convertedValue = replConverter->materializeSourceConversion(
- rewriter, op->getLoc(), resultType, newValue);
+ Value convertedValue = rewriterImpl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(newValue),
+ op->getLoc(), /*inputs=*/newValue, /*outputType=*/resultType,
+ /*origArgType=*/{}, replConverter);
+
if (!convertedValue)
return emitConversionError();
@@ -3429,15 +3152,56 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
+static void resolveMaterialization(const UnresolvedMaterialization &mat) {
+ if (!mat.converter)
+ return;
+ IRRewriter rewriter(mat.op);
+ switch (mat.kind) {
+ case MaterializationKind::Argument:
+ {
+ assert(mat.op->getNumResults() == 1);
+ Value newMaterialization = mat.converter->materializeArgumentConversion(
+ rewriter, mat.op->getLoc(), mat.op->getResult(0).getType(), mat.op->getOperands());
+ if(newMaterialization)
+ rewriter.replaceOp(mat.op, newMaterialization);
+ }
+ break;
+ case MaterializationKind::Target:
+ {
+ assert(mat.op->getNumResults() == 1);
+ Value newMaterialization = mat.converter->materializeTargetConversion(
+ rewriter, mat.op->getLoc(), mat.op->getResult(0).getType(), mat.op->getOperands());
+ if(newMaterialization)
+ rewriter.replaceOp(mat.op, newMaterialization);
+ }
+ break;
+ case MaterializationKind::Source:
+ {
+ assert(mat.op->getNumResults() == 1);
+ Value newMaterialization = mat.converter->materializeArgumentConversion(
+ rewriter, mat.op->getLoc(), mat.op->getResult(0).getType(), mat.op->getOperands());
+ if(newMaterialization)
+ rewriter.replaceOp(mat.op, newMaterialization);
+ }
+ break;
+ }
+}
+
//===----------------------------------------------------------------------===//
// Partial Conversion
LogicalResult mlir::applyPartialConversion(
ArrayRef<Operation *> ops, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+ SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
+ config.unresolvedMaterializations = &unresolvedMaterializations;
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Partial);
- return opConverter.convertOperations(ops);
+ LogicalResult status = opConverter.convertOperations(ops);
+ if (failed(status)) return failure();
+ for (auto mat : unresolvedMaterializations)
+ resolveMaterialization(mat);
+ return success();
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
@@ -3453,9 +3217,15 @@ LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) {
+ SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
+ config.unresolvedMaterializations = &unresolvedMaterializations;
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Full);
- return opConverter.convertOperations(ops);
+ LogicalResult status = opConverter.convertOperations(ops);
+ if (failed(status)) return failure();
+ for (auto mat : unresolvedMaterializations)
+ resolveMaterialization(mat);
+ return success();
}
LogicalResult mlir::applyFullConversion(Operation *op,
const ConversionTarget &target,