diff options
author | Nicolas Vasilache <ntv@google.com> | 2020-01-21 19:37:18 -0500 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2020-01-21 19:37:54 -0500 |
commit | 89e19e8eddd6dd0dc38d595b6784fb9ce65d9972 (patch) | |
tree | 1f50a15ff7fea553e318fd88147410467a3c20f9 | |
parent | e03ead6771fc97b11cb0c94b7f023142184ad25f (diff) | |
download | llvm-89e19e8eddd6dd0dc38d595b6784fb9ce65d9972.zip llvm-89e19e8eddd6dd0dc38d595b6784fb9ce65d9972.tar.gz llvm-89e19e8eddd6dd0dc38d595b6784fb9ce65d9972.tar.bz2 |
[mlir][Linalg] Add tensor support to Linalg EDSC Builders
Summary:
This diff extends the Linalg EDSC builders so we can easily create mixed
tensor/buffer linalg.generic ops. This is expected to be useful for
HLO -> Linalg lowering.
The `StructuredIndexed` struct is made to derive from `ValueHandle` and can
now capture a type + indexing expressions. This is used to represent return
tensors.
Pointwise unary and binary builders are extended to allow both output buffers
and return tensors. This has implications on the number of region arguments.
Reviewers: ftynse, herhut, hanchung, asaadaldien, stellaraccident
Reviewed By: asaadaldien
Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72863
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h | 134 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 132 | ||||
-rw-r--r-- | mlir/test/EDSC/builder-api-test.cpp | 44 |
3 files changed, 246 insertions, 64 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index fa81350..4ee30c0 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -110,11 +110,14 @@ struct StructuredIndexed { operator Value() const /* implicit */ { return value; } ArrayRef<AffineExpr> getExprs() { return exprs; } + Type getType() { return value.getType(); } private: StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings) : value(v), exprs(indexings.begin(), indexings.end()) { - assert(v.getType().isa<MemRefType>() && "MemRefType expected"); + assert((v.getType().isa<MemRefType>() || + v.getType().isa<RankedTensorType>()) && + "MemRef or RankedTensor expected"); } StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings) : StructuredIndexed(v.getValue(), indexings) {} @@ -125,9 +128,21 @@ private: inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {} +/// Build a `linalg.generic` op with the specified inputs, outputs and region. +/// +/// `otherValues` and `otherAttributes` may be passed and will be appended as +/// operands and attributes respectively. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *makeGenericLinalgOp( ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs, - ArrayRef<StructuredIndexed> outputs, + ArrayRef<StructuredIndexed> outputs, ArrayRef<Type> resultTensorTypes = {}, function_ref<void(ArrayRef<BlockArgument>)> regionBuilder = defaultRegionBuilder, ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {}); @@ -167,32 +182,77 @@ void macRegionBuilder(ArrayRef<BlockArgument> args); /// with in-place semantics and parallelism. /// Unary pointwise operation (with broadcast) entry point. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. using UnaryPointwiseOpBuilder = function_ref<Value(ValueHandle)>; Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, - StructuredIndexed I, StructuredIndexed O); + StructuredIndexed I, StructuredIndexed O, + ArrayRef<Type> resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = tanh(I)`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. -Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O, + ArrayRef<Type> resultTensorTypes = {}); /// Binary pointwise operation (with broadcast) entry point. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. using BinaryPointwiseOpBuilder = function_ref<Value(ValueHandle, ValueHandle)>; Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef<Type> resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = I1 + I2`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef<Type> resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = max(I!, I2)`. The client is responsible for specifying the /// proper indexings when creating the StructuredIndexed. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef<Type> resultTensorTypes = {}); // TODO(ntv): Implement more useful pointwise operations on a per-need basis. @@ -203,11 +263,23 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, /// | /// | C(m, n) += A(m, k) * B(k, n) /// ``` -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, + ArrayRef<Type> resultTensorTypes = {}); -template <typename Container> Operation *linalg_matmul(Container values) { +template <typename Container> +Operation *linalg_matmul(Container values, + ArrayRef<Type> resultTensorTypes = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_matmul(values[0], values[1], values[2]); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); + return linalg_matmul(values[0], values[1], values[2], resultTensorTypes); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -231,16 +303,28 @@ template <typename Container> Operation *linalg_matmul(Container values) { /// /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +// // TODO(ntv) Extend convolution rank with some template magic. Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef<Type> resultTensorTypes = {}, ArrayRef<int> strides = {}, ArrayRef<int> dilations = {}); template <typename Container> -Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {}, - ArrayRef<int> dilations = {}) { +Operation * +linalg_conv_nhwc(Container values, ArrayRef<Type> resultTensorTypes = {}, + ArrayRef<int> strides = {}, ArrayRef<int> dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); + return linalg_conv_nhwc(values[0], values[1], values[2], resultTensorTypes, + strides, dilations); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -249,7 +333,7 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {}, /// (batch, dm, c, [h, w, ...], [kh, kw, ...]) = /// | (par, par, par, [par, par, ...], [red, red, ...]) /// | -/// | O(batch, [h, w, ...], c * depth_multiplier) += +/// | O(batch, [h, w, ...], c * depthMultiplier) += /// | I(batch, /// | [ /// | stride[0] * h + dilations[0] * kh, @@ -257,26 +341,40 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {}, /// ], /// | c) /// | * -/// | W([kh, kw, ...], c, depth_multiplier) +/// | W([kh, kw, ...], c, depthMultiplier) /// ``` /// If `dilations` or `strides` are left empty, the default value of `1` is used /// along each relevant dimension. /// /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +// // TODO(ntv) Extend convolution rank with some template magic. Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, - ValueHandle vO, int depth_multiplier = 1, + ValueHandle vO, + ArrayRef<Type> resultTensorTypes = {}, + int depthMultiplier = 1, ArrayRef<int> strides = {}, ArrayRef<int> dilations = {}); template <typename Container> -Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier, +Operation *linalg_dilated_conv_nhwc(Container values, + ArrayRef<Type> resultTensorTypes = {}, + int depthMultiplier = 1, ArrayRef<int> strides = {}, ArrayRef<int> dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); return linalg_dilated_conv_nhwc(values[0], values[1], values[2], - depth_multiplier, strides, dilations); + resultTensorTypes, depthMultiplier, strides, + dilations); } } // namespace ops diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 0940f56..395a409 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -128,16 +128,20 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices, Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs, - ArrayRef<StructuredIndexed> outputs, + ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Type> resultTensorTypes, function_ref<void(ArrayRef<BlockArgument>)> regionBuilder, ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) { + assert( + llvm::all_of(llvm::make_range(outputBuffers.begin(), outputBuffers.end()), + [](Value v) { return v.getType().isa<MemRefType>(); }) && + "output operands must all be buffers."); auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); - unsigned nOutputs = outputs.size(); + unsigned nOutputs = outputBuffers.size() + resultTensorTypes.size(); unsigned maxPos = 0; getMaxDimIndex(inputs, maxPos); - getMaxDimIndex(outputs, maxPos); + getMaxDimIndex(outputBuffers, maxPos); // maxPos is 0 indexed, need to turn this into a count (i.e. +1) unsigned nDims = maxPos + 1; @@ -146,7 +150,7 @@ Operation *mlir::edsc::makeGenericLinalgOp( for (auto in : inputs) maps.push_back( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs())); - for (auto out : outputs) + for (auto out : outputBuffers) maps.push_back( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); @@ -154,7 +158,7 @@ Operation *mlir::edsc::makeGenericLinalgOp( SmallVector<Value, 4> values; values.reserve(nViews); values.append(inputs.begin(), inputs.end()); - values.append(outputs.begin(), outputs.end()); + values.append(outputBuffers.begin(), outputBuffers.end()); auto iteratorStrTypes = functional::map(toString, iteratorTypes); // clang-format off @@ -162,7 +166,7 @@ Operation *mlir::edsc::makeGenericLinalgOp( edsc::ScopedContext::getBuilder() .create<linalg::GenericOp>( edsc::ScopedContext::getLocation(), - ArrayRef<Type>{}, // TODO(ntv): support tensors + resultTensorTypes, values, IntegerAttr::get(IntegerType::get(64, ctx), nInputs), IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), @@ -207,7 +211,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) { Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, - StructuredIndexed O) { + StructuredIndexed O, + ArrayRef<Type> resultTensorTypes) { SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(), edsc::IterType::Parallel); auto fun = [&unaryOp](ArrayRef<BlockArgument> args) { @@ -215,22 +220,30 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, ValueHandle a(args[0]); linalg_yield(unaryOp(a)); }; - return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); + + // Distinguish between tensor and buffer semantics. + if (O.getType().isa<MemRefType>()) { + assert(resultTensorTypes.empty()); + return makeGenericLinalgOp(iterTypes, {I}, {O}, {}, fun); + } + return makeGenericLinalgOp(iterTypes, {I, O}, {}, resultTensorTypes, fun); } -Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O, + ArrayRef<Type> resultTensorTypes) { ; using edsc::intrinsics::tanh; UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); }); - return linalg_pointwise(unOp, I, O); + return linalg_pointwise(unOp, I, O, resultTensorTypes); } /// Binary pointwise operation (with broadcast) entry point. Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O) { + StructuredIndexed O, + ArrayRef<Type> resultTensorTypes) { SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(), edsc::IterType::Parallel); auto fun = [&binaryOp](ArrayRef<BlockArgument> args) { @@ -238,45 +251,62 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, ValueHandle a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; - return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); + // Distinguish between tensor and buffer semantics. + if (O.getType().isa<MemRefType>()) { + assert(resultTensorTypes.empty()); + return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, {}, fun); + } + return makeGenericLinalgOp(iterTypes, {I1, I2, O}, {}, resultTensorTypes, + fun); } -Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, + StructuredIndexed I2, StructuredIndexed O, + ArrayRef<Type> resultTensorTypes) { using edsc::op::operator+; BinaryPointwiseOpBuilder binOp( [](ValueHandle a, ValueHandle b) -> Value { return a + b; }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes); } -Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, + StructuredIndexed I2, StructuredIndexed O, + ArrayRef<Type> resultTensorTypes) { BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value { using edsc::intrinsics::select; using edsc::op::operator>; return select(a > b, a, b).getValue(); }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes); } Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC) { - // clang-format off + ValueHandle vC, + ArrayRef<Type> resultTensorTypes) { AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC); + + assert(!C.getType().isa<MemRefType>() || resultTensorTypes.empty()); + StructuredIndexed allIndexed[3]{A({m, k}), B({k, n}), C({m, n})}; + ArrayRef<StructuredIndexed> inputs = + (C.getType().isa<MemRefType>()) + ? ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 2} + : ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 3}; + ArrayRef<StructuredIndexed> outputs = + (C.getType().isa<MemRefType>()) + ? ArrayRef<StructuredIndexed>{allIndexed + 2, allIndexed + 3} + : ArrayRef<StructuredIndexed>{}; return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, - {A({m, k}), B({k, n})}, - {C({m, n})}, - macRegionBuilder); - // clang-format on + {IterType::Parallel, IterType::Parallel, IterType::Reduction}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef<Type> resultTensorTypes, ArrayRef<int> strides, ArrayRef<int> dilations) { MLIRContext *ctx = ScopedContext::getContext(); @@ -294,23 +324,33 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, bindDims(ctx, b, f, h, w, kh, kw, c); unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); + + assert(!O.getType().isa<MemRefType>() || resultTensorTypes.empty()); + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. // clang-format off - return makeGenericLinalgOp( - {par, par, par, par, red, red, red}, { + StructuredIndexed allIndexed[3] = { I({b, - // Roundtrip to flattened form to serve as canonicalization and ensure - // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, f})}, { - O({b, h, w, f})}, - macRegionBuilder); + W({kh, kw, c, f}), + O({b, h, w, f})}; // clang-format on + auto inputs = (O.getType().isa<MemRefType>()) + ? ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 2} + : ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 3}; + ArrayRef<StructuredIndexed> outputs = + (O.getType().isa<MemRefType>()) + ? ArrayRef<StructuredIndexed>{allIndexed + 2, allIndexed + 3} + : ArrayRef<StructuredIndexed>{}; + return makeGenericLinalgOp({par, par, par, par, red, red, red}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc( - ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier, + ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef<Type> resultTensorTypes, int depthMultiplier, ArrayRef<int> strides, ArrayRef<int> dilations) { MLIRContext *ctx = ScopedContext::getContext(); // TODO(ntv) some template magic to make everything rank-polymorphic. @@ -328,16 +368,26 @@ Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc( bindDims(ctx, b, dm, c, h, w, kh, kw); unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); - return makeGenericLinalgOp( - {par, par, par, par, par, red, red}, { + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. + // clang-format off + StructuredIndexed allIndexed[3] = { I({b, // Roundtrip to flattened form to serve as canonicalization and ensure // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, dm})}, { - O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})}, - macRegionBuilder); + W({kh, kw, c, dm}), + O({b, h, w, simplifyAffineExpr(c * depthMultiplier + dm, numDims, 0)})}; // clang-format on + auto inputs = (O.getType().isa<MemRefType>()) + ? ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 2} + : ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 3}; + ArrayRef<StructuredIndexed> outputs = + (O.getType().isa<MemRefType>()) + ? ArrayRef<StructuredIndexed>{allIndexed + 2, allIndexed + 3} + : ArrayRef<StructuredIndexed>{}; + return makeGenericLinalgOp({par, par, par, par, par, red, red}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 5388446..f3ca261 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -467,7 +467,8 @@ TEST_FUNC(zero_and_sign_extendi_op_i1_to_i8) { auto i1Type = IntegerType::get(1, &globalContext()); auto i8Type = IntegerType::get(8, &globalContext()); auto memrefType = MemRefType::get({}, i1Type, {}, 0); - auto f = makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); + auto f = + makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); @@ -795,10 +796,12 @@ TEST_FUNC(empty_map_load_store) { } // CHECK-LABEL: func @affine_if_op -// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} +// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}} +// CHECK-SAME: [[s0:.*]], [[s1:.*]]{{\]}} // CHECK-NOT: else -// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} -// CHECK-NEXT: } else { +// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}} +// CHECK-SAME: [[s0:.*]], [[s1:.*]]{{\]}} +// CHECK-NEXT: } else { TEST_FUNC(affine_if_op) { using namespace edsc; using namespace edsc::intrinsics; @@ -900,6 +903,36 @@ TEST_FUNC(linalg_matmul_test) { } // clang-format off +// CHECK-LABEL: func @linalg_matmul_mixed_tensors +// CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 +// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 +// CHECK: linalg.yield %[[a4]] : f32 +// CHECK: }: tensor<?x?xf32>, memref<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32> +// clang-format on +TEST_FUNC(linalg_matmul_mixed_tensors_test) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); + auto tensorType = RankedTensorType::get({-1, -1}, f32Type); + auto f = makeFunction("linalg_matmul_mixed_tensors", {}, + {tensorType, memrefType, tensorType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + tensorType); + + f.print(llvm::outs()); + f.erase(); +} + +// clang-format off // CHECK-LABEL: func @linalg_conv_nhwc // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>, @@ -923,7 +956,7 @@ TEST_FUNC(linalg_conv_nhwc) { OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), {}, /*strides=*/{3, 4}, /*dilations=*/{5, 6}); f.print(llvm::outs()); @@ -956,6 +989,7 @@ TEST_FUNC(linalg_dilated_conv_nhwc) { ScopedContext scope(builder, f.getLoc()); linalg_dilated_conv_nhwc( makeValueHandles(llvm::to_vector<3>(f.getArguments())), + /*outputTensorTypes=*/{}, /*depth_multiplier=*/7, /*strides=*/{3, 4}, /*dilations=*/{5, 6}); |