diff options
author | Matthias Springer <mspringer@nvidia.com> | 2024-06-23 21:20:27 +0200 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2024-06-23 21:20:27 +0200 |
commit | 61e54a1dbfea9d0a177b7a29aa3c5dd8c6b6bb6b (patch) | |
tree | 03b914ee9f4e450ca9d6000de9abe913ecab7c8c | |
parent | 0ae7616116a9e31171de4b7fb98c18b4c0c92b68 (diff) | |
download | llvm-users/matthias-springer/argument_mat_bbarg_list.zip llvm-users/matthias-springer/argument_mat_bbarg_list.tar.gz llvm-users/matthias-springer/argument_mat_bbarg_list.tar.bz2 |
improve addArgumentMaterialization APIusers/matthias-springer/argument_mat_bbarg_list
6 files changed, 45 insertions, 13 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f83f3a3..aa30281 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -186,7 +186,7 @@ public: std::decay_t<FnT>>::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { argumentMaterializations.emplace_back( - wrapMaterialization<T>(std::forward<FnT>(callback))); + wrapArgumentMaterialization<T>(std::forward<FnT>(callback))); } /// This method registers a materialization that will be called when /// converting a legal type to an illegal source type. This is used when @@ -300,9 +300,9 @@ public: /// methods. Value materializeArgumentConversion(OpBuilder &builder, Location loc, Type resultType, - ValueRange inputs) const { - return materializeConversion(argumentMaterializations, builder, loc, - resultType, inputs); + Block::BlockArgListType inputs) const { + return materializeArgumentConversion(argumentMaterializations, builder, loc, + resultType, inputs); } Value materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const { @@ -332,6 +332,8 @@ private: /// The signature of the callback used to materialize a conversion. using MaterializationCallbackFn = std::function<std::optional<Value>( OpBuilder &, Type, ValueRange, Location)>; + using ArgumentMaterializationCallbackFn = std::function<std::optional<Value>( + OpBuilder &, Type, Block::BlockArgListType, Location)>; /// The signature of the callback used to convert a type attribute. using TypeAttributeConversionCallbackFn = @@ -343,6 +345,10 @@ private: materializeConversion(ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) const; + Value materializeArgumentConversion( + ArrayRef<ArgumentMaterializationCallbackFn> materializations, + OpBuilder &builder, Location loc, Type resultType, + Block::BlockArgListType inputs) const; /// Generate a wrapper for the given callback. This allows for accepting /// different callback forms, that all compose into a single version. @@ -397,6 +403,18 @@ private: return std::nullopt; }; } + template <typename T, typename FnT> + ArgumentMaterializationCallbackFn + wrapArgumentMaterialization(FnT &&callback) const { + return + [callback = std::forward<FnT>(callback)]( + OpBuilder &builder, Type resultType, Block::BlockArgListType inputs, + Location loc) -> std::optional<Value> { + if (T derivedType = dyn_cast<T>(resultType)) + return callback(builder, derivedType, inputs, loc); + return std::nullopt; + }; + } /// Generate a wrapper for the given memory space conversion callback. The /// callback may take any subclass of `Attribute` and the wrapper will check @@ -428,7 +446,7 @@ private: SmallVector<ConversionCallbackFn, 4> conversions; /// The list of registered materialization functions. - SmallVector<MaterializationCallbackFn, 2> argumentMaterializations; + SmallVector<ArgumentMaterializationCallbackFn, 2> argumentMaterializations; SmallVector<MaterializationCallbackFn, 2> sourceMaterializations; SmallVector<MaterializationCallbackFn, 2> targetMaterializations; diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index f5620a6..3f1114c 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -168,14 +168,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, inputs); }); addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, - ValueRange inputs, + Block::BlockArgListType inputs, Location loc) -> std::optional<Value> { if (inputs.size() == 1) { // This is a bare pointer. We allow bare pointers only for function entry // blocks. - BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front()); - if (!barePtr) - return std::nullopt; + BlockArgument barePtr = inputs.front(); Block *block = barePtr.getOwner(); if (!block->isEntryBlock() || !isa<FunctionOpInterface>(block->getParentOp())) diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp index fa030cb..027f237 100644 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp @@ -176,7 +176,7 @@ struct DecomposeCallGraphTypesForCallOp replacedValues.push_back(decomposedValues.front()); } else { // Materialize a single Value to replace the original Value. - Value materialized = getTypeConverter()->materializeArgumentConversion( + Value materialized = getTypeConverter()->materializeSourceConversion( rewriter, op.getLoc(), op.getType(i), decomposedValues); replacedValues.push_back(materialized); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index e6c0ee2..a250084 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2841,6 +2841,7 @@ static LogicalResult legalizeUnresolvedMaterialization( if (const TypeConverter *converter = mat.getConverter()) { rewriter.setInsertionPoint(op); Value newMaterialization; + SmallVector<BlockArgument> blockArgs; switch (mat.getMaterializationKind()) { case MaterializationKind::Argument: // Try to materialize an argument conversion. @@ -2851,8 +2852,10 @@ static LogicalResult legalizeUnresolvedMaterialization( // that it diverges from the behavior of the other hooks, and can be // easily misunderstood. We should clean up the argument hooks to better // represent the desired invariants we actually care about. + blockArgs = llvm::map_to_vector( + inputOperands, [](Value v) { return cast<BlockArgument>(v); }); newMaterialization = converter->materializeArgumentConversion( - rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands); + rewriter, op->getLoc(), mat.getOrigOutputType(), blockArgs); if (newMaterialization) break; @@ -3134,6 +3137,17 @@ Value TypeConverter::materializeConversion( return nullptr; } +Value TypeConverter::materializeArgumentConversion( + ArrayRef<ArgumentMaterializationCallbackFn> materializations, + OpBuilder &builder, Location loc, Type resultType, + Block::BlockArgListType inputs) const { + for (const ArgumentMaterializationCallbackFn &fn : + llvm::reverse(materializations)) + if (std::optional<Value> result = fn(builder, resultType, inputs, loc)) + return *result; + return nullptr; +} + std::optional<TypeConverter::SignatureConversion> TypeConverter::convertBlockSignature(Block *block) const { SignatureConversion conversion(block->getNumArguments()); diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index f6e8e9e..6512d89 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -394,9 +394,10 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, assert(castKind == getCastKindName(CastKind::Argument) && "unexpected value of cast kind attribute"); assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>)); + SmallVector<BlockArgument> blockArgs = llvm::map_to_vector( + operands, [](Value v) { return cast<BlockArgument>(v); }); maybeResult = typeConverter.materializeArgumentConversion( - rewriter, castOp->getLoc(), resultTypes.front(), - castOp.getOperands()); + rewriter, castOp->getLoc(), resultTypes.front(), blockArgs); } if (!maybeResult.has_value() || !maybeResult.value()) { emitError(castOp->getLoc()) diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index 0d7dce2..a1f0758 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -124,6 +124,7 @@ struct TestDecomposeCallGraphTypes return success(); }); typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); ValueDecomposer decomposer; decomposer.addDecomposeValueConversion(buildDecomposeTuple); |