aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib
diff options
context:
space:
mode:
authorEhsan Toosi <ehsan.nadjaran_toosi@dfki.de>2020-07-31 15:20:37 +0200
committerEhsan Toosi <ehsan.nadjaran_toosi@dfki.de>2020-09-02 13:26:55 +0200
commit94f5d248772ba0f1f9c8b0746fe75a5d246c5540 (patch)
treec8db17acf62a82953b5aa7d755079bae518f9b7b /mlir/test/lib
parentf72121254da48bf668c35918b53c96cf8c568342 (diff)
downloadllvm-94f5d248772ba0f1f9c8b0746fe75a5d246c5540.zip
llvm-94f5d248772ba0f1f9c8b0746fe75a5d246c5540.tar.gz
llvm-94f5d248772ba0f1f9c8b0746fe75a5d246c5540.tar.bz2
[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks
In this PR, the users of BufferPlacement can configure BufferAssginmentTypeConverter. These new configurations would give the user more freedom in the process of converting function signature, and return and call operation conversions. These are the new features: - Accepting callback functions for decomposing types (i.e. 1 to N type conversion such as unpacking tuple types). - Defining ResultConversionKind for specifying whether a function result with a certain type should be appended to the function arguments list or should be kept as function result. (Usage: converter.setResultConversionKind<MemRefType>(AppendToArgumentList)) - Accepting callback functions for composing or decomposing values (i.e. N to 1 and 1 to N value conversion). Differential Revision: https://reviews.llvm.org/D85133
Diffstat (limited to 'mlir/test/lib')
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td29
-rw-r--r--mlir/test/lib/Transforms/TestBufferPlacement.cpp48
2 files changed, 73 insertions, 4 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bc26a86..f03c953 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
- static LogicalResult inferReturnTypes(MLIRContext *,
+ static LogicalResult inferReturnTypes(MLIRContext *,
Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
@@ -1679,4 +1679,31 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
}];
}
+//===----------------------------------------------------------------------===//
+// Test BufferPlacement
+//===----------------------------------------------------------------------===//
+
+def GetTupleElementOp: TEST_Op<"get_tuple_element"> {
+ let description = [{
+ Test op that returns a specified element of the tuple.
+ }];
+
+ let arguments = (ins
+ TupleOf<[AnyType]>,
+ I32Attr:$index
+ );
+ let results = (outs AnyType);
+}
+
+def MakeTupleOp: TEST_Op<"make_tuple"> {
+ let description = [{
+ Test op that creates a tuple value from a list of values.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs
+ );
+ let results = (outs TupleOf<[AnyType]>);
+}
+
#endif // TEST_OPS
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 6cc0924..14b72b9 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -11,6 +11,8 @@
//
//===----------------------------------------------------------------------===//
+#include "TestDialect.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
@@ -109,14 +111,16 @@ struct TestBufferPlacementPreparationPass
void populateTensorLinalgToBufferLinalgConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *placer,
- TypeConverter *converter, OwningRewritePatternList *patterns) {
+ BufferAssignmentTypeConverter *converter,
+ OwningRewritePatternList *patterns) {
populateWithBufferAssignmentOpConversionPatterns<
- mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
- allowMemrefFunctionResults>(context, placer, converter, patterns);
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
+ converter, patterns);
patterns->insert<GenericOpConverter>(context, placer, converter);
}
void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<TestDialect>();
registry.insert<linalg::LinalgDialect>();
}
@@ -127,6 +131,8 @@ struct TestBufferPlacementPreparationPass
// Mark all Standard operations legal.
target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalOp<MakeTupleOp>();
+ target.addLegalOp<GetTupleElementOp>();
// Mark all Linalg operations illegal as long as they work on tensors.
auto isLegalOperation = [&](Operation *op) {
@@ -149,6 +155,42 @@ struct TestBufferPlacementPreparationPass
converter.isLegal(&funcOp.getBody());
});
+ auto kind = allowMemrefFunctionResults
+ ? BufferAssignmentTypeConverter::KeepAsFunctionResult
+ : BufferAssignmentTypeConverter::AppendToArgumentsList;
+ converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
+ converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
+ kind);
+
+ converter.addDecomposeTypeConversion(
+ [](TupleType tupleType, SmallVectorImpl<Type> &types) {
+ tupleType.getFlattenedTypes(types);
+ return success();
+ });
+
+ converter.addArgumentMaterialization(
+ [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
+ Location loc) -> Optional<Value> {
+ if (inputs.size() == 1)
+ return llvm::None;
+ TypeRange TypeRange = inputs.getTypes();
+ SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
+ TupleType tuple = TupleType::get(types, builder.getContext());
+ mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs);
+ return value;
+ });
+
+ converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
+ TupleType resultType, Value value,
+ SmallVectorImpl<Value> &values) {
+ for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
+ Value res = builder.create<GetTupleElementOp>(
+ loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
+ values.push_back(res);
+ }
+ return success();
+ });
+
// Walk over all the functions to apply buffer assignment.
this->getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns;