diff options
author | Ehsan Toosi <ehsan.nadjaran_toosi@dfki.de> | 2020-07-31 15:20:37 +0200 |
---|---|---|
committer | Ehsan Toosi <ehsan.nadjaran_toosi@dfki.de> | 2020-09-02 13:26:55 +0200 |
commit | 94f5d248772ba0f1f9c8b0746fe75a5d246c5540 (patch) | |
tree | c8db17acf62a82953b5aa7d755079bae518f9b7b /mlir/test/lib | |
parent | f72121254da48bf668c35918b53c96cf8c568342 (diff) | |
download | llvm-94f5d248772ba0f1f9c8b0746fe75a5d246c5540.zip llvm-94f5d248772ba0f1f9c8b0746fe75a5d246c5540.tar.gz llvm-94f5d248772ba0f1f9c8b0746fe75a5d246c5540.tar.bz2 |
[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks
In this PR, the users of BufferPlacement can configure
BufferAssginmentTypeConverter. These new configurations would give the user more
freedom in the process of converting function signature, and return and call
operation conversions.
These are the new features:
- Accepting callback functions for decomposing types (i.e. 1 to N type
conversion such as unpacking tuple types).
- Defining ResultConversionKind for specifying whether a function result
with a certain type should be appended to the function arguments list or
should be kept as function result. (Usage:
converter.setResultConversionKind<MemRefType>(AppendToArgumentList))
- Accepting callback functions for composing or decomposing values (i.e. N
to 1 and 1 to N value conversion).
Differential Revision: https://reviews.llvm.org/D85133
Diffstat (limited to 'mlir/test/lib')
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestOps.td | 29 | ||||
-rw-r--r-- | mlir/test/lib/Transforms/TestBufferPlacement.cpp | 48 |
2 files changed, 73 insertions, 4 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index bc26a86..f03c953 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5", let results = (outs AnyType:$result); let extraClassDeclaration = [{ - static LogicalResult inferReturnTypes(MLIRContext *, + static LogicalResult inferReturnTypes(MLIRContext *, Optional<Location> location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { @@ -1679,4 +1679,31 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5", }]; } +//===----------------------------------------------------------------------===// +// Test BufferPlacement +//===----------------------------------------------------------------------===// + +def GetTupleElementOp: TEST_Op<"get_tuple_element"> { + let description = [{ + Test op that returns a specified element of the tuple. + }]; + + let arguments = (ins + TupleOf<[AnyType]>, + I32Attr:$index + ); + let results = (outs AnyType); +} + +def MakeTupleOp: TEST_Op<"make_tuple"> { + let description = [{ + Test op that creates a tuple value from a list of values. + }]; + + let arguments = (ins + Variadic<AnyType>:$inputs + ); + let results = (outs TupleOf<[AnyType]>); +} + #endif // TEST_OPS diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index 6cc0924..14b72b9 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "TestDialect.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -109,14 +111,16 @@ struct TestBufferPlacementPreparationPass void populateTensorLinalgToBufferLinalgConversionPattern( MLIRContext *context, BufferAssignmentPlacer *placer, - TypeConverter *converter, OwningRewritePatternList *patterns) { + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, - allowMemrefFunctionResults>(context, placer, converter, patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, + converter, patterns); patterns->insert<GenericOpConverter>(context, placer, converter); } void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<TestDialect>(); registry.insert<linalg::LinalgDialect>(); } @@ -127,6 +131,8 @@ struct TestBufferPlacementPreparationPass // Mark all Standard operations legal. target.addLegalDialect<StandardOpsDialect>(); + target.addLegalOp<MakeTupleOp>(); + target.addLegalOp<GetTupleElementOp>(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -149,6 +155,42 @@ struct TestBufferPlacementPreparationPass converter.isLegal(&funcOp.getBody()); }); + auto kind = allowMemrefFunctionResults + ? BufferAssignmentTypeConverter::KeepAsFunctionResult + : BufferAssignmentTypeConverter::AppendToArgumentsList; + converter.setResultConversionKind<RankedTensorType, MemRefType>(kind); + converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>( + kind); + + converter.addDecomposeTypeConversion( + [](TupleType tupleType, SmallVectorImpl<Type> &types) { + tupleType.getFlattenedTypes(types); + return success(); + }); + + converter.addArgumentMaterialization( + [](OpBuilder &builder, TupleType resultType, ValueRange inputs, + Location loc) -> Optional<Value> { + if (inputs.size() == 1) + return llvm::None; + TypeRange TypeRange = inputs.getTypes(); + SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end()); + TupleType tuple = TupleType::get(types, builder.getContext()); + mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs); + return value; + }); + + converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc, + TupleType resultType, Value value, + SmallVectorImpl<Value> &values) { + for (unsigned i = 0, e = resultType.size(); i < e; ++i) { + Value res = builder.create<GetTupleElementOp>( + loc, resultType.getType(i), value, builder.getI32IntegerAttr(i)); + values.push_back(res); + } + return success(); + }); + // Walk over all the functions to apply buffer assignment. this->getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns; |