diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 84 |
1 files changed, 47 insertions, 37 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 36ee87b..df9700f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput( SmallVector<Value, 1>(replacements.begin(), replacements.end())}; } -LogicalResult TypeConverter::convertType(Type t, - SmallVectorImpl<Type> &results) const { - assert(t && "expected non-null type"); - +/// Internal implementation of the type conversion. +/// This is used with either a Type or a Value as the first argument. +/// - we can cache the context-free conversions until the last registered +/// context-aware conversion. +/// - we can't cache the result of type conversion happening after context-aware +/// conversions, because the type converter may return different results for the +/// same input type. +LogicalResult +TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue, + SmallVectorImpl<Type> &results) const { + assert(typeOrValue && "expected non-null type"); + Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType() + : cast<Type>(typeOrValue); { std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex, std::defer_lock); @@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t, // registered first. size_t currentCount = results.size(); + // We can cache the context-free conversions until the last registered + // context-aware conversion. But only if we're processing a Value right now. + auto isCacheable = [&](int index) { + int numberOfConversionsUntilContextAware = + conversions.size() - 1 - contextAwareTypeConversionsIndex; + return index < numberOfConversionsUntilContextAware; + }; + std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex, std::defer_lock); - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional<LogicalResult> result = converter(t, results)) { - if (t.getContext()->isMultithreadingEnabled()) - cacheWriteLock.lock(); - if (!succeeded(*result)) { - assert(results.size() == currentCount && - "failed type conversion should not change results"); - cachedDirectConversions.try_emplace(t, nullptr); - return failure(); - } - auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); - if (newTypes.size() == 1) - cachedDirectConversions.try_emplace(t, newTypes.front()); - else - cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) { + const ConversionCallbackFn &converter = indexedConverter.value(); + std::optional<LogicalResult> result = converter(typeOrValue, results); + if (!result) { + assert(results.size() == currentCount && + "failed type conversion should not change results"); + continue; + } + if (!isCacheable(indexedConverter.index())) return success(); - } else { + if (t.getContext()->isMultithreadingEnabled()) + cacheWriteLock.lock(); + if (!succeeded(*result)) { assert(results.size() == currentCount && "failed type conversion should not change results"); + cachedDirectConversions.try_emplace(t, nullptr); + return failure(); } + auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); + if (newTypes.size() == 1) + cachedDirectConversions.try_emplace(t, newTypes.front()); + else + cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + return success(); } return failure(); } -LogicalResult TypeConverter::convertType(Value v, +LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl<Type> &results) const { - assert(v && "expected non-null value"); - - // If this type converter does not have context-aware type conversions, call - // the type-based overload, which has caching. - if (!hasContextAwareTypeConversions) - return convertType(v.getType(), results); + return convertTypeImpl(t, results); +} - // Walk the added converters in reverse order to apply the most recently - // registered first. - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional<LogicalResult> result = converter(v, results)) { - if (!succeeded(*result)) - return failure(); - return success(); - } - } - return failure(); +LogicalResult TypeConverter::convertType(Value v, + SmallVectorImpl<Type> &results) const { + return convertTypeImpl(v, results); } Type TypeConverter::convertType(Type t) const { |