diff options
author | Mehdi Amini <joker.eph@gmail.com> | 2025-09-11 15:17:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-09-11 14:17:10 +0000 |
commit | b22f94dcc58e09710c188045b498a201db83d9a2 (patch) | |
tree | bf56938b32307e663cce8895c34ad4f8fef93d09 | |
parent | 4ce74bfb4d218392719ad610c9dfc467042a6b58 (diff) | |
download | llvm-b22f94dcc58e09710c188045b498a201db83d9a2.zip llvm-b22f94dcc58e09710c188045b498a201db83d9a2.tar.gz llvm-b22f94dcc58e09710c188045b498a201db83d9a2.tar.bz2 |
[MLIR] Enable caching of type conversion in the presence of context-aware conversion (#158072)
The current implementation is overly conservative and disable all
possible caching as soon as a context-aware conversion is present.
However the context-aware conversion only affects subsequent converters,
we can cache the previous ones.
This isn't NFC because if fixed a bug where we use to unconditionally
cache when using the `convertType(Type t, ...` API, while now all APIs
are aware of context-aware conversions.
-rw-r--r-- | mlir/docs/DialectConversion.md | 10 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 17 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 84 |
3 files changed, 66 insertions, 45 deletions
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 7070351..5ae3515 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a `Type`. A context-aware conversion function converts a `Value` into a type. The latter allows users to customize type conversion rules based on the IR. -Note: When there is at least one context-aware type conversion function, the -result of type conversions can no longer be cached, which can increase -compilation time. Use this feature with caution! +Note: context-aware type conversion functions impact the ability of the +framework to cache the conversion result. In the absence of a context-aware +conversion, all context-free type conversions can be cached. Otherwise only the +context-free conversions added after a context-aware type conversion can be +cached (conversions are applied in reverse order). +As such it is advised to add context-aware conversions as early as possible in +the sequence of `addConversion` calls (so that they apply last). A `materialization` describes how a list of values should be converted to a list of values with specific types. An important distinction from a diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 6949f4a..a096f82 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -433,7 +433,7 @@ private: std::is_same_v<T, Value>, ConversionCallbackFn> wrapCallback(FnT &&callback) { - hasContextAwareTypeConversions = true; + contextAwareTypeConversionsIndex = conversions.size(); return [callback = std::forward<FnT>(callback)]( PointerUnion<Type, Value> typeOrValue, SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { @@ -555,6 +555,10 @@ private: cachedMultiConversions.clear(); } + /// Internal implementation of the type conversion. + LogicalResult convertTypeImpl(PointerUnion<Type, Value> t, + SmallVectorImpl<Type> &results) const; + /// The set of registered conversion functions. SmallVector<ConversionCallbackFn, 4> conversions; @@ -575,10 +579,13 @@ private: mutable llvm::sys::SmartRWMutex<true> cacheMutex; /// Whether the type converter has context-aware type conversions. I.e., /// conversion rules that depend on the SSA value instead of just the type. - /// Type conversion caching is deactivated when there are context-aware - /// conversions because the type converter may return different results for - /// the same input type. - bool hasContextAwareTypeConversions = false; + /// We store here the index in the `conversions` vector of the last added + /// context-aware conversion, if any. This is useful because we can't cache + /// the result of type conversion happening after context-aware conversions, + /// because the type converter may return different results for the same input + /// type. This is why it is recommened to add context-aware conversions first, + /// any context-free conversions after will benefit from caching. + int contextAwareTypeConversionsIndex = -1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 36ee87b..df9700f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput( SmallVector<Value, 1>(replacements.begin(), replacements.end())}; } -LogicalResult TypeConverter::convertType(Type t, - SmallVectorImpl<Type> &results) const { - assert(t && "expected non-null type"); - +/// Internal implementation of the type conversion. +/// This is used with either a Type or a Value as the first argument. +/// - we can cache the context-free conversions until the last registered +/// context-aware conversion. +/// - we can't cache the result of type conversion happening after context-aware +/// conversions, because the type converter may return different results for the +/// same input type. +LogicalResult +TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue, + SmallVectorImpl<Type> &results) const { + assert(typeOrValue && "expected non-null type"); + Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType() + : cast<Type>(typeOrValue); { std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex, std::defer_lock); @@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t, // registered first. size_t currentCount = results.size(); + // We can cache the context-free conversions until the last registered + // context-aware conversion. But only if we're processing a Value right now. + auto isCacheable = [&](int index) { + int numberOfConversionsUntilContextAware = + conversions.size() - 1 - contextAwareTypeConversionsIndex; + return index < numberOfConversionsUntilContextAware; + }; + std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex, std::defer_lock); - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional<LogicalResult> result = converter(t, results)) { - if (t.getContext()->isMultithreadingEnabled()) - cacheWriteLock.lock(); - if (!succeeded(*result)) { - assert(results.size() == currentCount && - "failed type conversion should not change results"); - cachedDirectConversions.try_emplace(t, nullptr); - return failure(); - } - auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); - if (newTypes.size() == 1) - cachedDirectConversions.try_emplace(t, newTypes.front()); - else - cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) { + const ConversionCallbackFn &converter = indexedConverter.value(); + std::optional<LogicalResult> result = converter(typeOrValue, results); + if (!result) { + assert(results.size() == currentCount && + "failed type conversion should not change results"); + continue; + } + if (!isCacheable(indexedConverter.index())) return success(); - } else { + if (t.getContext()->isMultithreadingEnabled()) + cacheWriteLock.lock(); + if (!succeeded(*result)) { assert(results.size() == currentCount && "failed type conversion should not change results"); + cachedDirectConversions.try_emplace(t, nullptr); + return failure(); } + auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); + if (newTypes.size() == 1) + cachedDirectConversions.try_emplace(t, newTypes.front()); + else + cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + return success(); } return failure(); } -LogicalResult TypeConverter::convertType(Value v, +LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl<Type> &results) const { - assert(v && "expected non-null value"); - - // If this type converter does not have context-aware type conversions, call - // the type-based overload, which has caching. - if (!hasContextAwareTypeConversions) - return convertType(v.getType(), results); + return convertTypeImpl(t, results); +} - // Walk the added converters in reverse order to apply the most recently - // registered first. - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional<LogicalResult> result = converter(v, results)) { - if (!succeeded(*result)) - return failure(); - return success(); - } - } - return failure(); +LogicalResult TypeConverter::convertType(Value v, + SmallVectorImpl<Type> &results) const { + return convertTypeImpl(v, results); } Type TypeConverter::convertType(Type t) const { |