aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/DialectConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp84
1 files changed, 47 insertions, 37 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b..df9700f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput(
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
}
-LogicalResult TypeConverter::convertType(Type t,
- SmallVectorImpl<Type> &results) const {
- assert(t && "expected non-null type");
-
+/// Internal implementation of the type conversion.
+/// This is used with either a Type or a Value as the first argument.
+/// - we can cache the context-free conversions until the last registered
+/// context-aware conversion.
+/// - we can't cache the result of type conversion happening after context-aware
+/// conversions, because the type converter may return different results for the
+/// same input type.
+LogicalResult
+TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
+ SmallVectorImpl<Type> &results) const {
+ assert(typeOrValue && "expected non-null type");
+ Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
+ : cast<Type>(typeOrValue);
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
@@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t,
// registered first.
size_t currentCount = results.size();
+ // We can cache the context-free conversions until the last registered
+ // context-aware conversion. But only if we're processing a Value right now.
+ auto isCacheable = [&](int index) {
+ int numberOfConversionsUntilContextAware =
+ conversions.size() - 1 - contextAwareTypeConversionsIndex;
+ return index < numberOfConversionsUntilContextAware;
+ };
+
std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
std::defer_lock);
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(t, results)) {
- if (t.getContext()->isMultithreadingEnabled())
- cacheWriteLock.lock();
- if (!succeeded(*result)) {
- assert(results.size() == currentCount &&
- "failed type conversion should not change results");
- cachedDirectConversions.try_emplace(t, nullptr);
- return failure();
- }
- auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
- if (newTypes.size() == 1)
- cachedDirectConversions.try_emplace(t, newTypes.front());
- else
- cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+ for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
+ const ConversionCallbackFn &converter = indexedConverter.value();
+ std::optional<LogicalResult> result = converter(typeOrValue, results);
+ if (!result) {
+ assert(results.size() == currentCount &&
+ "failed type conversion should not change results");
+ continue;
+ }
+ if (!isCacheable(indexedConverter.index()))
return success();
- } else {
+ if (t.getContext()->isMultithreadingEnabled())
+ cacheWriteLock.lock();
+ if (!succeeded(*result)) {
assert(results.size() == currentCount &&
"failed type conversion should not change results");
+ cachedDirectConversions.try_emplace(t, nullptr);
+ return failure();
}
+ auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
+ if (newTypes.size() == 1)
+ cachedDirectConversions.try_emplace(t, newTypes.front());
+ else
+ cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+ return success();
}
return failure();
}
-LogicalResult TypeConverter::convertType(Value v,
+LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
- assert(v && "expected non-null value");
-
- // If this type converter does not have context-aware type conversions, call
- // the type-based overload, which has caching.
- if (!hasContextAwareTypeConversions)
- return convertType(v.getType(), results);
+ return convertTypeImpl(t, results);
+}
- // Walk the added converters in reverse order to apply the most recently
- // registered first.
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(v, results)) {
- if (!succeeded(*result))
- return failure();
- return success();
- }
- }
- return failure();
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ return convertTypeImpl(v, results);
}
Type TypeConverter::convertType(Type t) const {