aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-06-23 21:20:27 +0200
committerMatthias Springer <mspringer@nvidia.com>2024-06-23 21:20:27 +0200
commit61e54a1dbfea9d0a177b7a29aa3c5dd8c6b6bb6b (patch)
tree03b914ee9f4e450ca9d6000de9abe913ecab7c8c
parent0ae7616116a9e31171de4b7fb98c18b4c0c92b68 (diff)
downloadllvm-users/matthias-springer/argument_mat_bbarg_list.zip
llvm-users/matthias-springer/argument_mat_bbarg_list.tar.gz
llvm-users/matthias-springer/argument_mat_bbarg_list.tar.bz2
improve addArgumentMaterialization APIusers/matthias-springer/argument_mat_bbarg_list
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h28
-rw-r--r--mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp6
-rw-r--r--mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp16
-rw-r--r--mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp5
-rw-r--r--mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp1
6 files changed, 45 insertions, 13 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f83f3a3..aa30281 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -186,7 +186,7 @@ public:
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
- wrapMaterialization<T>(std::forward<FnT>(callback)));
+ wrapArgumentMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting a legal type to an illegal source type. This is used when
@@ -300,9 +300,9 @@ public:
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
Type resultType,
- ValueRange inputs) const {
- return materializeConversion(argumentMaterializations, builder, loc,
- resultType, inputs);
+ Block::BlockArgListType inputs) const {
+ return materializeArgumentConversion(argumentMaterializations, builder, loc,
+ resultType, inputs);
}
Value materializeSourceConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs) const {
@@ -332,6 +332,8 @@ private:
/// The signature of the callback used to materialize a conversion.
using MaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location)>;
+ using ArgumentMaterializationCallbackFn = std::function<std::optional<Value>(
+ OpBuilder &, Type, Block::BlockArgListType, Location)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -343,6 +345,10 @@ private:
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
OpBuilder &builder, Location loc, Type resultType,
ValueRange inputs) const;
+ Value materializeArgumentConversion(
+ ArrayRef<ArgumentMaterializationCallbackFn> materializations,
+ OpBuilder &builder, Location loc, Type resultType,
+ Block::BlockArgListType inputs) const;
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
@@ -397,6 +403,18 @@ private:
return std::nullopt;
};
}
+ template <typename T, typename FnT>
+ ArgumentMaterializationCallbackFn
+ wrapArgumentMaterialization(FnT &&callback) const {
+ return
+ [callback = std::forward<FnT>(callback)](
+ OpBuilder &builder, Type resultType, Block::BlockArgListType inputs,
+ Location loc) -> std::optional<Value> {
+ if (T derivedType = dyn_cast<T>(resultType))
+ return callback(builder, derivedType, inputs, loc);
+ return std::nullopt;
+ };
+ }
/// Generate a wrapper for the given memory space conversion callback. The
/// callback may take any subclass of `Attribute` and the wrapper will check
@@ -428,7 +446,7 @@ private:
SmallVector<ConversionCallbackFn, 4> conversions;
/// The list of registered materialization functions.
- SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
+ SmallVector<ArgumentMaterializationCallbackFn, 2> argumentMaterializations;
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index f5620a6..3f1114c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -168,14 +168,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
inputs);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs,
+ Block::BlockArgListType inputs,
Location loc) -> std::optional<Value> {
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
- if (!barePtr)
- return std::nullopt;
+ BlockArgument barePtr = inputs.front();
Block *block = barePtr.getOwner();
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index fa030cb..027f237 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -176,7 +176,7 @@ struct DecomposeCallGraphTypesForCallOp
replacedValues.push_back(decomposedValues.front());
} else {
// Materialize a single Value to replace the original Value.
- Value materialized = getTypeConverter()->materializeArgumentConversion(
+ Value materialized = getTypeConverter()->materializeSourceConversion(
rewriter, op.getLoc(), op.getType(i), decomposedValues);
replacedValues.push_back(materialized);
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e6c0ee2..a250084 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2841,6 +2841,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
if (const TypeConverter *converter = mat.getConverter()) {
rewriter.setInsertionPoint(op);
Value newMaterialization;
+ SmallVector<BlockArgument> blockArgs;
switch (mat.getMaterializationKind()) {
case MaterializationKind::Argument:
// Try to materialize an argument conversion.
@@ -2851,8 +2852,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
// that it diverges from the behavior of the other hooks, and can be
// easily misunderstood. We should clean up the argument hooks to better
// represent the desired invariants we actually care about.
+ blockArgs = llvm::map_to_vector(
+ inputOperands, [](Value v) { return cast<BlockArgument>(v); });
newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
+ rewriter, op->getLoc(), mat.getOrigOutputType(), blockArgs);
if (newMaterialization)
break;
@@ -3134,6 +3137,17 @@ Value TypeConverter::materializeConversion(
return nullptr;
}
+Value TypeConverter::materializeArgumentConversion(
+ ArrayRef<ArgumentMaterializationCallbackFn> materializations,
+ OpBuilder &builder, Location loc, Type resultType,
+ Block::BlockArgListType inputs) const {
+ for (const ArgumentMaterializationCallbackFn &fn :
+ llvm::reverse(materializations))
+ if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
+ return *result;
+ return nullptr;
+}
+
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index f6e8e9e..6512d89 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -394,9 +394,10 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
assert(castKind == getCastKindName(CastKind::Argument) &&
"unexpected value of cast kind attribute");
assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
+ SmallVector<BlockArgument> blockArgs = llvm::map_to_vector(
+ operands, [](Value v) { return cast<BlockArgument>(v); });
maybeResult = typeConverter.materializeArgumentConversion(
- rewriter, castOp->getLoc(), resultTypes.front(),
- castOp.getOperands());
+ rewriter, castOp->getLoc(), resultTypes.front(), blockArgs);
}
if (!maybeResult.has_value() || !maybeResult.value()) {
emitError(castOp->getLoc())
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 0d7dce2..a1f0758 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -124,6 +124,7 @@ struct TestDecomposeCallGraphTypes
return success();
});
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+ typeConverter.addSourceMaterialization(buildMakeTupleOp);
ValueDecomposer decomposer;
decomposer.addDecomposeValueConversion(buildDecomposeTuple);