aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h29
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td243
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h2
-rw-r--r--mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h8
-rw-r--r--mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/EDSC/Builders.cpp143
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp336
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp56
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp59
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp53
-rw-r--r--mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir32
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir6
-rw-r--r--mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir39
-rw-r--r--mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir33
-rw-r--r--mlir/test/Dialect/Linalg/fusion-tensor.mlir198
-rw-r--r--mlir/test/Dialect/Linalg/fusion.mlir67
-rw-r--r--mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir51
-rw-r--r--mlir/test/Dialect/Linalg/inlining.mlir8
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir240
-rw-r--r--mlir/test/Dialect/Linalg/loops.mlir82
-rw-r--r--mlir/test/Dialect/Linalg/parallel_loops.mlir25
-rw-r--r--mlir/test/Dialect/Linalg/roundtrip.mlir118
-rw-r--r--mlir/test/Dialect/Linalg/standard.mlir17
-rw-r--r--mlir/test/Dialect/Linalg/tensors-to-buffers.mlir27
-rw-r--r--mlir/test/Dialect/Linalg/tile.mlir6
-rw-r--r--mlir/test/Dialect/Linalg/tile_indexed_generic.mlir12
-rw-r--r--mlir/test/Dialect/Linalg/tile_parallel.mlir17
-rw-r--r--mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir15
-rw-r--r--mlir/test/Dialect/Linalg/transform-patterns.mlir42
-rw-r--r--mlir/test/EDSC/builder-api-test.cpp91
-rw-r--r--mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir8
-rw-r--r--mlir/test/Transforms/buffer-placement-preparation.mlir85
-rw-r--r--mlir/test/Transforms/buffer-placement.mlir218
-rw-r--r--mlir/test/Transforms/copy-removal.mlir34
-rw-r--r--mlir/test/lib/Transforms/TestBufferPlacement.cpp73
-rw-r--r--mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp2
36 files changed, 1307 insertions, 1170 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index 5b6cb0a..ac9ca95 100644
--- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -30,8 +30,8 @@ class ParallelOp;
namespace edsc {
inline void defaultRegionBuilder(ValueRange args) {}
-/// Build a `linalg.generic` op with the specified `inputs`, `outputs` and
-/// `region`.
+/// Build a `linalg.generic` op with the specified `inputs`, `outputBuffers`,
+/// `initTensors`, `resultTensorsTypes` and `region`.
///
/// `otherValues` and `otherAttributes` may be passed and will be appended as
/// operands and attributes respectively.
@@ -41,14 +41,15 @@ inline void defaultRegionBuilder(ValueRange args) {}
///
/// 1. `inputs` may contain StructuredIndexed that capture either buffer or
/// tensor values.
-/// 2. `outputs` may contain StructuredIndexed that capture either buffer values
-/// or tensor types. If both buffer values and tensor types are present, then
-/// all buffer values must appear before any tensor type. Without this
-/// restriction output tensor results would need to be reordered, which would
-/// result in surprising behavior when combined with region definition.
+/// 2. `outputsBuffers` may contain StructuredIndexed that capture buffer
+/// values.
+/// 3. `initTensors` contain tensor values, without indexing maps.
+/// 4. `resultTensorTypes` may contain StructuredIndexed that capture return
+/// tensor types.
Operation *makeGenericLinalgOp(
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
- ArrayRef<StructuredIndexed> outputs,
+ ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
+ ArrayRef<StructuredIndexed> resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder = defaultRegionBuilder,
ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {});
@@ -139,18 +140,6 @@ linalg_generic_matmul(Value vA, Value vB, Value vC,
/// ```
/// (m, n, k) = (par, par, seq)
/// |
-/// | C(m, n) = sum_k(A(m, k) * B(k, n))
-/// ```
-/// and returns the tensor `C`.
-Operation *
-linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC,
- MatmulRegionBuilder regionBuilder = mulRegionBuilder);
-
-/// Build a linalg.generic, under the current ScopedContext, at the current
-/// insert point, that computes:
-/// ```
-/// (m, n, k) = (par, par, seq)
-/// |
/// | D(m, n) = C(m, n) + sum_k(A(m, k) * B(k, n))
/// ```
/// and returns the tensor `D`.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2120840..65df012 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -40,11 +40,12 @@ def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
// depending on the specific Linalg op.
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic,
- !listconcat(props, [StructuredOpTraits, LinalgStructuredInterface])> {
+ !listconcat(props, [LinalgStructuredInterface])> {
}
class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
- : LinalgStructuredBase_Op<mnemonic, props> {
+ : LinalgStructuredBase_Op<mnemonic,
+ !listconcat(props, [StructuredOpTraits])> {
code libraryCallName = [{
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
@@ -460,43 +461,38 @@ class LinalgOperandOfRank<int rank>: Type<
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
>>;
-class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
- [SingleBlockImplicitTerminator<"YieldOp">]> {
- let arguments = (ins Variadic<LinalgOperand>:$views,
- I64Attr:$args_in,
- I64Attr:$args_out,
- AffineMapArrayAttr:$indexing_maps,
- ArrayAttr:$iterator_types,
- OptionalAttr<StrAttr>:$doc,
- OptionalAttr<StrAttr>:$library_call,
- Confined<OptionalAttr<I64Attr>,
- [IntMinValue<0>]>:$symbol_source);
- let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
+class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
+ NamedStructuredOpTrait,
+ AttrSizedOperandSegments,
+ SingleBlockImplicitTerminator<"YieldOp">]> {
+ let arguments = (ins Variadic<AnyShaped>:$inputs,
+ Variadic<AnyMemRef>:$output_buffers,
+ Variadic<AnyRankedTensor>:$init_tensors,
+ AffineMapArrayAttr:$indexing_maps,
+ ArrayAttr:$iterator_types,
+ OptionalAttr<StrAttr>:$doc,
+ OptionalAttr<StrAttr>:$library_call,
+ Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>
+ :$symbol_source);
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
return SmallVector<StringRef, 8>{
- getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
+ getDocAttrName(),
getIndexingMapsAttrName(), getLibraryCallAttrName(),
getIteratorTypesAttrName(), getSymbolSourceAttrName()
};
}
-
- unsigned getNumInputs() { return args_in(); }
-
- unsigned getNumOutputs() { return args_out(); }
-
StringRef getLibraryCallName() {
return library_call().hasValue() ? library_call().getValue() : "";
}
-
llvm::Optional<unsigned> getSymbolSource() {
auto ss = symbol_source();
return ss.hasValue() ?
llvm::Optional<unsigned>(ss.getValue()) : llvm::None;
}
}];
-
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseGenericOp(parser, result); }];
}
@@ -505,18 +501,19 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic,
def GenericOp : GenericOpBase<"generic"> {
let description = [{
Generic Linalg op form where the key properties of the computation are
- specified as attributes. In pretty form, a linalg.generic op is written as:
+ specified as attributes. In pretty form, a `linalg.generic` op is written
+ as:
```mlir
- linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>
+ linalg.generic #trait_attribute
+ ins(%A, %B : memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>)
+ outs(%C : memref<?x?xf32, stride_specification>)
+ attrs = {other-optional-attributes}
+ {region}
```
Where #trait_attributes is an alias of a dictionary attribute containing:
- - args_in: an I64Attr representing the number of input (readonly) views
- - args_out: an I64Attr representing the number of output (readwrite) views
- doc [optional]: a documentation string
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
and output view. Such AffineMapAttr specifies the mapping between the
@@ -547,22 +544,22 @@ def GenericOp : GenericOpBase<"generic"> {
doc = "C(m, n) += A(m, k) * B(k, n)",
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
- args_in = 2,
- args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"]
}
```
And can be reused in multiple places as:
```mlir
- linalg.generic #matmul_trait %A, %B, %C [other-attributes] {
+ linalg.generic #matmul_trait
+ ins(%A, %B : memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>)
+ outs(%C : memref<?x?xf32, stride_specification>)
+ {other-optional-attributes} {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e : f32
- } : memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>
+ }
```
This may lower to either:
@@ -591,30 +588,29 @@ def GenericOp : GenericOpBase<"generic"> {
```
To allow progressive lowering from the value world (a.k.a tensor values) to
- the buffer world (a.k.a memref values), a `linalg.generic` op accepts
- mixing input and output ranked tensor values with input and output memrefs.
+ the buffer world (a.k.a memref values), a `linalg.generic` op allows mixing
+ tensors and buffers operands and tensor results.
```mlir
- %C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} :
- tensor<?x?xf32>,
- memref<?x?xf32, stride_specification>
+ %C = linalg.generic #trait_attribute
+ ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
+ init(%C : tensor<?x?xf32>)
+ {other-optional-attributes}
+ {region}
-> (tensor<?x?xf32>)
```
- In this case, the number of outputs (args_out) must match the sum of (1) the
- number of output buffer operands and (2) the number of tensor return values.
- The semantics is that the `linalg.indexed_generic` op produces (i.e.
- allocates and fills) its tensor return values.
+ The `init` operand and the conventions around mixing tensors and buffers are
+ described in more detail in the "Tensors and Buffers: Conventions and
+ Limitations" section in the [Linalg Document](../docs/Linalg.md)
Tensor values must be legalized by a buffer allocation pass before most
- transformations can be applied. Such legalization moves tensor return values
+ transformations can be applied. Such legalizations move tensor return values
into output buffer operands and updates the region arguments accordingly.
- Transformations that create control-flow around linalg.indexed_generic
- operations are not expected to work with tensors because SSA values do not
- escape naturally. Still, transformations and rewrites that take advantage of
- tensor SSA values are expected to be useful and will be added in the near
- future.
+ The `symbol_source` attribute allows selecting a particular operand and
+ introducing symbols for each operand dimension. Such symbols can then be
+ used in the indexing maps.
Example of 1D convolution with symbols:
```mlir
@@ -632,14 +628,14 @@ def GenericOp : GenericOpBase<"generic"> {
symbol_source = 1
}
- linalg.generic #conv_1d_trait %in, %filter, %out {
+ linalg.generic #conv_1d_trait
+ ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+ outs(%out : memref<?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
- } : memref<?xf32>,
- memref<?xf32>,
- memref<?xf32>
+ }
```
where symbol s0 will be substituted with `dim %filter, %c0` i.e. the first
and only dimension of the second operand as specified by the symbol_source
@@ -648,12 +644,28 @@ def GenericOp : GenericOpBase<"generic"> {
let builders = [
OpBuilder<
- "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
- "ValueRange args, int64_t argsIn, int64_t argsOut, "
+ "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTensorTypes,"
+ "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, "
"ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+ "StringRef doc, StringRef libraryCall, IntegerAttr symbolSource, "
+ "function_ref<void(OpBuilder &, Location, ValueRange)> = nullptr">,
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, "
+ "ValueRange inputs, ValueRange outputBuffers, "
+ "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+ "StringRef doc, StringRef libraryCall, IntegerAttr symbolSource, "
+ "function_ref<void(OpBuilder &, Location, ValueRange)> = nullptr">,
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTensorTypes,"
+ "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, "
+ "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+ "function_ref<void(OpBuilder &, Location, ValueRange)> = nullptr">,
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, ValueRange inputs, "
+ "ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps, "
+ "ArrayRef<StringRef> iteratorTypes, "
"function_ref<void(OpBuilder &, Location, ValueRange)> = nullptr">
];
-
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
@@ -665,19 +677,19 @@ def GenericOp : GenericOpBase<"generic"> {
def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
let description = [{
Indexed Generic Linalg op form where the key properties of the computation
- are specified as attributes. In pretty form, a linalg.indexed_generic op is
- written as:
+ are specified as attributes. In pretty form, a `linalg.indexed_generic` op
+ is written as:
```mlir
- linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} :
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>
+ linalg.indexed_generic #trait_attribute
+ ins(%A, %B : memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>)
+ outs(%C : memref<?x?xf32, stride_specification>)
+ attrs = {other-optional-attributes}
+ {region}
```
Where #trait_attributes is an alias of a dictionary attribute containing:
- - args_in: an I64Attr representing the number of input (readonly) views
- - args_out: an I64Attr representing the number of output (readwrite) views
- doc [optional]: a documentation string
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
and output view. Such AffineMapAttr specifies the mapping between the
@@ -705,8 +717,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
doc = "C(m, n) += A(m, k) * B(k, n)",
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
- args_in = 2,
- args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"]
}
```
@@ -714,23 +724,25 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
And can be reused in multiple places as:
```mlir
- linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] {
+ linalg.indexed_generic #matmul_trait
+ ins(%A, %B : memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>)
+ outs(%C : memref<?x?xf32, stride_specification>) {
(%offset_m: index, %offset_n: index, %offset_k: index,
%a: f32, %b: f32, %c: f32) :
"some_optional_computation"(%offset_m, %offset_n, %offset_k)
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg_yield %e : f32
- } : memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>,
- memref<?x?xf32, stride_specification>
+ }
```
This may lower to either:
```mlir
call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) :
- (memref<?x?xf32, stride_specification>,
+ (index, index, index,
+ memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>)
-> ()
@@ -756,42 +768,83 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
To allow progressive lowering from the value world (a.k.a tensor values) to
the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
- accepts mixing input and output ranked tensor values with input and output
- memrefs.
+ allows mixing tensors and buffers operands and tensor results.
```mlir
- %C = linalg.indexed_generic #trait_attribute %A, %B {other-attributes}
- : tensor<?x?xf32>,
- memref<?x?xf32, stride_specification>
+ %C = linalg.indexed_generic #trait_attribute
+ ins(%A, %B : tensor<?x?xf32>, memref<?x?xf32, stride_specification>)
+ init(%C : tensor<?x?xf32>)
+ {other-optional-attributes}
+ {region_with_index_arguments}
-> (tensor<?x?xf32>)
```
- In this case, the number of outputs (args_out) must match the sum of (1) the
- number of output buffer operands and (2) the number of tensor return values.
- The semantics is that the `linalg.indexed_generic` op produces (i.e.
- allocates and fills) its return values.
+ The `init` operand and the conventions around mixing tensors and buffers are
+ described in more detail in the "Tensors and Buffers: Conventions and
+ Limitations" section in the [Linalg Document](../docs/Linalg.md)
Tensor values must be legalized by a buffer allocation pass before most
- transformations can be applied. Such legalization moves tensor return values
- into output buffer operands and updates the region argument accordingly.
-
- Transformations that create control-flow around linalg.indexed_generic
- operations are not expected to work with tensors because SSA values do not
- escape naturally. Still, transformations and rewrites that take advantage of
- tensor SSA values are expected to be useful and will be added in the near
- future.
+ transformations can be applied. Such legalizations move tensor return values
+ into output buffer operands and update the region arguments accordingly.
+
+ The `symbol_source` attribute allows selecting a particular operand and
+ introducing symbols for each operand dimension. Such symbols can then be
+ used in the indexing maps.
+
+ Example of 1D convolution with symbols:
+ ```mlir
+ #conv_1d_accesses = [
+ affine_map<(m, n)[dimN] -> (m + n - dimN floordiv 2)>, // in
+ affine_map<(m, n)[dimN] -> (n)>, // filter
+ affine_map<(m, n)[dimN] -> (m)> // out
+ ]
+
+ #conv_1d_trait = {
+ doc = "O(m) += I(m + n - size(n) floordiv 2) * K(n)",
+ indexing_maps = #conv_1d_accesses,
+ library_call = "linalg_conv_1d",
+ iterator_types = ["parallel", "parallel"],
+ symbol_source = 1
+ }
+
+ linalg.generic #conv_1d_trait
+ ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+ outs(%out : memref<?xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b : f32
+ %e = addf %c, %d : f32
+ linalg.yield %e : f32
+ }
+ ```
+ where symbol s0 will be substituted with `dim %filter, %c0` i.e. the first
+ and only dimension of the second operand as specified by the symbol_source
+ attribute.
}];
let builders = [
OpBuilder<
- "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes, "
- "ValueRange args, int64_t argsIn, int64_t argsOut, "
+ "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTensorTypes,"
+ "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, "
+ "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+ "StringRef doc, StringRef libraryCall, IntegerAttr symbolSource, "
+ "function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> = nullptr">,
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, "
+ "ValueRange inputs, ValueRange outputBuffers, "
+ "ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
+ "StringRef doc, StringRef libraryCall, IntegerAttr symbolSource, "
+ "function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> = nullptr">,
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTensorTypes,"
+ "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, "
"ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, "
- "function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> "
- "= nullptr">
+ "function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> = nullptr">,
+ OpBuilder<
+ "OpBuilder &builder, OperationState &result, ValueRange inputs, "
+ "ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps, "
+ "ArrayRef<StringRef> iteratorTypes, "
+ "function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> = nullptr">
];
-
-
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index ae56dd6..1da9362 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -90,7 +90,7 @@ public:
unsigned getNumOutputs() {
ConcreteType concreteOp = cast<ConcreteType>(this->getOperation());
return concreteOp.output_buffers().size() +
- concreteOp.output_tensors().size();
+ concreteOp.result_tensors().size();
}
static LogicalResult verifyTrait(Operation *op) {
ConcreteType concreteOp = cast<ConcreteType>(op);
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 41d6142..805db03 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -58,14 +58,6 @@ constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
/// op's iterators.
constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
-/// Attribute name for the IntegerAttr which encodes the number of input buffer
-/// arguments.
-constexpr StringRef getArgsInAttrName() { return "args_in"; }
-
-/// Attribute name for the IntegerAttr which encodes the number of input buffer
-/// arguments.
-constexpr StringRef getArgsOutAttrName() { return "args_out"; }
-
/// Attribute name for the StringAttr which encodes an optional documentation
/// string of the structured op.
constexpr StringRef getDocAttrName() { return "doc"; }
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index 4c0aa75..18e2693 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -71,7 +71,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
return llvm::None;
// Make sure this is reduction with one input and one output.
- if (genericOp.args_in() != 1 || genericOp.args_out() != 1)
+ if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
return llvm::None;
auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index 44eb5e7..20522fe 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -23,37 +23,36 @@ using namespace mlir::scf;
Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
- ArrayRef<StructuredIndexed> outputs,
+ ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
+ ArrayRef<StructuredIndexed> resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
ArrayRef<Attribute> otherAttributes) {
- for (unsigned i = 0, e = outputs.size(); i + 1 < e; ++i)
- assert(!(outputs[i].getType().isa<RankedTensorType>() &&
- outputs[i + 1].getType().isa<MemRefType>()) &&
- "output tensors must be passed after output buffers");
- auto &builder = edsc::ScopedContext::getBuilderRef();
- auto *ctx = builder.getContext();
- unsigned nInputs = inputs.size();
- unsigned nOutputs = outputs.size();
+ OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
+ // Build maps
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
- exprsList.reserve(nInputs + nOutputs);
- for (auto structuredIndexed : inputs)
- exprsList.emplace_back(structuredIndexed.getExprs().begin(),
- structuredIndexed.getExprs().end());
- for (auto structuredIndexed : outputs)
- exprsList.emplace_back(structuredIndexed.getExprs().begin(),
- structuredIndexed.getExprs().end());
+ exprsList.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
+ for (auto container : {inputs, outputBuffers, resultTensorTypes})
+ for (const StructuredIndexed &s : container)
+ exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end());
auto maps = AffineMap::inferFromExprList(exprsList);
- unsigned nViews = nInputs + nOutputs;
- SmallVector<Value, 4> values;
- values.reserve(nViews);
- values.append(inputs.begin(), inputs.end());
- std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(values),
- [](StructuredIndexed s) { return s.hasValue(); });
SmallVector<Type, 4> types;
- std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(types),
- [](StructuredIndexed s) { return !s.hasValue(); });
+ assert(llvm::all_of(resultTensorTypes, [](const StructuredIndexed &s) {
+ return !s.hasValue();
+ }));
+ std::copy(resultTensorTypes.begin(), resultTensorTypes.end(),
+ std::back_inserter(types));
+
+ SmallVector<Value, 4> inputValues, outputBufferValues, initTensorValues;
+ inputValues.reserve(inputs.size());
+ outputBufferValues.reserve(outputBuffers.size());
+ initTensorValues.reserve(initTensors.size());
+ std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues));
+ std::copy(outputBuffers.begin(), outputBuffers.end(),
+ std::back_inserter(outputBufferValues));
+ std::copy(initTensors.begin(), initTensors.end(),
+ std::back_inserter(initTensorValues));
auto iteratorStrTypes =
llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
@@ -63,9 +62,9 @@ Operation *mlir::edsc::makeGenericLinalgOp(
.create<linalg::GenericOp>(
edsc::ScopedContext::getLocation(),
types,
- values,
- IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
- IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
+ inputValues,
+ outputBufferValues,
+ initTensorValues,
builder.getAffineMapArrayAttr(maps),
builder.getStrArrayAttr(iteratorStrTypes),
StringAttr() /*doc*/,
@@ -78,11 +77,12 @@ Operation *mlir::edsc::makeGenericLinalgOp(
using namespace edsc;
SmallVector<Type, 4> blockTypes;
- blockTypes.reserve(values.size());
- for (auto it : llvm::enumerate(values))
- blockTypes.push_back((it.index() < nViews)
- ? getElementTypeOrSelf(it.value())
- : it.value().getType());
+ blockTypes.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
+ for (auto container : {inputs, outputBuffers})
+ for (const StructuredIndexed &s : container)
+ blockTypes.push_back(getElementTypeOrSelf(s.getType()));
+ for (Value v : initTensors)
+ blockTypes.push_back(getElementTypeOrSelf(v.getType()));
assert(op->getNumRegions() == 1);
assert(op->getRegion(0).empty());
@@ -113,20 +113,17 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) {
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
IteratorType::Parallel);
- if (O.getType().isa<RankedTensorType>()) {
- auto fun = [&unaryOp](ValueRange args) {
- assert(args.size() == 1 && "expected 1 block arguments");
- Value a(args[0]);
- linalg_yield(unaryOp(a));
- };
- return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
- }
auto fun = [&unaryOp](ValueRange args) {
- assert(args.size() == 2 && "expected 2 block arguments");
+ assert(!args.empty() >= 1 && "expected >= 1 block arguments");
Value a(args[0]);
linalg_yield(unaryOp(a));
};
- return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
+ if (O.getType().isa<RankedTensorType>())
+ return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{},
+ /*initTensors=*/{}, /*resultTensorTypes=*/{O},
+ fun);
+ return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{O},
+ /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
}
Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
@@ -141,20 +138,18 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
StructuredIndexed I2, StructuredIndexed O) {
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
IteratorType::Parallel);
- if (O.getType().isa<RankedTensorType>()) {
- auto fun = [&binaryOp](ValueRange args) {
- assert(args.size() == 2 && "expected 2 block arguments");
- Value a(args[0]), b(args[1]);
- linalg_yield(binaryOp(a, b));
- };
- return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
- }
auto fun = [&binaryOp](ValueRange args) {
- assert(args.size() == 3 && "expected 3 block arguments");
+ assert(args.size() >= 2 && "expected >= 2 block arguments");
Value a(args[0]), b(args[1]);
linalg_yield(binaryOp(a, b));
};
- return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
+ if (O.getType().isa<RankedTensorType>())
+ return makeGenericLinalgOp(
+ iterTypes, /*inputs=*/{I1, I2}, /*outputBuffers=*/{},
+ /*initTensors=*/{}, /*resultTensorTypes=*/{O}, fun);
+ return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2},
+ /*outputBuffers=*/{O},
+ /*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
}
Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
@@ -185,23 +180,10 @@ mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
StructuredIndexed A(vA), B(vB), C(vC);
return makeGenericLinalgOp(
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
- {A({m, k}), B({k, n})},
- {C({m, n})},
- regionBuilder);
- // clang-format on
-}
-
-Operation *
-mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC,
- MatmulRegionBuilder regionBuilder) {
- // clang-format off
- AffineExpr m, n, k;
- bindDims(ScopedContext::getContext(), m, n, k);
- StructuredIndexed A(vA), B(vB), C(tC);
- return makeGenericLinalgOp(
- {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
- {A({m, k}), B({k, n})},
- {C({m, n})},
+ /*inputs=*/{A({m, k}), B({k, n})},
+ /*outputBuffers=*/{C({m, n})},
+ /*initTensors=*/{},
+ /*resultTensorTypes=*/{},
regionBuilder);
// clang-format on
}
@@ -216,8 +198,10 @@ mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
StructuredIndexed A(vA), B(vB), C(vC), D(tD);
return makeGenericLinalgOp(
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
- {A({m, k}), B({k, n}), C({m, n})},
- {D({m, n})},
+ /*inputs=*/{A({m, k}), B({k, n})},
+ /*outputBuffers=*/{},
+ /*initTensors=*/{C({m, n})},
+ /*resultTensorTypes=*/{D({m, n})},
regionBuilder);
// clang-format on
}
@@ -243,15 +227,18 @@ Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
StructuredIndexed I(vI), W(vW), O(vO);
// clang-format off
return makeGenericLinalgOp(
- {par, par, par, par, red, red, red}, {
+ {par, par, par, par, red, red, red},
+ /*inputs=*/{
I({b,
// Roundtrip to flattened form to serve as canonicalization and ensure
// consistent ordering of subexpressions.
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
- W({kh, kw, c, f})}, {
- O({b, h, w, f})},
+ W({kh, kw, c, f}) },
+ /*outputBuffers=*/{ O({b, h, w, f}) },
+ /*initTensors=*/{},
+ /*resultTensorTypes=*/{},
macRegionBuilder);
// clang-format on
}
@@ -276,15 +263,19 @@ Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
StructuredIndexed I(vI), W(vW), O(vO);
return makeGenericLinalgOp(
- {par, par, par, par, par, red, red}, {
+ {par, par, par, par, par, red, red},
+ /*inputs=*/{
I({b,
// Roundtrip to flattened form to serve as canonicalization and ensure
// consistent ordering of subexpressions.
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
- W({kh, kw, c, dm})}, {
+ W({kh, kw, c, dm})},
+ /*outputBuffers=*/{
O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
+ /*initTensors=*/{},
+ /*resultTensorTypes=*/{},
macRegionBuilder);
// clang-format on
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7b9ba74..e322a85 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -40,6 +40,12 @@ static void buildNamedStructuredOpRegionAndAttributes(
TypeRange outputBufferTypes, TypeRange initTensorTypes,
TypeRange resultTypes);
+static ParseResult
+parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
+ SmallVectorImpl<Type> &inputTypes,
+ SmallVectorImpl<Type> &outputBufferTypes,
+ SmallVectorImpl<Type> &initTensorTypes);
+
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
@@ -53,6 +59,10 @@ template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result);
+template <typename NamedStructuredOpType>
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+ NamedStructuredOpType op);
+
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
@@ -87,24 +97,26 @@ static LogicalResult foldMemRefCast(Operation *op) {
//===----------------------------------------------------------------------===//
// GenericOps
//===----------------------------------------------------------------------===//
-
void GenericOp::build(
- OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
- ValueRange args, int64_t argsIn, int64_t argsOut,
+ OpBuilder &builder, OperationState &result,
+ ArrayRef<Type> resultTensorTypes, ValueRange inputs,
+ ValueRange outputBuffers, ValueRange initTensors,
ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ StringRef doc, StringRef libraryCall, IntegerAttr symbolSource,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
- build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn),
- builder.getI64IntegerAttr(argsOut),
+ build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
- /*doc=*/nullptr, /*library_call=*/nullptr,
- /*symbol_source=*/nullptr);
+ doc.empty() ? StringAttr() : builder.getStringAttr(doc),
+ libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
+ symbolSource);
if (!bodyBuild)
return;
SmallVector<Type, 4> blockArgTypes;
- for (Value arg : args)
- blockArgTypes.push_back(arg.getType().cast<ShapedType>().getElementType());
+ for (ValueRange container : {inputs, outputBuffers, initTensors})
+ for (Value v : container)
+ blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
OpBuilder::InsertionGuard guard(builder);
auto &region = *result.regions.front();
@@ -112,25 +124,62 @@ void GenericOp::build(
bodyBuild(builder, result.location, bodyBlock->getArguments());
}
+void GenericOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+ IntegerAttr symbolSource,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+ build(builder, result, ArrayRef<Type>{}, inputs, outputBuffers, ValueRange{},
+ indexingMaps, iteratorTypes, doc, libraryCall, symbolSource, bodyBuild);
+}
+
+void GenericOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+ build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
+ /*doc=*/"",
+ /*libraryCall=*/"",
+ /*symbolSource=*/IntegerAttr(), bodyBuild);
+}
+
+void GenericOp::build(
+ OpBuilder &builder, OperationState &result,
+ ArrayRef<Type> resultTensorTypes, ValueRange inputs,
+ ValueRange outputBuffers, ValueRange initTensors,
+ ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+ build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
+ indexingMaps, iteratorTypes,
+ /*doc=*/"",
+ /*libraryCall=*/"",
+ /*symbolSource=*/IntegerAttr(), bodyBuild);
+}
+
void IndexedGenericOp::build(
- OpBuilder &builder, OperationState &result, ArrayRef<Type> resultTypes,
- ValueRange args, int64_t argsIn, int64_t argsOut,
+ OpBuilder &builder, OperationState &result,
+ ArrayRef<Type> resultTensorTypes, ValueRange inputs,
+ ValueRange outputBuffers, ValueRange initTensors,
ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ StringRef doc, StringRef libraryCall, IntegerAttr symbolSource,
function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
bodyBuild) {
- build(builder, result, resultTypes, args, builder.getI64IntegerAttr(argsIn),
- builder.getI64IntegerAttr(argsOut),
+ build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
- /*doc=*/nullptr, /*library_call=*/nullptr,
- /*symbol_source=*/nullptr);
+ doc.empty() ? StringAttr() : builder.getStringAttr(doc),
+ libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
+ symbolSource);
if (!bodyBuild)
return;
unsigned nLoops = iteratorTypes.size();
SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
- for (Value arg : args)
- blockArgTypes.push_back(arg.getType().cast<ShapedType>().getElementType());
+ for (ValueRange container : {inputs, outputBuffers, initTensors})
+ for (Value v : container)
+ blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
OpBuilder::InsertionGuard guard(builder);
auto &region = *result.regions.front();
@@ -140,26 +189,83 @@ void IndexedGenericOp::build(
bodyBlock->getArguments().drop_front(nLoops));
}
+void IndexedGenericOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
+ IntegerAttr symbolSource,
+ function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
+ bodyBuild) {
+ build(builder, result, ArrayRef<Type>{}, inputs, outputBuffers, ValueRange{},
+ indexingMaps, iteratorTypes, doc, libraryCall, symbolSource, bodyBuild);
+}
+
+void IndexedGenericOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<StringRef> iteratorTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
+ bodyBuild) {
+ build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
+ /*doc=*/"",
+ /*libraryCall=*/"",
+ /*symbolSource=*/IntegerAttr(), bodyBuild);
+}
+
+void IndexedGenericOp::build(
+ OpBuilder &builder, OperationState &result,
+ ArrayRef<Type> resultTensorTypes, ValueRange inputs,
+ ValueRange outputBuffers, ValueRange initTensors,
+ ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
+ function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
+ bodyBuild) {
+ build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
+ indexingMaps, iteratorTypes,
+ /*doc=*/"",
+ /*libraryCall=*/"",
+ /*symbolSource=*/IntegerAttr(), bodyBuild);
+}
+
template <typename GenericOpType>
static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
- auto attrNames = op.linalgTraitAttrNames();
- llvm::StringSet<> linalgTraitAttrsSet;
- linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
- SmallVector<NamedAttribute, 8> attrs;
+ p << op.getOperationName() << " ";
+
+ // Print extra attributes.
+ auto genericAttrNames = op.linalgTraitAttrNames();
+
+ llvm::StringSet<> genericAttrNamesSet;
+ genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
+ SmallVector<NamedAttribute, 8> genericAttrs;
for (auto attr : op.getAttrs())
- if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
- attrs.push_back(attr);
+ if (genericAttrNamesSet.count(attr.first.strref()) > 0)
+ genericAttrs.push_back(attr);
+ if (!genericAttrs.empty()) {
+ auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
+ p << genericDictAttr;
+ }
+
+ // Printing is shared with named ops, except for the region and attributes
+ printCommonStructuredOpParts(p, op);
+
+ genericAttrNames.push_back("operand_segment_sizes");
+ genericAttrNamesSet.insert(genericAttrNames.back());
- auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
- p << op.getOperationName() << " " << dictAttr;
- p.printOptionalAttrDict(op.getAttrs(), attrNames);
- p << " " << op.getOperands();
+ bool hasExtraAttrs = false;
+ for (NamedAttribute n : op.getAttrs()) {
+ if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref())))
+ break;
+ }
+ if (hasExtraAttrs) {
+ p << " attrs = ";
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/genericAttrNames);
+ }
+
+ // Print region.
if (!op.region().empty())
p.printRegion(op.region());
- p << ": " << op.getOperandTypes();
- auto outputTensorTypes = op.getResultTypes();
- if (!outputTensorTypes.empty())
- p << " -> " << outputTensorTypes;
+
+ // Print results.
+ printNamedStructuredOpResults(p, op.result_tensors().getTypes());
}
static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
@@ -169,7 +275,6 @@ static void print(OpAsmPrinter &p, IndexedGenericOp op) {
}
static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::OperandType, 8> operandsInfo, regionOperandsInfo;
DictionaryAttr dictAttr;
// Parse the core linalg traits that must check into a dictAttr.
// The name is unimportant as we will overwrite result.attributes.
@@ -180,26 +285,35 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
- // Optional attributes may be added.
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseOperandList(operandsInfo))
+ // Parsing is shared with named ops, except for the region.
+ SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
+ if (parseCommonStructuredOpParts(parser, result, inputTypes,
+ outputBufferTypes, initTensorTypes))
return failure();
- Region &region = *result.addRegion();
+ // Optional attributes may be added.
+ if (succeeded(parser.parseOptionalKeyword("attrs")))
+ if (failed(parser.parseEqual()) ||
+ failed(parser.parseOptionalAttrDict(result.attributes)))
+ return failure();
+
+ SmallVector<OpAsmParser::OperandType, 8> regionOperands;
+ std::unique_ptr<Region> region = std::make_unique<Region>();
SmallVector<Type, 8> operandTypes, regionTypes;
- if (parser.parseRegion(region, regionOperandsInfo, regionTypes))
- return failure();
- if (parser.parseColonTypeList(operandTypes))
+ if (parser.parseRegion(*region, regionOperands, regionTypes))
return failure();
+ result.addRegion(std::move(region));
+
// Generic ops may specify that a subset of its outputs are tensors. Such
// outputs are specified in the result type.
- SmallVector<Type, 8> tensorResultTypes;
- if (parser.parseOptionalArrowTypeList(tensorResultTypes))
+ // TODO: may need to move output parsing before region parsing.
+ // Need to wait for declarative assembly resolution to decide.
+ SmallVector<Type, 1> outputTensorsTypes;
+ if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
return failure();
- if (!tensorResultTypes.empty())
- result.addTypes(tensorResultTypes);
- return parser.resolveOperands(operandsInfo, operandTypes,
- parser.getCurrentLocation(), result.operands);
+ result.addTypes(outputTensorsTypes);
+
+ return success();
}
namespace {
@@ -266,6 +380,11 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
+ if (op.inputs().size() + op.output_buffers().size() +
+ op.init_tensors().size() + op.getNumResults() ==
+ 0)
+ return op.emitOpError("expected at least 1 Shaped operand or return");
+
auto &region = op.region();
if (!llvm::hasSingleElement(region))
return op.emitOpError("expected region with 1 block");
@@ -314,27 +433,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
return success();
}
-static LogicalResult verify(GenericOp op) {
- // Temporarily hoisted here to avoid duplicating more code.
- // TODO: uniformize with named structured ops.
- auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
- if (nInputsAndOutputBuffers != llvm::size(op.views()))
- return op.emitOpError("expected exactly ")
- << nInputsAndOutputBuffers
- << " inputs (tensor or buffer) and output buffer operands";
- return verifyGenericOp(op);
-}
+static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
-static LogicalResult verify(IndexedGenericOp op) {
- // Temporarily hoisted here to avoid duplicating more code.
- // TODO: uniformize with named structured ops.
- auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
- if (nInputsAndOutputBuffers != llvm::size(op.views()))
- return op.emitOpError("expected exactly ")
- << nInputsAndOutputBuffers
- << " inputs (tensor or buffer) and output buffer operands";
- return verifyGenericOp(op);
-}
+static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
//===----------------------------------------------------------------------===//
// ReshapeOp
@@ -1141,6 +1242,9 @@ static LogicalResult verify(PoolingSumOp op) {
/// Assumes `op` is a LinalgOp.
void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
SmallVectorImpl<AffineExpr> &res) {
+ if (!cast<LinalgOp>(op).iterator_types())
+ return;
+
unsigned dim = 0;
MLIRContext *ctx = op->getContext();
for (auto tn :
@@ -1341,59 +1445,50 @@ parseNamedStructuredOpResults(OpAsmParser &parser,
return success();
}
-template <typename NamedStructuredOpType>
-static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
- OperationState &result) {
+static ParseResult
+parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
+ SmallVectorImpl<Type> &inputTypes,
+ SmallVectorImpl<Type> &outputBufferTypes,
+ SmallVectorImpl<Type> &initTensorTypes) {
llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc,
initTensorsOperandsLoc;
SmallVector<OpAsmParser::OperandType, 4> inputsOperands,
outputBuffersOperands, initTensorsOperands;
- SmallVector<Type, 1> inputsTypes, outputBuffersTypes, initTensorsTypes,
- outputTensorsTypes;
- std::unique_ptr<Region> regionRegion = std::make_unique<Region>();
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseKeyword("ins") || parser.parseLParen())
- return failure();
+ parser.parseOptionalAttrDict(result.attributes);
- inputsOperandsLoc = parser.getCurrentLocation();
- if (parser.parseOperandList(inputsOperands) || parser.parseColon() ||
- parser.parseTypeList(inputsTypes) || parser.parseRParen())
- return failure();
+ if (succeeded(parser.parseOptionalKeyword("ins"))) {
+ if (parser.parseLParen())
+ return failure();
+
+ inputsOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(inputsOperands) ||
+ parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+ return failure();
+ }
if (succeeded(parser.parseOptionalKeyword("outs"))) {
outputBuffersOperandsLoc = parser.getCurrentLocation();
if (parser.parseLParen() ||
- parser.parseOperandList(outputBuffersOperands) || parser.parseColon() ||
- parser.parseTypeList(outputBuffersTypes) || parser.parseRParen())
+ parser.parseOperandList(outputBuffersOperands) ||
+ parser.parseColonTypeList(outputBufferTypes) || parser.parseRParen())
return failure();
}
if (succeeded(parser.parseOptionalKeyword("init"))) {
initTensorsOperandsLoc = parser.getCurrentLocation();
if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) ||
- parser.parseColon() || parser.parseTypeList(initTensorsTypes) ||
- parser.parseRParen())
+ parser.parseColonTypeList(initTensorTypes) || parser.parseRParen())
return failure();
}
- if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
- return failure();
-
- if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
- parser, *regionRegion, inputsTypes, outputBuffersTypes,
- initTensorsTypes, outputTensorsTypes))
- return failure();
-
- if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
+ if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
result.operands) ||
- parser.resolveOperands(outputBuffersOperands, outputBuffersTypes,
+ parser.resolveOperands(outputBuffersOperands, outputBufferTypes,
outputBuffersOperandsLoc, result.operands) ||
- parser.resolveOperands(initTensorsOperands, initTensorsTypes,
+ parser.resolveOperands(initTensorsOperands, initTensorTypes,
initTensorsOperandsLoc, result.operands))
return failure();
- result.addTypes(outputTensorsTypes);
- result.addRegion(std::move(regionRegion));
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(
{static_cast<int32_t>(inputsOperands.size()),
@@ -1402,28 +1497,61 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
return success();
}
+template <typename NamedStructuredOpType>
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
+ if (parseCommonStructuredOpParts(parser, result, inputTypes,
+ outputBufferTypes, initTensorTypes))
+ return failure();
+
+ // TODO: consider merging results parsing into region parsing.
+ // Need to wait for declarative assembly resolution to decide.
+ SmallVector<Type, 1> outputTensorsTypes;
+ if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
+ return failure();
+ result.addTypes(outputTensorsTypes);
+
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
+ parser, *region, inputTypes, outputBufferTypes, initTensorTypes,
+ outputTensorsTypes))
+ return failure();
+ result.addRegion(std::move(region));
+
+ return success();
+}
+
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes) {
if (resultTypes.empty())
return;
- p << "-> " << resultTypes;
+ p.printOptionalArrowTypeList(resultTypes);
}
template <typename NamedStructuredOpType>
-static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
- p << op.getOperationName();
- p.printOptionalAttrDict(op.getAttrs(),
- /*elidedAttrs=*/{"operand_segment_sizes"});
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+ NamedStructuredOpType op) {
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
if (!op.output_buffers().empty())
p << " outs(" << op.output_buffers() << " : "
<< op.output_buffers().getTypes() << ")";
if (!op.init_tensors().empty())
p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes()
- << ")";
- p << " ";
- printNamedStructuredOpResults(p, op.output_tensors().getTypes());
- p << " ";
+ << ") ";
+}
+
+template <typename NamedStructuredOpType>
+static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
+ p << op.getOperationName();
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{"operand_segment_sizes"});
+
+ // Printing is shared with generic ops, except for the region and attributes.
+ printCommonStructuredOpParts(p, op);
+
+ // Results printing.
+ printNamedStructuredOpResults(p, op.result_tensors().getTypes());
// Region is elided.
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 65fd197..1366477 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -261,12 +261,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
}
namespace {
+
/// Pattern to replace tensors operands/results that are unit extents.
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!genericOp.hasTensorSemantics())
+ // TODO: support init_tensors and reductions.
+ if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty())
return failure();
MLIRContext *context = rewriter.getContext();
@@ -283,8 +285,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
- doCanonicalization =
- doCanonicalization || replacementInfo.type != std::get<1>(it);
+ doCanonicalization |= replacementInfo.type != std::get<1>(it);
}
// If the indexing maps of the result operation are not invertible (i.e. not
@@ -295,32 +296,40 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
// If any operand type change, insert a reshape to convert from the original
// type to the new type.
- SmallVector<Value, 4> newOperands;
- newOperands.reserve(genericOp.getNumOperands());
- for (auto operand : llvm::enumerate(genericOp.getOperands())) {
- if (operand.value().getType() == newInputOutputTypes[operand.index()]) {
- newOperands.push_back(operand.value());
- } else {
- newOperands.push_back(rewriter.create<linalg::TensorReshapeOp>(
- loc, newInputOutputTypes[operand.index()], operand.value(),
- reassociationMaps[operand.index()]));
+ // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
+ unsigned flattenedIdx = 0;
+ auto insertReshapes = [&](ValueRange values) {
+ SmallVector<Value, 4> res;
+ res.reserve(values.size());
+ for (auto operand : llvm::enumerate(values)) {
+ if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
+ res.push_back(operand.value());
+ else
+ res.push_back(rewriter.create<linalg::TensorReshapeOp>(
+ loc, newInputOutputTypes[flattenedIdx], operand.value(),
+ reassociationMaps[flattenedIdx]));
+ ++flattenedIdx;
}
- }
+ return res;
+ };
+
+ SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
+ SmallVector<Value, 4> newOutputBuffers =
+ insertReshapes(genericOp.output_buffers());
+ SmallVector<Value, 4> newInitTensors =
+ insertReshapes(genericOp.init_tensors());
// If any result type change, insert a reshape to convert from the original
// type to the new type.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(genericOp.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(
- newInputOutputTypes[i + genericOp.getNumOperands()]);
+ resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
GenericOp replacementOp = rewriter.create<GenericOp>(
- loc, resultTypes, newOperands, genericOp.args_in(),
- genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps),
- genericOp.iterator_types(),
- /*doc = */ nullptr,
- /*library_call = */ nullptr,
- /*symbol_source = */ nullptr);
+ loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
+ newIndexingMaps,
+ llvm::to_vector<4>(
+ genericOp.iterator_types().getAsValueRange<StringAttr>()));
rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
replacementOp.region().begin());
@@ -332,12 +341,11 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
RankedTensorType origResultType = genericOp.getResult(result.index())
.getType()
.cast<RankedTensorType>();
- if (origResultType != result.value().getType()) {
+ if (origResultType != result.value().getType())
resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, origResultType, result.value(), reassociationMaps[index]));
- } else {
+ else
resultReplacements.push_back(result.value());
- }
}
rewriter.replaceOp(genericOp, resultReplacements);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index adbf4a7..04d4174 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -499,32 +499,31 @@ struct FuseGenericOpsOnTensors {
consumerIndexMaps.end());
// Generate the fused op.
+ // Tensor-level fusion is only on ops without initTensors and outputBuffers.
LinalgOp fusedOp;
if (isa<GenericOp>(producer.getOperation()) &&
isa<GenericOp>(consumer.getOperation())) {
fusedOp =
rewriter
- .create<GenericOp>(
- rewriter.getUnknownLoc(),
- consumer.getOperation()->getResultTypes(), fusedOperands,
- rewriter.getI64IntegerAttr(fusedOperands.size()),
- rewriter.getI64IntegerAttr(
- consumer.getOperation()->getNumResults()),
- rewriter.getArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr,
- /*symbol_source=*/nullptr)
+ .create<GenericOp>(consumer.getLoc(),
+ consumer.getOperation()->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{},
+ rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr,
+ /*symbol_source=*/nullptr)
.getOperation();
} else {
fusedOp =
rewriter
.create<IndexedGenericOp>(
- rewriter.getUnknownLoc(),
- consumer.getOperation()->getResultTypes(), fusedOperands,
- rewriter.getI64IntegerAttr(fusedOperands.size()),
- rewriter.getI64IntegerAttr(
- consumer.getOperation()->getNumResults()),
+ consumer.getLoc(), consumer.getOperation()->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{},
rewriter.getArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
@@ -812,9 +811,10 @@ struct FuseTensorReshapeOpAsProducer {
}));
LinalgOp fusedOp = createLinalgOpOfSameType(
consumer, rewriter, rewriter.getUnknownLoc(),
- consumerOp->getResultTypes(), fusedOperands,
- rewriter.getI64IntegerAttr(fusedOperands.size()),
- rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
+ consumerOp->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{}, // no init tensors for now.
rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
@@ -871,10 +871,10 @@ struct FuseTensorReshapeOpAsConsumer {
Operation *producerOp = producer.getOperation();
LinalgOp fusedOp = createLinalgOpOfSameType(
producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
- producerOp->getOperands(),
- rewriter.getI64IntegerAttr(producerOp->getNumOperands()),
- rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
- producer.iterator_types(),
+ /*inputs=*/producerOp->getOperands(),
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{}, // no init tensors for now.
+ rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
/*symbol_source=*/nullptr);
@@ -932,10 +932,10 @@ struct FuseTensorReshapeOpAsConsumer {
}
int rank = dstShape.size();
- int numArgsIn = producer.getNumInputs();
- int numArgsOut = producer.getNumOutputs();
auto genericOp = rewriter.create<linalg::GenericOp>(
- loc, resultTypes, args, numArgsIn, numArgsOut,
+ loc, resultTypes, /*inputs=*/args,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{},
SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
rewriter.getMultiDimIdentityMap(rank)),
SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
@@ -995,9 +995,10 @@ struct FuseConstantOpAsProducer {
LinalgOp fusedOp = createLinalgOpOfSameType(
consumer, rewriter, rewriter.getUnknownLoc(),
- consumerOp->getResultTypes(), fusedOperands,
- rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1),
- rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
+ consumerOp->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ /*outputBuffers=*/ValueRange{},
+ /*initTensors=*/ValueRange{}, // no init tensors for now.
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 6af0067..7f671fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -36,32 +36,45 @@ public:
LogicalResult
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ linalg::GenericOpAdaptor adaptor(operands,
+ op.getOperation()->getAttrDictionary());
+
+ // TODO: support ops with reduction.
+ if (!op.init_tensors().empty())
+ return failure();
+
+ // All inputs need to be turned into buffers first. Until then, bail out.
+ if (llvm::any_of(adaptor.inputs(),
+ [](Value in) { return !in.getType().isa<MemRefType>(); }))
+ return failure();
+
Location loc = op.getLoc();
- ResultRange results = op.getOperation()->getResults();
- SmallVector<Value, 2> newArgs, newResults;
- newArgs.reserve(operands.size() + results.size());
- newArgs.append(operands.begin(), operands.end());
- newResults.reserve(results.size());
+ SmallVector<Value, 2> outputBuffers, newOutputBuffers;
+ outputBuffers.assign(adaptor.output_buffers().begin(),
+ adaptor.output_buffers().end());
+ newOutputBuffers.reserve(op.getNumOutputs());
+ newOutputBuffers.append(adaptor.output_buffers().begin(),
+ adaptor.output_buffers().end());
// Update all types to memref types.
- for (auto result : results) {
- auto type = result.getType().cast<ShapedType>();
- assert(type && "tensor to buffer conversion expects ranked results");
+ for (Type t : op.getResultTypes()) {
+ auto type = t.cast<ShapedType>();
if (!type.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "dynamic shapes not currently supported");
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
- newArgs.push_back(alloc);
- newResults.push_back(alloc);
+ newOutputBuffers.push_back(alloc);
}
// Generate a new linalg operation that works on buffers.
auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
- rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
- op.iterator_types(), op.docAttr(), op.library_callAttr(),
- op.symbol_sourceAttr());
+ loc,
+ /*resultTensorTypes=*/ArrayRef<Type>{},
+ /*inputs=*/adaptor.inputs(),
+ /*outputBuffers=*/newOutputBuffers,
+ /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(),
+ op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr());
// Create a new block in the region of the new Generic Op.
Block &oldBlock = op.getRegion().front();
@@ -70,23 +83,23 @@ public:
oldBlock.getArgumentTypes());
// Add the result arguments to the new block.
- for (auto result : newResults)
- newBlock->addArgument(
- result.getType().cast<ShapedType>().getElementType());
+ for (Value v : newOutputBuffers)
+ newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
// Clone the body of the old block to the new block.
BlockAndValueMapping mapping;
for (unsigned i = 0; i < oldBlock.getNumArguments(); i++)
mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i));
+
+ OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(newBlock);
for (auto &op : oldBlock.getOperations()) {
Operation *clonedOp = rewriter.clone(op, mapping);
mapping.map(op.getResults(), clonedOp->getResults());
}
- // Replace the results of the old Generic Op with the results of the new
- // one.
- rewriter.replaceOp(op, newResults);
+ // Replace the results of the old op with the new output buffers.
+ rewriter.replaceOp(op, newOutputBuffers);
return success();
}
};
diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
index d437ab1..0fac017 100644
--- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
+++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
@@ -5,8 +5,6 @@
//===----------------------------------------------------------------------===//
#single_workgroup_reduction_trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["reduction"],
indexing_maps = [
affine_map<(i) -> (i)>,
@@ -49,11 +47,13 @@ module attributes {
func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}
} {
- linalg.generic #single_workgroup_reduction_trait %input, %output {
+ linalg.generic #single_workgroup_reduction_trait
+ ins(%input : memref<16xi32>)
+ outs(%output : memref<1xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
- } : memref<16xi32>, memref<1xi32>
+ }
spv.Return
}
}
@@ -63,8 +63,6 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
// Missing shader entry point ABI
#single_workgroup_reduction_trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["reduction"],
indexing_maps = [
affine_map<(i) -> (i)>,
@@ -78,11 +76,13 @@ module attributes {
} {
func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) {
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
- linalg.generic #single_workgroup_reduction_trait %input, %output {
+ linalg.generic #single_workgroup_reduction_trait
+ ins(%input : memref<16xi32>)
+ outs(%output : memref<1xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
- } : memref<16xi32>, memref<1xi32>
+ }
return
}
}
@@ -92,8 +92,6 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
// Mismatch between shader entry point ABI and input memref shape
#single_workgroup_reduction_trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["reduction"],
indexing_maps = [
affine_map<(i) -> (i)>,
@@ -109,11 +107,13 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>}
} {
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
- linalg.generic #single_workgroup_reduction_trait %input, %output {
+ linalg.generic #single_workgroup_reduction_trait
+ ins(%input : memref<16xi32>)
+ outs(%output : memref<1xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
- } : memref<16xi32>, memref<1xi32>
+ }
spv.Return
}
}
@@ -123,8 +123,6 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
// Unsupported multi-dimension input memref
#single_workgroup_reduction_trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "reduction"],
indexing_maps = [
affine_map<(i, j) -> (i, j)>,
@@ -140,11 +138,13 @@ func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi3
spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>}
} {
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
- linalg.generic #single_workgroup_reduction_trait %input, %output {
+ linalg.generic #single_workgroup_reduction_trait
+ ins(%input : memref<16x8xi32>)
+ outs(%output : memref<16xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
- } : memref<16x8xi32>, memref<16xi32>
+ }
spv.Return
}
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 23e9f17..5e0890f 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -182,8 +182,6 @@ func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
]
#trait = {
- args_in = 1,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel"]
}
@@ -193,10 +191,10 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32>
// tensor<0xf32> cannot be dce'ed
- %1 = linalg.generic #trait %arg1 {
+ %1 = linalg.generic #trait ins(%arg1 : tensor<0xf32>) {
^bb(%0: f32) :
linalg.yield %0 : f32
- } : tensor<0xf32> -> tensor<0xf32>
+ } -> tensor<0xf32>
return %1: tensor<0xf32>
}
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index c4ea61a..5e30b52 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -6,8 +6,6 @@
]
#trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_func"
@@ -15,10 +13,11 @@
func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
{
- %0 = linalg.generic #trait %arg0 {
+ %0 = linalg.generic #trait
+ ins(%arg0 : tensor<?x1x?xf32>) {
^bb0(%arg1 : f32) :
linalg.yield %arg1 : f32
- } : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
+ } -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -40,8 +39,6 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "parallel"],
indexing_maps = #access,
library_call = "some_external_func"
@@ -49,10 +46,11 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
{
- %0 = linalg.generic #trait %arg0 {
+ %0 = linalg.generic #trait
+ ins(%arg0 : tensor<1x1xf32>) {
^bb0(%arg1: f32) :
linalg.yield %arg1 : f32
- } : tensor<1x1xf32> -> tensor<1x1xf32>
+ } -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> ()>
@@ -70,18 +68,17 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
]
#trait = {
- args_in = 1,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel"],
library_call = "some_external_fn"
}
func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
- %0 = linalg.generic #trait %arg0 {
+ %0 = linalg.generic #trait
+ ins(%arg0 : tensor<1x5xf32>) {
^bb0(%arg2: f32): // no predecessors
linalg.yield %arg2 : f32
- } : tensor<1x5xf32> -> tensor<5xf32>
+ } -> tensor<5xf32>
return %0 : tensor<5xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -100,8 +97,6 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
]
#trait = {
- args_in = 2,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel"],
library_call = "some_external_fn"
@@ -113,11 +108,12 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5
tensor<5xf32> into tensor<1x5xf32>
%1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
tensor<5xf32> into tensor<5x1xf32>
- %2 = linalg.generic #trait %0, %1 {
+ %2 = linalg.generic #trait
+ ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%3 = addf %arg2, %arg3 : f32
linalg.yield %3 : f32
- } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+ } -> tensor<5x5xf32>
return %2 : tensor<5x5xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
@@ -138,8 +134,6 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5
]
#trait = {
- args_in = 1,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel"],
library_call = "some_external_fn"
@@ -147,10 +141,11 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5
func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic #trait %arg0 {
- ^bb0(%arg1 : f32):
- linalg.yield %arg1 : f32
- } : tensor<1x1xf32> -> tensor<?x?xf32>
+ %0 = linalg.generic #trait
+ ins(%arg0 : tensor<1x1xf32>) {
+ ^bb0(%arg1 : f32):
+ linalg.yield %arg1 : f32
+ } -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
diff --git a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
index 14c0797..6d75c48 100644
--- a/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
+++ b/mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
@@ -6,8 +6,6 @@
]
#trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_func"
@@ -15,10 +13,11 @@
func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
{
- %0 = linalg.generic #trait %arg0 {
+ %0 = linalg.generic #trait
+ ins(%arg0 : tensor<?x1x?xf32>) {
^bb0(%arg1 : f32) :
linalg.yield %arg1 : f32
- } : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
+ } -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
@@ -33,8 +32,6 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "parallel"],
indexing_maps = #access,
library_call = "some_external_func"
@@ -42,10 +39,11 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
{
- %0 = linalg.generic #trait %arg0 {
+ %0 = linalg.generic #trait
+ ins(%arg0 : tensor<1x1xf32>) {
^bb0(%arg1: f32) :
linalg.yield %arg1 : f32
- } : tensor<1x1xf32> -> tensor<1x1xf32>
+ } -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
@@ -59,8 +57,6 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "parallel"],
indexing_maps = #access,
library_call = "some_external_func"
@@ -68,10 +64,12 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
{
- linalg.generic #trait %arg0, %arg1 {
+ linalg.generic #trait
+ ins(%arg0 : memref<1x1xf32>)
+ outs(%arg1 : memref<1x1xf32>) {
^bb0(%arg2: f32, %arg3 : f32) :
linalg.yield %arg2 : f32
- } : memref<1x1xf32>, memref<1x1xf32>
+ }
return
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
@@ -88,18 +86,17 @@ func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
]
#trait = {
- args_in = 1,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel"],
library_call = "some_external_fn"
}
func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
- %0 = linalg.generic #trait %arg0 {
- ^bb0(%arg2: f32): // no predecessors
- linalg.yield %arg2 : f32
- } : tensor<1x5xf32> -> tensor<5xf32>
+ %0 = linalg.generic #trait
+ ins(%arg0 : tensor<1x5xf32>) {
+ ^bb0(%arg2: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<5xf32>
return %0 : tensor<5xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (0, d0)>
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index ac2d3e2..ccadff5 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -6,14 +6,16 @@
// CHECK-LABEL: @add_mul_fusion
func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = addf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
- // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64
+ } -> tensor<?x?xf32>
+ // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
- %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 {
+ %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
// CHECK: ^{{[a-zA-Z0-9_]*}}
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
@@ -25,7 +27,7 @@ func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : te
// CHECK: linalg.yield
%3 = mulf %arg5, %arg6 : f32
linalg.yield %3 : f32
- }: tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ } -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
@@ -39,18 +41,20 @@ func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : te
// CHECK-LABEL: @transpose_add_mul_fusion
func @transpose_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = addf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
- // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64
+ } -> tensor<?x?xf32>
+ // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP0]], [[$MAP0]]{{\]}}
- %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 {
+ %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%arg5: f32, %arg6: f32): // no predecessors
%3 = mulf %arg5, %arg6 : f32
linalg.yield %3 : f32
- }: tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ } -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
@@ -64,18 +68,20 @@ func @transpose_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK-LABEL: @add_transpose_mul_fusion
func @add_transpose_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = addf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
- // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64
+ } -> tensor<?x?xf32>
+ // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
- %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 {
+ %2 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%0, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%arg5: f32, %arg6: f32): // no predecessors
%3 = mulf %arg5, %arg6 : f32
linalg.yield %3 : f32
- }: tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ } -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
@@ -90,18 +96,20 @@ func @add_transpose_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// CHECK-LABEL: @add_broadcast_mul_fusion
func @add_broadcast_mul_fusion(%arg0: tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
- %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} %arg0, %arg1 {
+ %0 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = addf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<?xf32>, tensor<?xf32> -> tensor<?xf32>
- // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64
+ } -> tensor<?xf32>
+ // CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP1]], [[$MAP0]], [[$MAP0]]
- %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 {
+ %2 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%0, %arg2 : tensor<?xf32>, tensor<?x?xf32>) {
^bb0(%arg5: f32, %arg6: f32): // no predecessors
%3 = mulf %arg5, %arg6 : f32
linalg.yield %3 : f32
- }: tensor<?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+ } -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
@@ -113,19 +121,21 @@ func @add_broadcast_mul_fusion(%arg0: tensor<?xf32>, %arg1 : tensor<?xf32>, %arg
// CHECK-LABEL: @add_mul_scalar_fusion
func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32>
{
- %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %arg0, %arg1 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
+ ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = addf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<f32>, tensor<f32> -> tensor<f32>
- // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64
+ } -> tensor<f32>
+ // CHECK: linalg.generic {
// CHECK: addf
// CHECK: mulf
- %1 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %0, %arg2 {
+ %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
+ ins(%0, %arg2 : tensor<f32>, tensor<f32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<f32>, tensor<f32> -> tensor<f32>
+ } -> tensor<f32>
return %1 : tensor<f32>
}
@@ -144,25 +154,23 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
affine_map<(i, j, k, l) -> (j, k)>,
affine_map<(i, j, k, l) -> (l)>] :
tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
- %1 = linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %1 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- %0, %arg1 {
+ ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32> -> tensor<?x?x4x?xf32>
+ } -> tensor<?x?x4x?xf32>
return %1 : tensor<?x?x4x?xf32>
}
// CHECK-LABEL: func @generic_op_reshape_producer_fusion
// CHECK: linalg.generic
-// CHECK-SAME: args_in = 2
-// CHECK-SAME: args_out = 1
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]]
// CHECK-NOT: linalg.generic
+
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -173,15 +181,14 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
%arg1 : tensor<?x?x4x5xf32>) ->
tensor<?x?xf32>
{
- %0 = linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %0 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- %arg0, %arg1 {
+ ins(%arg0, %arg1 : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32> -> tensor<?x?x4x5xf32>
+ } -> tensor<?x?x4x5xf32>
%1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
affine_map<(i, j, k, l) -> (j, k, l)>] :
tensor<?x?x4x5xf32> into tensor<?x?xf32>
@@ -190,8 +197,6 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
// CHECK-LABEL: func @generic_op_reshape_consumer_fusion
// CHECK: linalg.generic
-// CHECK-SAME: args_in = 2
-// CHECK-SAME: args_out = 1
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]]
// CHECK-NOT: linalg.generic
@@ -202,15 +207,14 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
%arg1 : tensor<?x?x?x5xf32>) ->
tensor<?x?xf32>
{
- %0 = linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %0 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- %arg0, %arg1 {
+ ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
- }: tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32> -> tensor<?x?x?x5xf32>
+ } -> tensor<?x?x?x5xf32>
%1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
affine_map<(i, j, k, l) -> (j, k, l)>] :
tensor<?x?x?x5xf32> into tensor<?x?xf32>
@@ -229,15 +233,14 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
-> tensor<8x33x4xf32> {
%cst = constant dense<2.000000e+00> : tensor<264x4xf32>
- %0 = linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %0 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
- %arg0, %cst {
+ ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>) {
^bb0(%arg1: f32, %arg2: f32): // no predecessors
%2 = mulf %arg1, %arg2 : f32
linalg.yield %2 : f32
- }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32>
+ } -> tensor<264x4xf32>
%1 = linalg.tensor_reshape %0 [#map1, #map2] :
tensor<264x4xf32> into tensor<8x33x4xf32>
return %1 : tensor<8x33x4xf32>
@@ -251,7 +254,8 @@ func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
// CHECK: %[[CST:.*]] = constant {{.*}} : f32
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK: tensor<264x4xf32> -> tensor<8x33x4xf32>
+// CHECK-SAME: tensor<264x4xf32>
+// CHECK: -> tensor<8x33x4xf32>
// CHECK-NOT: linalg.tensor_reshape
// -----
@@ -261,23 +265,20 @@ func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
{
%0 = constant dense<42.0> : tensor<5xf32>
- %1 = linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %1 = linalg.generic {
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
- %0, %arg0 {
+ ins(%0, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%2 = mulf %arg1, %arg2 : f32
linalg.yield %2 : f32
- }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ } -> tensor<5x?x?xf32>
return %1 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @generic_op_constant_fusion
// CHECK: %[[CST:.*]] = constant {{.*}} : f32
// CHECK: linalg.generic
-// CHECK-SAME: args_in = 1 : i64
-// CHECK-SAME: args_out = 1 : i64
// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
// CHECK: mulf %[[CST]], %[[ARG1]]
@@ -289,23 +290,20 @@ func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
-> tensor<5x?x?xf32>
{
%0 = constant dense<42.0> : tensor<5xf32>
- %1 = linalg.indexed_generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %1 = linalg.indexed_generic {
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
- %0, %arg0 {
+ ins(%0, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>) {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32):
%2 = mulf %arg4, %arg5 : f32
linalg.yield %2 : f32
- }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ } -> tensor<5x?x?xf32>
return %1 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @indexed_generic_op_constant_fusion
// CHECK: %[[CST:.*]] = constant {{.*}} : f32
// CHECK: linalg.indexed_generic
-// CHECK-SAME: args_in = 1 : i64
-// CHECK-SAME: args_out = 1 : i64
// CHECK: ^{{[a-zA-Z0-9_]*}}
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index
@@ -321,23 +319,20 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
-> tensor<5x?x?xf32>
{
%0 = constant dense<42.0> : tensor<f32>
- %1 = linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %1 = linalg.generic {
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
- %0, %arg0 {
+ ins(%0, %arg0 : tensor<f32>, tensor<5x?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%2 = mulf %arg1, %arg2 : f32
linalg.yield %2 : f32
- }: tensor<f32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ } -> tensor<5x?x?xf32>
return %1 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @generic_op_zero_dim_constant_fusion
// CHECK: %[[CST:.*]] = constant {{.*}} : f32
// CHECK: linalg.generic
-// CHECK-SAME: args_in = 1 : i64
-// CHECK-SAME: args_out = 1 : i64
// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
// CHECK: mulf %[[CST]], %[[ARG1]]
@@ -349,23 +344,20 @@ func @indexed_generic_op_zero_dim_constant_fusion
(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
{
%0 = constant dense<42.0> : tensor<f32>
- %1 = linalg.indexed_generic
- {args_in = 2 : i64, args_out = 1 : i64,
+ %1 = linalg.indexed_generic {
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
- %0, %arg0 {
+ ins(%0, %arg0 : tensor<f32>, tensor<5x?x?xf32>) {
^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32):
%2 = mulf %arg4, %arg5 : f32
linalg.yield %2 : f32
- }: tensor<f32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ } -> tensor<5x?x?xf32>
return %1 : tensor<5x?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion
// CHECK: %[[CST:.*]] = constant {{.*}} : f32
// CHECK: linalg.indexed_generic
-// CHECK-SAME: args_in = 1 : i64
-// CHECK-SAME: args_out = 1 : i64
// CHECK: ^{{[a-zA-Z0-9_]*}}
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index
@@ -379,34 +371,30 @@ func @indexed_generic_op_zero_dim_constant_fusion
func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
%arg1: tensor<?x?xi32>) {
%0 = linalg.generic {
- args_in = 2 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"] } %arg0, %arg1 {
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%10 = addi %arg2, %arg3 : i32
linalg.yield %10 : i32
- } : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
+ } -> tensor<?x?xi32>
%1 = linalg.indexed_generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] } %0 {
+ iterator_types = ["parallel", "parallel"] }
+ ins(%0 : tensor<?x?xi32>) {
^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = index_cast %arg3 : index to i32
%4 = addi %arg4, %2 : i32
%5 = subi %4, %3 : i32
linalg.yield %5 : i32
- }: tensor<?x?xi32> -> tensor<?x?xi32>
+ } -> tensor<?x?xi32>
return
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion
// CHECK-NOT: linalg.generic
// CHECK: linalg.indexed_generic
-// CHECK-SAME: args_in = 2
-// CHECK-SAME: args_out = 1
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
// CHECK: ^{{[a-zA-Z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
@@ -426,33 +414,29 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
%arg1: tensor<?x?xi32>) {
%0 = linalg.indexed_generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] } %arg0 {
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?xi32>) {
^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = index_cast %arg3 : index to i32
%4 = addi %arg4, %2 : i32
%5 = subi %4, %3 : i32
linalg.yield %5 : i32
- }: tensor<?x?xi32> -> tensor<?x?xi32>
+ } -> tensor<?x?xi32>
%1 = linalg.generic {
- args_in = 2 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"] } %0, %arg1 {
+ iterator_types = ["parallel", "parallel"] }
+ ins(%0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%10 = addi %arg2, %arg3 : i32
linalg.yield %10 : i32
- } : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
+ } -> tensor<?x?xi32>
return
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion
// CHECK: linalg.indexed_generic
-// CHECK-SAME: args_in = 2
-// CHECK-SAME: args_out = 1
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
// CHECK: ^{{[a-zA-Z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
@@ -474,36 +458,32 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
%0 = linalg.indexed_generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] } %arg0 {
+ iterator_types = ["parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?xi32>) {
^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = index_cast %arg3 : index to i32
%4 = addi %arg4, %2 : i32
%5 = subi %4, %3 : i32
linalg.yield %5 : i32
- }: tensor<?x?xi32> -> tensor<?x?xi32>
+ } -> tensor<?x?xi32>
%1 = linalg.indexed_generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map1, #map1],
- iterator_types = ["parallel", "parallel"] } %0 {
+ iterator_types = ["parallel", "parallel"] }
+ ins(%0 : tensor<?x?xi32>) {
^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = index_cast %arg3 : index to i32
%4 = addi %arg4, %2 : i32
%5 = subi %4, %3 : i32
linalg.yield %5 : i32
- }: tensor<?x?xi32> -> tensor<?x?xi32>
+ } -> tensor<?x?xi32>
return
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @indexed_generic_op_fusion
// CHECK: linalg.indexed_generic
-// CHECK-SAME: args_in = 1
-// CHECK-SAME: args_out = 1
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
// CHECK: ^{{[a-zA-Z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
@@ -533,23 +513,20 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
affine_map<(i, j, k, l) -> (l)>] :
tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
%1 = linalg.indexed_generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %0 {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
+ ins(%0 : tensor<?x?x4x?xi32>) {
^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = addi %arg6, %2 : i32
linalg.yield %3 : i32
- }: tensor<?x?x4x?xi32> -> tensor<?x?x4x?xi32>
+ } -> tensor<?x?x4x?xi32>
return %1 : tensor<?x?x4x?xi32>
}
// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK: linalg.indexed_generic
-// CHECK-SAME: args_in = 1
-// CHECK-SAME: args_out = 1
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
// CHECK-NOT: linalg.tensor_reshape
@@ -562,15 +539,14 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
-> tensor<?x?xi32> {
%0 = linalg.indexed_generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %arg0 {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?x4x5xi32>) {
^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = addi %arg6, %2 : i32
linalg.yield %3 : i32
- }: tensor<?x?x4x5xi32> -> tensor<?x?x4x5xi32>
+ } -> tensor<?x?x4x5xi32>
%1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
affine_map<(i, j, k, l) -> (j, k, l)>] :
tensor<?x?x4x5xi32> into tensor<?x?xi32>
@@ -580,7 +556,5 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK: linalg.indexed_generic
-// CHECK-SAME: args_in = 1
-// CHECK-SAME: args_out = 1
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
// CHECK-NOT: linalg.tensor_reshape
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 38e1b43..788cb89 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -470,8 +470,6 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
#id_2d = affine_map<(i, j) -> (i, j)>
#pointwise_2d_trait = {
- args_in = 2,
- args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
@@ -483,13 +481,14 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%c0 = constant 0 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
- linalg.generic #pointwise_2d_trait %A, %A, %B {
+ linalg.generic #pointwise_2d_trait
+ ins(%A, %A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
+ memref<?x?xf32, offset: 0, strides: [?, ?]>)
+ outs(%B : memref<?x?xf32, offset: 0, strides: [?, ?]>) {
^bb0(%E: f32, %arg5: f32, %arg6: f32): // no predecessors
%2 = addf %E, %arg5 : f32
linalg.yield %2 : f32
- }: memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>,
- memref<?x?xf32, offset: 0, strides: [?, ?]>
+ }
%0 = dim %B, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
scf.for %arg4 = %c0 to %0 step %c2 {
@@ -503,13 +502,14 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.generic #pointwise_2d_trait %4, %5, %6 {
+ linalg.generic #pointwise_2d_trait
+ ins(%4, %5: memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%6 : memref<?x?xf32, offset: ?, strides: [?, ?]>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
%7 = mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
- }: memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>
+ }
}
}
return
@@ -527,8 +527,6 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
#id_2d = affine_map<(i, j) -> (i, j)>
#pointwise_2d_trait = {
- args_in = 2,
- args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
@@ -542,13 +540,13 @@ func @pointwise_no_view(%M: index, %N: index) {
%C = alloc (%M, %N): memref<?x?xf32>
%D = alloc (%M, %N): memref<?x?xf32>
%E = alloc (%M, %N): memref<?x?xf32>
- linalg.generic #pointwise_2d_trait %A, %A, %B {
+ linalg.generic #pointwise_2d_trait
+ ins(%A, %A : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%B : memref<?x?xf32>) {
^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors
%2 = addf %e, %arg5 : f32
linalg.yield %2 : f32
- }: memref<?x?xf32>,
- memref<?x?xf32>,
- memref<?x?xf32>
+ }
%0 = dim %B, %c0 : memref<?x?xf32>
%1 = dim %B, %c1 : memref<?x?xf32>
scf.for %arg4 = %c0 to %0 step %c2 {
@@ -562,13 +560,14 @@ func @pointwise_no_view(%M: index, %N: index) {
%6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
- linalg.generic #pointwise_2d_trait %4, %5, %6 {
+ linalg.generic #pointwise_2d_trait
+ ins(%4, %5: memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>)
+ outs(%6 : memref<?x?xf32, offset: ?, strides: [?, ?]>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
%7 = mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
- }: memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>,
- memref<?x?xf32, offset: ?, strides: [?, ?]>
+ }
}
}
return
@@ -596,25 +595,23 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%c1 = constant 1 : index
%0 = alloc() {temp = true} : memref<100x10xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map1],
- iterator_types = ["parallel", "parallel"]
- } %arg1, %0 {
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg1 : memref<100xf32>)
+ outs(%0 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
- }: memref<100xf32>, memref<100x10xf32>
+ }
%1 = alloc() {temp = true} : memref<100x10xf32>
linalg.generic {
- args_in = 2 : i64,
- args_out = 1 : i64,
indexing_maps = [#map1, #map1, #map1],
- iterator_types = ["parallel", "parallel"]
- } %arg0, %0, %1 {
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %0: memref<100x10xf32>, memref<100x10xf32>)
+ outs(%1 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = subf %arg3, %arg4 : f32
linalg.yield %2 : f32
- }: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32>
+ }
dealloc %0 : memref<100x10xf32>
%2 = dim %1, %c0 : memref<100x10xf32>
%3 = dim %1, %c1 : memref<100x10xf32>
@@ -627,16 +624,14 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%7 = std.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] :
memref<100x10xf32> to memref<?x?xf32, #map2>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map1, #map1],
- iterator_types = ["parallel", "parallel"]
- } %6, %7 {
+ iterator_types = ["parallel", "parallel"]}
+ ins(%6 : memref<?x?xf32, #map2>)
+ outs(%7 : memref<?x?xf32, #map2>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%8 = exp %arg3 : f32
linalg.yield %8 : f32
- }: memref<?x?xf32, #map2>,
- memref<?x?xf32, #map2>
+ }
}
}
dealloc %1 : memref<100x10xf32>
diff --git a/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir b/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir
index 984da63..ee25617 100644
--- a/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir
+++ b/mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir
@@ -3,8 +3,6 @@
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {
- args_in = 2,
- args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
@@ -12,11 +10,13 @@ func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
%B: memref<?x?xf32>,
%C: memref<?x?xf32>,
%D: memref<?x?xf32>) {
- linalg.generic #pointwise_2d_trait %A, %B, %C {
+ linalg.generic #pointwise_2d_trait
+ ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C : memref<?x?xf32>) {
^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors
%2 = addf %e, %arg5 : f32
linalg.yield %2 : f32
- }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ }
%c1 = constant 1 : index
%c0 = constant 0 : index
%c25 = constant 25 : index
@@ -33,10 +33,9 @@ func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
memref<?x?xf32> to memref<?x?xf32, #map>
linalg.indexed_generic {
indexing_maps = [#id_2d, #id_2d],
- iterator_types = ["parallel", "parallel"],
- args_in = 1,
- args_out = 1
- } %4, %5 {
+ iterator_types = ["parallel", "parallel"]}
+ ins(%4 : memref<?x?xf32, #map>)
+ outs(%5 : memref<?x?xf32, #map>) {
^bb0(%arg4: index, %arg5: index, %arg6: f32, %arg7: f32):
%6 = addi %arg4, %arg2 : index
%7 = addi %arg5, %arg3 : index
@@ -46,7 +45,7 @@ func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
%11 = sitofp %10 : i32 to f32
%12 = addf %9, %11 : f32
linalg.yield %12 : f32
- }: memref<?x?xf32, #map>, memref<?x?xf32, #map>
+ }
}
}
return
@@ -66,8 +65,6 @@ func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {
- args_in = 2,
- args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
@@ -79,14 +76,16 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
%c0 = constant 0 : index
%c25 = constant 25 : index
%c10 = constant 10 : index
- linalg.indexed_generic #pointwise_2d_trait %A, %B, %C {
+ linalg.indexed_generic #pointwise_2d_trait
+ ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C : memref<?x?xf32>) {
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
%ab = addf %a, %b : f32
%out = addf %ab, %i_float : f32
linalg.yield %out : f32
- }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ }
%C_X = dim %C, %c0 : memref<?x?xf32>
%C_Y = dim %C, %c1 : memref<?x?xf32>
%D_X = dim %D, %c0 : memref<?x?xf32>
@@ -98,14 +97,13 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
memref<?x?xf32> to memref<?x?xf32, #map>
linalg.generic {
indexing_maps = [#id_2d, #id_2d],
- iterator_types = ["parallel", "parallel"],
- args_in = 1,
- args_out = 1
- } %C_view, %D_view {
+ iterator_types = ["parallel", "parallel"]}
+ ins(%C_view : memref<?x?xf32, #map>)
+ outs(%D_view : memref<?x?xf32, #map>) {
^bb0( %a: f32, %b: f32):
%ab = addf %a, %b : f32
linalg.yield %ab : f32
- }: memref<?x?xf32, #map>, memref<?x?xf32, #map>
+ }
}
return
}
@@ -125,8 +123,6 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {
- args_in = 2,
- args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
@@ -137,14 +133,16 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
%c1 = constant 1 : index
%c3 = constant 3 : index
%c0 = constant 0 : index
- linalg.indexed_generic #pointwise_2d_trait %A, %B, %C {
+ linalg.indexed_generic #pointwise_2d_trait
+ ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C : memref<?x?xf32>) {
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
%j_int = index_cast %j: index to i32
%j_float = sitofp %j_int : i32 to f32
%ab = addf %a, %b : f32
%out = addf %ab, %j_float : f32
linalg.yield %out : f32
- }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+ }
%C_X = dim %C, %c0 : memref<?x?xf32>
%C_Y = dim %C, %c1 : memref<?x?xf32>
%D_X = dim %D, %c0 : memref<?x?xf32>
@@ -161,14 +159,13 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
linalg.generic {
indexing_maps = [#id_2d, #id_2d],
- iterator_types = ["parallel", "parallel"],
- args_in = 1,
- args_out = 1
- } %C_view, %D_view {
+ iterator_types = ["parallel", "parallel"]}
+ ins(%C_view : memref<?x?xf32, #map>)
+ outs(%D_view : memref<?x?xf32, #map>) {
^bb0( %a: f32, %b: f32):
%ab = addf %a, %b : f32
linalg.yield %ab : f32
- }: memref<?x?xf32, #map>, memref<?x?xf32, #map>
+ }
scf.yield
}
return
diff --git a/mlir/test/Dialect/Linalg/inlining.mlir b/mlir/test/Dialect/Linalg/inlining.mlir
index 1e5af26..527b044 100644
--- a/mlir/test/Dialect/Linalg/inlining.mlir
+++ b/mlir/test/Dialect/Linalg/inlining.mlir
@@ -9,8 +9,6 @@
]
#trait = {
- args_in = 1,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel"]
}
@@ -23,9 +21,11 @@ func @inline_into(%arg0: memref<?xf32>) {
func @inlined_fn(%arg0: memref<?xf32>) {
// CHECK: linalg.generic
- linalg.generic #trait %arg0, %arg0 {
+ linalg.generic #trait
+ ins(%arg0 : memref<?xf32>)
+ outs(%arg0 : memref<?xf32>) {
^bb(%0 : f32, %1 : f32) :
linalg.yield %0 : f32
- } : memref<?xf32>, memref<?xf32>
+ }
return
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index dce5c21..004bf92 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -62,52 +62,24 @@ func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// -----
func @generic_no_region(%arg0: memref<f32>) {
- // expected-error @+6 {{expected '{' to begin a region}}
+ // expected-error @+5 {{expected '{' to begin a region}}
linalg.generic {
- args_in = 1,
- args_out = 1,
indexing_maps = [ affine_map<() -> (0)> ],
iterator_types = []
- } %arg0 : memref<f32>
-}
-
-// -----
-
-func @generic_at_least_2_operands(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected 2 or more operands}}
- linalg.generic {
- args_in = 1,
- args_out = 1,
- indexing_maps = [ affine_map<() -> (0)> ],
- iterator_types = []
- } %arg0 {} : memref<f32>
-}
-
-// -----
-
-func @generic_exactly_2_views(%arg0: memref<f32>) {
- // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}}
- linalg.generic {
- args_in = 1,
- args_out = 1,
- indexing_maps = [ affine_map<() -> (0)> ],
- iterator_types = []
- } %arg0, %arg0, %arg0 {}: memref<f32>, memref<f32>, memref<f32>
+ } ins(%arg0 : memref<f32>)
}
// -----
func @generic_mismatched_num_returns(%arg0: memref<f32>) {
- // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}}
+ // expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}}
linalg.generic {
- args_in = 0,
- args_out = 1,
- indexing_maps = [ affine_map<() -> ()> ],
- iterator_types = []
- } %arg0 {
+ indexing_maps = [ affine_map<() -> ()> ],
+ iterator_types = []}
+ outs(%arg0 : memref<f32>) {
^bb(%0: f32):
linalg.yield
- }: memref<f32>
+ }
}
// -----
@@ -115,14 +87,12 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
func @generic_symbol_in_map(%arg0: memref<i32>) {
// expected-error @+1 {{expected the number of symbols in indexing_map #0 to match rank of operand `symbol_source`}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<()[N] -> (0)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<i32>) {
^bb(%i : i32):
linalg.yield %i : i32
- }: memref<i32>
+ }
}
// -----
@@ -130,15 +100,13 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {
func @generic_symbol_source_out_of_range(%arg0: memref<i32>) {
// expected-error @+1 {{symbol_source index out of range}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<()[N] -> (0)> ],
iterator_types = ["parallel"],
- symbol_source = 1
- } %arg0 {
+ symbol_source = 1}
+ outs(%arg0 : memref<i32>) {
^bb(%i : i32):
linalg.yield %i : i32
- }: memref<i32>
+ }
}
// -----
@@ -146,14 +114,12 @@ func @generic_symbol_source_out_of_range(%arg0: memref<i32>) {
func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<() -> (0)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<1xi32>) {
^bb(%i : i32):
linalg.yield %i : i32
- }: memref<1xi32>
+ }
}
// -----
@@ -161,30 +127,26 @@ func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref<?xf32, affine_map<(d0)[s0] -> (d0 + s0)>>'}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<() -> (0, 0)> ],
- iterator_types = []
- } %arg0 {
+ iterator_types = []}
+ outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
^bb(%f : f32):
linalg.yield %f: f32
- }: memref<?xf32, affine_map<(i)[off]->(off + i)>>
+ }
}
// -----
func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+9 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
+ // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
^bb(%0: f32):
%1 = constant 1: i4
linalg.yield %1: i4
- }: memref<?xf32, affine_map<(i)[off]->(off + i)>>
+ }
}
// -----
@@ -192,18 +154,16 @@ func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(o
func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
linalg.generic {
- args_in = 1,
- args_out = 1,
indexing_maps = [
affine_map<(i, j) -> (i + j)>,
affine_map<(i, j) -> (i + j)>
],
- iterator_types = ["parallel","parallel"]
- } %arg0, %arg1 {
- ^bb(%0: f32, %1: f32):
+ iterator_types = ["parallel","parallel"]}
+ ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>)
+ outs(%arg1 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
+ ^bb(%0: f32, %1: f32):
linalg.yield %1: f32
- }: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
- memref<?xf32, affine_map<(i)[off]->(off + i)>>
+ }
}
////////////////////////////////////////////////////////////////////////////////
@@ -216,16 +176,15 @@ func @generic_empty_region(%arg0: memref<f32>) {
%f0 = constant 0.0: f32
// expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}}
linalg.generic {
- args_in = 1,
- args_out = 1,
indexing_maps = [ affine_map<() -> (0)> ],
- iterator_types = []
- } %arg0, %arg0 {
+ iterator_types = []}
+ ins(%arg0 : memref<f32>)
+ outs(%arg0 : memref<f32>) {
^bb1:
linalg.yield %f0: f32
^bb2:
linalg.yield %f0: f32
- }: memref<f32>, memref<f32>
+ }
}
// -----
@@ -234,12 +193,11 @@ func @generic_empty_region(%arg0: memref<f32>) {
%f0 = constant 0.0: f32
// expected-error @+1 {{linalg.generic' op expected region with 1 block}}
linalg.generic {
- args_in = 1,
- args_out = 1,
indexing_maps = [ affine_map<() -> (0)> ],
- iterator_types = []
- } %arg0, %arg0 {
- }: memref<f32>, memref<f32>
+ iterator_types = []}
+ ins(%arg0 : memref<f32>)
+ outs(%arg0 : memref<f32>) {
+ }
}
// -----
@@ -247,14 +205,12 @@ func @generic_empty_region(%arg0: memref<f32>) {
func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of operands}}
linalg.generic {
- args_in = 0,
- args_out = 1,
- indexing_maps = [ affine_map<() -> (0)> ],
- iterator_types = []
- } %arg0 {
+ indexing_maps = [ affine_map<() -> (0)> ],
+ iterator_types = []}
+ outs(%arg0 : memref<f32>) {
^bb(%f: f32, %g: f32):
linalg.yield %f: f32
- }: memref<f32>
+ }
}
// -----
@@ -262,14 +218,12 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
func @generic_block_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output operand: 'memref<f32>'}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<() -> (0)> ],
- iterator_types = []
- } %arg0 {
+ iterator_types = []}
+ outs(%arg0 : memref<f32>) {
^bb(%i: i1):
linalg.yield %i : i1
- }: memref<f32>
+ }
}
// -----
@@ -277,14 +231,12 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}}
linalg.indexed_generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(d0) -> (d0)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<f32>) {
^bb(%f: f32):
linalg.yield %f : f32
- }: memref<f32>
+ }
}
// -----
@@ -292,14 +244,12 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 1 to be an index}}
linalg.indexed_generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(d0) -> (d0)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<f32>) {
^bb(%i: f64, %f: f32):
linalg.yield %f: f32
- }: memref<f32>
+ }
}
// -----
@@ -307,14 +257,12 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 2 of the same type as elemental type of output operand: 'memref<f32>'}}
linalg.indexed_generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(d0) -> (d0)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<f32>) {
^bb(%i: index, %f: i1):
linalg.yield %i: index
- }: memref<f32>
+ }
}
// -----
@@ -322,14 +270,12 @@ func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
func @indexed_generic_arg_count(%arg0: memref<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}}
linalg.indexed_generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<()[] -> ()> ],
- iterator_types = []
- } %arg0 {
+ iterator_types = []}
+ outs(%arg0 : memref<f32>) {
^bb(%0: index, %1: f32):
linalg.yield %1: f32
- } : memref<f32>
+ }
return
}
@@ -338,60 +284,39 @@ func @indexed_generic_arg_count(%arg0: memref<f32>) {
func @indexed_generic_induction_var_arg_type(%arg0: memref<f32>) {
// expected-error @+1 {{op expected block argument 1 to be an index}}
linalg.indexed_generic {
- args_in = 0,
- args_out = 1,
iterator_types = ["parallel"],
- indexing_maps = [ affine_map<(i) -> (i)> ]
- } %arg0 {
+ indexing_maps = [ affine_map<(i) -> (i)> ]}
+ outs(%arg0 : memref<f32>) {
^bb(%0: i32, %1: f32):
linalg.yield %1: f32
- } : memref<f32>
+ }
}
// -----
func @indexed_generic_result_count(%arg0: memref<?xf32>) {
- // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
+ // expected-error @+6 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
linalg.indexed_generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(d0) -> (d0)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<?xf32>) {
^bb(%i: index, %val: f32):
linalg.yield %val, %val: f32, f32
- }: memref<?xf32>
+ }
}
// -----
func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+9 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
+ // expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
^bb(%i: f32):
%0 = constant 0: i1
linalg.yield %0: i1
- }: memref<?xf32, affine_map<(i)[off]->(off + i)>>
-}
-
-// -----
-
-func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}}
- %0 = linalg.generic {
- args_in = 0,
- args_out = 1,
- indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]
- } %arg0 {
- ^bb(%i: f32):
- linalg.yield %i: f32
- }: memref<?xf32, affine_map<(i)[off]->(off + i)>> -> f32
+ }
}
// -----
@@ -399,14 +324,12 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
// expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}}
%0 = linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
^bb(%i: f32):
linalg.yield %i: f32
- }: memref<?xf32, affine_map<(i)[off]->(off + i)>> -> f32
+ } -> f32
}
// -----
@@ -415,14 +338,12 @@ func @generic(%arg0: memref<?x?xi4>) {
// expected-error @+2 {{op expects regions to end with 'linalg.yield', found 'std.addf'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'linalg.yield'}}
linalg.generic {
- args_in = 0,
- args_out = 1,
indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ outs(%arg0 : memref<?x?xi4>) {
^bb(%0: i4) :
%1 = std.addf %0, %0: i4
- } : memref<?x?xi4>
+ }
return
}
@@ -511,23 +432,6 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
// -----
-func @generic(%arg0: tensor<?x?xi4>) {
- // expected-error @+1 {{unexpected #results > #outputs}}
- linalg.generic {
- args_in = 1,
- args_out = 1,
- indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]
- } %arg0 {
- ^bb(%0: i4) :
- %1 = std.addi %0, %0: i4
- linalg.yield %1, %1: i4, i4
- } : tensor<?x?xi4> -> (tensor<?x?xi4>, tensor<?x?xi4>)
- return
-}
-
-// -----
-
func @empty_init_expected(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
// expected-error @+1 {{expected empty `init` when op has no results or no reduction dims}}
linalg.matmul ins(%m, %m: memref<?x?xf32>, memref<?x?xf32>)
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index b8df79e..348468a 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -560,12 +560,15 @@ func @pooling_sum(%arg0: memref<?x?xf32>,
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
}
func @generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.generic #trait2 %arg0, %arg1, %arg2 {
+ linalg.generic #trait2
+ ins(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
^bb0(%a: f32, %b: f32, %c: f32):
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %d, %e : f32, f32
- }: memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ }
return
}
// CHECKLOOP-LABEL: @generic_region
@@ -602,7 +605,10 @@ func @indexed_generic_region(
%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
%arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.indexed_generic #trait4 %arg0, %arg1, %arg2 {
+ linalg.indexed_generic #trait4
+ ins(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
^bb0(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%result_1 = mulf %a, %b : f32
@@ -613,9 +619,7 @@ func @indexed_generic_region(
%result_2 = addf %c, %ijk_float : f32
linalg.yield %result_1, %result_2 : f32, f32
- }: memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
- memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ }
return
}
@@ -666,10 +670,12 @@ func @indexed_generic_region(
func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
{
- linalg.generic #trait_broadcast %arg0, %arg1 {
+ linalg.generic #trait_broadcast
+ ins(%arg0 : memref<f32>)
+ outs(%arg1 : memref<3x4xf32>) {
^bb(%a: f32, %b: f32) :
linalg.yield %a : f32
- } : memref<f32>, memref<3x4xf32>
+ }
return
}
@@ -690,13 +696,15 @@ func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
{
- linalg.indexed_generic #trait_broadcast %arg0, %arg1 {
+ linalg.indexed_generic #trait_broadcast
+ ins(%arg0 : memref<i32>)
+ outs(%arg1 : memref<3x4xi32>) {
^bb(%i: index, %j: index, %a: i32, %b: i32) :
%ij = addi %i, %j : index
%ij_int = index_cast %ij : index to i32
%result = addi %a, %ij_int : i32
linalg.yield %result : i32
- } : memref<i32>, memref<3x4xi32>
+ }
return
}
@@ -736,11 +744,13 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
{
- linalg.generic #trait_reduce_1D %arg0, %arg1 {
+ linalg.generic #trait_reduce_1D
+ ins(%arg0 : memref<?xf32>)
+ outs(%arg1 : memref<f32>) {
^bb(%a: f32, %b: f32) :
%0 = addf %a, %b : f32
linalg.yield %0 : f32
- } : memref<?xf32>, memref<f32>
+ }
return
}
// CHECKLOOP-LABEL: @generic_op_1D_reduce
@@ -780,14 +790,16 @@ func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
%arg1: memref<f32>,
%arg2: memref<f32>)
{
- linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 {
+ linalg.indexed_generic #trait_reduce_init_1D
+ ins(%arg0, %arg1 : memref<?xf32>, memref<f32>)
+ outs(%arg2 : memref<f32>) {
^bb(%i : index, %a: f32, %b: f32, %c: f32) :
%0 = constant 0 : index
%1 = cmpi "eq", %0, %i : index
%2 = select %1, %b, %c : f32
%3 = addf %a, %2 : f32
linalg.yield %3 : f32
- } : memref<?xf32>, memref<f32>, memref<f32>
+ }
return
}
// CHECKLOOP-LABEL: @indexed_generic_op_1D_reduce
@@ -823,10 +835,10 @@ func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
}
func @generic_const_init(%arg0: memref<?xf32>) {
%cst = constant 1.0 : f32
- linalg.generic #trait_const_fill %arg0 {
+ linalg.generic #trait_const_fill outs(%arg0 : memref<?xf32>) {
^bb0(%arg1: f32): // no predecessors
linalg.yield %cst : f32
- }: memref<?xf32>
+ }
return
}
// CHECKLOOP-LABEL: @generic_const_init
@@ -855,11 +867,13 @@ func @generic_const_init(%arg0: memref<?xf32>) {
}
func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
{
- linalg.generic #scalar_trait %arg0, %arg1, %arg2 {
+ linalg.generic #scalar_trait
+ ins(%arg0, %arg1 : memref<f32>, memref<f32>)
+ outs(%arg2 : memref<f32>) {
^bb(%a : f32, %b : f32, %c : f32) :
%0 = addf %a, %b : f32
linalg.yield %0 : f32
- } : memref<f32>, memref<f32>, memref<f32>
+ }
return
}
// CHECKLOOP-LABEL: @scalar_code
@@ -944,14 +958,14 @@ func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memre
}
func @conv1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
- linalg.generic #conv_1d_trait %in, %filter, %out {
+ linalg.generic #conv_1d_trait
+ ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+ outs(%out : memref<?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
- } : memref<?xf32>,
- memref<?xf32>,
- memref<?xf32>
+ }
return
}
@@ -1012,14 +1026,14 @@ func @conv1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>
}
func @conv2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
- linalg.generic #conv_2d_trait %in, %filter, %out {
+ linalg.generic #conv_2d_trait
+ ins(%in, %filter : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%out : memref<?x?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
- } : memref<?x?xf32>,
- memref<?x?xf32>,
- memref<?x?xf32>
+ }
return
}
@@ -1096,14 +1110,14 @@ func @conv2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x
}
func @conv3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
- linalg.generic #conv_3d_trait %in, %filter, %out {
+ linalg.generic #conv_3d_trait
+ ins(%in, %filter : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%out : memref<?x?x?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
- } : memref<?x?x?xf32>,
- memref<?x?x?xf32>,
- memref<?x?x?xf32>
+ }
return
}
@@ -1196,14 +1210,14 @@ func @conv3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memre
}
func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out : memref<?x?x?x?xf32>) -> () {
- linalg.generic #conv_4d_trait %in, %filter, %out {
+ linalg.generic #conv_4d_trait
+ ins(%in, %filter : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+ outs(%out : memref<?x?x?x?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
- } : memref<?x?x?x?xf32>,
- memref<?x?x?x?xf32>,
- memref<?x?x?x?xf32>
+ }
return
}
diff --git a/mlir/test/Dialect/Linalg/parallel_loops.mlir b/mlir/test/Dialect/Linalg/parallel_loops.mlir
index 6c500ec..95eb997 100644
--- a/mlir/test/Dialect/Linalg/parallel_loops.mlir
+++ b/mlir/test/Dialect/Linalg/parallel_loops.mlir
@@ -5,15 +5,14 @@ func @linalg_generic_sum(%lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>,
%sum: memref<2x2xf32>) {
linalg.generic {
- args_in = 2 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"]
- } %lhs, %rhs, %sum {
+ iterator_types = ["parallel", "parallel"]}
+ ins(%lhs, %rhs : memref<2x2xf32>, memref<2x2xf32>)
+ outs(%sum : memref<2x2xf32>) {
^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): // no predecessors
%0 = addf %lhs_in, %rhs_in : f32
linalg.yield %0 : f32
- }: memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
+ }
return
}
// CHECK-LABEL: @linalg_generic_sum
@@ -35,17 +34,17 @@ func @linalg_generic_sum(%lhs: memref<2x2xf32>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
]
#trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
indexing_maps = #accesses
}
func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
- linalg.generic #trait %A, %B {
+ linalg.generic #trait
+ ins(%A : memref<?x?x?x?xf32>)
+ outs(%B : memref<?x?x?xf32>) {
^bb0(%a: f32, %b: f32):
linalg.yield %a: f32
- } : memref<?x?x?x?xf32>, memref<?x?x?xf32>
+ }
return
}
// CHECK-LABEL: @lower_outer_parallel
@@ -68,17 +67,17 @@ func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
]
#trait = {
- args_in = 1,
- args_out = 1,
iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"],
indexing_maps = #accesses
}
func @lower_mixed_parallel(%A: memref<?x?x?x?x?x?xf32>, %B: memref<?x?x?x?xf32>) {
- linalg.generic #trait %A, %B {
+ linalg.generic #trait
+ ins(%A : memref<?x?x?x?x?x?xf32>)
+ outs(%B : memref<?x?x?x?xf32>) {
^bb0(%a: f32, %b: f32):
linalg.yield %a: f32
- } : memref<?x?x?x?x?x?xf32>, memref<?x?x?x?xf32>
+ }
return
}
// CHECK-LABEL: @lower_mixed_parallel
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1d58259..5960d55 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -293,8 +293,6 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
]
#trait = {
- args_in = 1,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
@@ -302,37 +300,44 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.generic #trait {foo = 1} %arg0, %arg1 {
+ linalg.generic #trait
+ ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+ outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
+ attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
- } : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
- memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ }
return
}
// CHECK-LABEL: func @generic
-// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
-// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
-// CHECK-SAME: library_call = "some_external_function_name_1"
+// CHECK: linalg.generic {
+// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
+// CHECK-SAME: library_call = "some_external_function_name_1"}
+// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
+// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
-// CHECK: memref<?x?xvector<3x4xi4>, #[[$strided2D]]>, memref<?x?x?xf32, #[[$strided3D]]>
func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.generic #trait {foo = 1} %arg0, %arg1 {
+ linalg.generic #trait
+ ins(%arg0 : tensor<?x?xvector<3x4xi4>>)
+ outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
+ attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
- } : tensor<?x?xvector<3x4xi4>>,
- memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ }
return
}
// CHECK-LABEL: func @generic_with_tensor_input
-// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
+// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>)
+// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
-// CHECK: tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[$strided3D]]>
// -----
@@ -342,8 +347,6 @@ func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
]
#trait2 = {
- args_in = 2,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
@@ -352,20 +355,22 @@ func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
func @generic_with_tensor_input_and_output(
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>) {
- %0 = linalg.generic #trait2 {foo = 1} %arg0, %arg1 {
+ %0 = linalg.generic #trait2
+ ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+ attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
- } : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+ } -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @generic_with_tensor_input_and_output
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
+// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
// CHECK-SAME: {foo = 1 : i64}
-// CHECK-SAME: %{{.*}}, %{{.*}}
-// CHECK: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+// CHECK: -> tensor<?x?x?xf32>
// CHECK: return {{.*}} : tensor<?x?x?xf32>
// -----
@@ -376,8 +381,6 @@ func @generic_with_tensor_input_and_output(
]
#trait2 = {
- args_in = 2,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
@@ -386,20 +389,22 @@ func @generic_with_tensor_input_and_output(
func @indexed_generic_with_tensor_input_and_output(
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>) {
- %0 = linalg.indexed_generic #trait2 {foo = 1} %arg0, %arg1 {
+ %0 = linalg.indexed_generic #trait2
+ ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
+ attrs = {foo = 1} {
^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
- } : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+ } -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
-// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
+// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
// CHECK-SAME: {foo = 1 : i64}
-// CHECK-SAME: %{{.*}}, %{{.*}}
-// CHECK: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+// CHECK: -> tensor<?x?x?xf32>
// CHECK: return {{.*}} : tensor<?x?x?xf32>
// -----
@@ -410,8 +415,6 @@ func @indexed_generic_with_tensor_input_and_output(
]
#trait_broadcast = {
- args_in = 1,
- args_out = 1,
indexing_maps = #broadcast_access,
iterator_types = ["parallel", "parallel"],
library_call = "some_broadcast_external_fn"
@@ -419,19 +422,21 @@ func @indexed_generic_with_tensor_input_and_output(
func @generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
{
- %0 = linalg.generic #trait_broadcast %arg0 {
+ %0 = linalg.generic #trait_broadcast
+ ins(%arg0 : tensor<f32>) {
^bb(%a: f32) :
linalg.yield %a : f32
- } : tensor<f32> -> tensor<3x4xf32>
+ } -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
{
- %0 = linalg.indexed_generic #trait_broadcast %arg0 {
+ %0 = linalg.indexed_generic #trait_broadcast
+ ins(%arg0 : tensor<f32>) {
^bb(%i: index, %j: index, %a: f32) :
linalg.yield %a : f32
- } : tensor<f32> -> tensor<3x4xf32>
+ } -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
@@ -446,8 +451,6 @@ func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
]
#trait3 = {
- args_in = 1,
- args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_2"
@@ -455,41 +458,48 @@ func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.generic #trait3 {foo = 1} %arg0, %arg1 {
+ linalg.generic #trait3
+ ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+ outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
+ attrs = {foo = 1} {
^bb(%a: vector<3x4xi4>, %b: f32) :
linalg.yield %b : f32
- } : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
- memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ }
return
}
// CHECK-LABEL: func @generic_region
-// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
-// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
+// CHECK: linalg.generic {
+// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_2"
-// CHECK-SAME: {foo = 1 : i64}
-// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
-// CHECK: linalg.yield %{{.*}} : f32
-// CHECK: memref<?x?xvector<3x4xi4>, #[[$strided2D]]>,
-// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]>
+// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
+// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
+// CHECK-SAME: attrs = {foo = 1 : i64} {
+// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
+// CHECK: linalg.yield %{{.*}} : f32
func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
- linalg.indexed_generic #trait3 {foo = 1} %arg0, %arg1 {
- ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
+ linalg.indexed_generic #trait3
+ ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
+ outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
+ attrs = {foo = 1} {
+ ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
linalg.yield %b : f32
- }: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
- memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ }
return
}
// CHECK-LABEL: func @indexed_generic
-// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64,
-// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
+// CHECK: linalg.indexed_generic {
+// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_2"
+// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
+// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32
-// CHECK: }: memref<?x?xvector<3x4xi4>, #[[$strided2D]]>,
-// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]>
+// CHECK: }
// -----
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
index 638fdb8..14b4e2a 100644
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ b/mlir/test/Dialect/Linalg/standard.mlir
@@ -72,8 +72,6 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
affine_map<(m, n, k) -> (m, n)>
]
#matmul_trait = {
- args_in = 2,
- args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_outerproduct_matmul"
@@ -88,20 +86,19 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
!matrix_type_C = type memref<?x?x!vector_type_C>
func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
- linalg.generic #matmul_trait %A, %B, %C {
+ linalg.generic #matmul_trait
+ ins(%A, %B : !matrix_type_A, !matrix_type_B)
+ outs(%C : !matrix_type_C) {
^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
linalg.yield %d: !vector_type_C
- } : !matrix_type_A, !matrix_type_B, !matrix_type_C
-
+ }
return
}
// CHECK-LABEL: func @matmul_vec_impl(
// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
#indexed_matmul_trait = {
- args_in = 2,
- args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_indexed_outerproduct_matmul"
@@ -109,12 +106,14 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
func @matmul_vec_indexed(%A: !matrix_type_A,
%B: !matrix_type_B,
%C: !matrix_type_C) {
- linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
+ linalg.indexed_generic #indexed_matmul_trait
+ ins(%A, %B : !matrix_type_A, !matrix_type_B)
+ outs(%C : !matrix_type_C) {
^bb0(%i: index, %j: index, %k: index,
%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
linalg.yield %d: !vector_type_C
- } : !matrix_type_A, !matrix_type_B, !matrix_type_C
+ }
return
}
// CHECK-LABEL: func @matmul_vec_indexed(
diff --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
index ed4f32b..654a13fc 100644
--- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
+++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
@@ -4,23 +4,24 @@
// CHECK-LABEL: func @multiple_results_generic_op
func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
- %0, %1 = linalg.generic {args_in = 1 : i64, args_out = 2 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %0, %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<4xf32>) {
^bb0(%gen_arg1: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1, %tmp1 : f32, f32
- }: tensor<4xf32> -> (tensor<4xf32>, tensor<4xf32>)
+ } -> tensor<4xf32>, tensor<4xf32>
return %0, %1 : tensor<4xf32>, tensor<4xf32>
}
// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]], %[[ARG2_RESULT:.*]]: [[TYPE]])
// CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: linalg.generic
-// CHECK-SAME: %[[NEW_ARG0]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]
+// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]]
+// CHECK-SAME: outs(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]] : [[TYPE]], [[TYPE]]
// CHECK-NEXT: ^{{[a-z0-9_]*}}
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32
// CHECK-NEXT: %{{.*}} = exp
// CHECK-NEXT: linalg.yield
-// CHECK-NEXT: [[TYPE]], [[TYPE]], [[TYPE]]
// CHECK: linalg.copy(%[[FIRST_ALLOC]], %[[ARG1_RESULT]])
// CHECK: dealloc %[[FIRST_ALLOC]]
// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG2_RESULT]])
@@ -33,31 +34,33 @@ func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tenso
// CHECK-LABEL: func @chained_operations
func @chained_operations(%arg0: tensor<4xf32>) -> tensor<4xf32> {
- %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<4xf32>) {
^bb0(%gen_arg1: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1 : f32
- }: tensor<4xf32> -> tensor<4xf32>
- %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 {
+ } -> tensor<4xf32>
+ %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%0 : tensor<4xf32>) {
^bb0(%gen_arg2: f32):
%tmp2 = exp %gen_arg2 : f32
linalg.yield %tmp2 : f32
- }: tensor<4xf32> -> tensor<4xf32>
+ } -> tensor<4xf32>
return %1 : tensor<4xf32>
}
// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]])
// CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: linalg.generic
-// CHECK-SAME: %[[NEW_ARG0]], %[[FIRST_ALLOC]]
+// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]]
+// CHECK-SAME: outs(%[[FIRST_ALLOC]] : [[TYPE]]
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32
-// CHECK: [[TYPE]], [[TYPE]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: linalg.generic
-// CHECK-SAME: %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]
+// CHECK-SAME: ins(%[[FIRST_ALLOC]] : [[TYPE]]
+// CHECK-SAME: outs(%[[SECOND_ALLOC]] : [[TYPE]]
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32
-// CHECK: [[TYPE]], [[TYPE]]
// CHECK: dealloc %[[FIRST_ALLOC]]
// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG1_RESULT]])
// CHECK: dealloc %[[SECOND_ALLOC]]
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index cd20d4f..1c0a196 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -349,11 +349,13 @@ func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.generic #pointwise_2d_trait %arg0, %arg1, %arg2 {
+ linalg.generic #pointwise_2d_trait
+ ins(%arg0, %arg1 : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg2 : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32): // no predecessors
%4 = addf %arg4, %arg5 : f32
linalg.yield %4 : f32
- }: memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
+ }
return
}
// TILE-2-LABEL: func @pointwise
diff --git a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
index 1f7b8e5..0d38d4e 100644
--- a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
+++ b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
@@ -10,13 +10,15 @@
iterator_types = ["parallel"]
}
func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
- linalg.indexed_generic #pointwise_1d_trait %operand, %result {
+ linalg.indexed_generic #pointwise_1d_trait
+ ins(%operand :memref<50xf32>)
+ outs(%result : memref<50xf32>) {
^bb0(%i: index, %operand_in: f32, %result_in: f32):
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
%out = addf %operand_in, %i_float : f32
linalg.yield %out : f32
- }: memref<50xf32>, memref<50xf32>
+ }
return
}
// TILE-10n25-LABEL: func @indexed_generic_vector
@@ -53,7 +55,9 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>)
iterator_types = ["parallel", "parallel"]
}
func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
- linalg.indexed_generic #combined_indices_trait %operand, %result {
+ linalg.indexed_generic #combined_indices_trait
+ ins(%operand : memref<50x100xf32>)
+ outs(%result : memref<50x100xf32>) {
^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
@@ -61,7 +65,7 @@ func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x10
%j_float = sitofp %j_int : i32 to f32
%out = addf %i_float, %j_float : f32
linalg.yield %out : f32
- }: memref<50x100xf32>, memref<50x100xf32>
+ }
return
}
// TILE-10n25-LABEL: func @indexed_generic_matrix
diff --git a/mlir/test/Dialect/Linalg/tile_parallel.mlir b/mlir/test/Dialect/Linalg/tile_parallel.mlir
index ad38095..586823c 100644
--- a/mlir/test/Dialect/Linalg/tile_parallel.mlir
+++ b/mlir/test/Dialect/Linalg/tile_parallel.mlir
@@ -14,13 +14,14 @@
func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%sum: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.generic #pointwise_2d_trait %lhs, %rhs, %sum {
+ linalg.generic #pointwise_2d_trait
+ ins(%lhs, %rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%sum : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32):
%result = addf %lhs_in, %rhs_in : f32
linalg.yield %result : f32
- }: memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>
+ }
return
}
// TILE-2-LABEL: func @sum(
@@ -33,7 +34,7 @@ func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-2: [[LHS_SUBVIEW:%.*]] = subview [[LHS]]
// TILE-2: [[RHS_SUBVIEW:%.*]] = subview [[RHS]]
// TILE-2: [[SUM_SUBVIEW:%.*]] = subview [[SUM]]
-// TILE-2: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
+// TILE-2: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]
// TILE-02-LABEL: func @sum(
// TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
@@ -45,12 +46,12 @@ func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-02: [[LHS_SUBVIEW:%.*]] = subview [[LHS]]
// TILE-02: [[RHS_SUBVIEW:%.*]] = subview [[RHS]]
// TILE-02: [[SUM_SUBVIEW:%.*]] = subview [[SUM]]
-// TILE-02: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
+// TILE-02: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]
// TILE-002-LABEL: func @sum(
// TILE-002-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
// TILE-002-NO: scf.parallel
-// TILE-002: linalg.generic {{.*}} [[LHS]], [[RHS]], [[SUM]] {
+// TILE-002: linalg.generic {{.*}} ins([[LHS]], [[RHS]]{{.*}} outs([[SUM]]
// TILE-234-LABEL: func @sum(
// TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
@@ -64,4 +65,4 @@ func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-234: [[LHS_SUBVIEW:%.*]] = subview [[LHS]]
// TILE-234: [[RHS_SUBVIEW:%.*]] = subview [[RHS]]
// TILE-234: [[SUM_SUBVIEW:%.*]] = subview [[SUM]]
-// TILE-234: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
+// TILE-234: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]
diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
index a9733cf5..7b64742 100644
--- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
+++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
@@ -59,12 +59,14 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
%arg1 : memref<?x?xf32>,
%arg2 : memref<?xf32>)
{
- linalg.generic #trait %arg0, %arg1, %arg2 {
+ linalg.generic #trait
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>)
+ outs(%arg2 : memref<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%0 = addf %arg3, %arg4 : f32
%1 = addf %0, %arg5 : f32
linalg.yield %1 : f32
- } : memref<?x?x?xf32>, memref<?x?xf32>, memref<?xf32>
+ }
return
}
@@ -82,7 +84,8 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
// CHECK: linalg.generic
-// CHECK-SAME: %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK-SAME: ins(%[[SV1]], %[[SV2]]
+// CHECK-SAME: outs(%[[SV3]]
// TILE1-LABEL: func @reduction
// TILE1-DAG: %[[C2:.*]] = constant 2 : index
@@ -92,7 +95,8 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
// TILE1: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1-NOT: subview
// TILE1: linalg.generic
-// TILE1-SAME: %[[SV1]], %[[SV2]], %{{.*}}
+// TILE1-SAME: ins(%[[SV1]], %[[SV2]]
+// TILE1-SAME: outs(%{{.*}}
// TILE2-LABEL: func @reduction
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
@@ -105,4 +109,5 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
// TILE2: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
// TILE2: linalg.generic
-// TILE2-SAME: %[[SV1]], %[[SV2]], %[[SV3]]
+// TILE2-SAME: ins(%[[SV1]], %[[SV2]]
+// TILE2-SAME: outs(%[[SV3]]
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 6d00396..1d9c4f9 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -105,12 +105,14 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
}
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
- linalg.generic #matmul_trait %A, %B, %C {
+ linalg.generic #matmul_trait
+ ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
+ outs(%C : memref<8x32xf32>) {
^bb(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e : f32
- } : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
+ }
return
}
// CHECK-LABEL: func @vectorization_test
@@ -122,12 +124,14 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
%C: memref<8x32xi32>) {
- linalg.generic #matmul_trait %A, %B, %C {
+ linalg.generic #matmul_trait
+ ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
+ outs(%C : memref<8x32xi32>) {
^bb(%a: i32, %b: i32, %c: i32) :
%d = muli %a, %b: i32
%e = addi %c, %d: i32
linalg.yield %e : i32
- } : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32>
+ }
return
}
// CHECK-LABEL: func @vectorization_test_integer
@@ -187,23 +191,24 @@ func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.generic #generic_matmul_trait %A, %B, %C {
+ linalg.generic #generic_matmul_trait
+ ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb(%a: f32, %b: f32, %c: f32):
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e: f32
- }: memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>
+ }
return
}
// CHECK-LABEL: func @permute_generic
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
-// CHECK-SAME: library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}}
+// CHECK-SAME: library_call = "linalg_matmul"}
// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
-// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
+// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
#indexed_matmul_trait = {
@@ -217,23 +222,24 @@ func @permute_generic_indexed(
%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
+ linalg.indexed_generic #indexed_matmul_trait
+ ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>)
+ outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e: f32
- } : memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>
+ }
return
}
// CHECK-LABEL: func @permute_generic_indexed
-// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
-// CHECK-SAME: library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}}
+// CHECK-SAME: library_call = "linalg_matmul_indexed"}
// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
-// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
+// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index ec22dd0..f36b93ec 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -344,6 +344,7 @@ TEST_FUNC(builder_helpers) {
});
});
+ // clang-format off
// CHECK-LABEL: @builder_helpers
// CHECK: affine.for %{{.*}} = affine_map<(d0) -> (d0)>({{.*}}) to affine_map<(d0) -> (d0)>({{.*}}) {
// CHECK-NEXT: affine.for %{{.*}} = affine_map<(d0) -> (d0)>({{.*}}) to affine_map<(d0) -> (d0)>({{.*}}) {
@@ -424,9 +425,11 @@ TEST_FUNC(operator_or) {
Value rhs(f.getArgument(1));
lhs || rhs;
+ // clang-format off
// CHECK-LABEL: @operator_or
// CHECK: [[ARG0:%.*]]: i1, [[ARG1:%.*]]: i1
// CHECK: or [[ARG0]], [[ARG1]]
+ // clang-format on
f.print(llvm::outs());
f.erase();
}
@@ -444,11 +447,13 @@ TEST_FUNC(operator_and) {
Value rhs(f.getArgument(1));
negate(lhs && rhs);
+ // clang-format off
// CHECK-LABEL: @operator_and
// CHECK: [[ARG0:%.*]]: i1, [[ARG1:%.*]]: i1
// CHECK: [[AND:%.*]] = and [[ARG0]], [[ARG1]]
// CHECK: [[TRUE:%.*]] = constant true
// CHECK: subi [[TRUE]], [[AND]] : i1
+ // clang-format on
f.print(llvm::outs());
f.erase();
}
@@ -632,6 +637,7 @@ TEST_FUNC(select_op_f32) {
std_select(ugt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
});
+ // clang-format off
// CHECK-LABEL: @select_op
// CHECK: affine.for %{{.*}} = 0 to 1 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 1 {
@@ -886,22 +892,25 @@ TEST_FUNC(affine_if_op) {
// clang-format off
// CHECK-LABEL: func @linalg_generic_pointwise
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins({{.*}}memref<?x?xf32>, memref<?x?xf32>)
+// CHECK-SAME: outs({{.*}}memref<?x?xf32>)
// CHECK: addf
-// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins({{.*}}memref<?x?xf32>, memref<?x?xf32>)
+// CHECK-SAME: outs({{.*}}memref<?x?xf32>)
// CHECK: cmpf "ogt"
// CHECK: select
-// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
-// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}} : memref<?x?xf32>)
+// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?xf32>)
// CHECK: tanh
-// CHECK: }: memref<?x?xf32>, memref<?x?xf32>
// clang-format on
TEST_FUNC(linalg_generic_pointwise_test) {
using namespace edsc;
@@ -929,14 +938,16 @@ TEST_FUNC(linalg_generic_pointwise_test) {
// clang-format off
// CHECK-LABEL: func @linalg_generic_matmul
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<?x?xf32>, memref<?x?xf32>)
+// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?xf32>)
/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32
-// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+// CHECK: }
// clang-format on
TEST_FUNC(linalg_generic_matmul_test) {
using namespace edsc;
@@ -958,16 +969,18 @@ TEST_FUNC(linalg_generic_matmul_test) {
// clang-format off
// CHECK-LABEL: func @linalg_generic_conv_nhwc
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?x?x?xf32>)
/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32
-// CHECK: }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+// CHECK: }
// clang-format on
TEST_FUNC(linalg_generic_conv_nhwc) {
using namespace edsc;
@@ -992,16 +1005,18 @@ TEST_FUNC(linalg_generic_conv_nhwc) {
// clang-format off
// CHECK-LABEL: func @linalg_generic_dilated_conv_nhwc
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3 * 3 + d5 * 5, d4 * 4 + d6 * 6, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?x?x?xf32>)
// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32
-// CHECK: }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+// CHECK: }
// clang-format on
TEST_FUNC(linalg_generic_dilated_conv_nhwc) {
using namespace edsc;
@@ -1053,38 +1068,43 @@ TEST_FUNC(linalg_metadata_ops) {
// clang-format off
// CHECK-LABEL: func @linalg_tensors
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, memref<?x?xf32>)
// CHECK: addf
-// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK: cmpf "ogt"
// CHECK: select
-// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
-// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}} : tensor<?x?xf32>)
// CHECK: tanh
-// CHECK: }: tensor<?x?xf32> -> tensor<?x?xf32>
-// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, memref<?x?xf32>)
// CHECK: mulf
-// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
-// CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64,
+// CHECK: } -> tensor<?x?xf32>
+// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
-// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, memref<?x?xf32>)
+// CHECK-SAME: init(%{{[a-z0-9]*}} : tensor<?x?xf32>)
// CHECK: mulf
// CHECK: addf
-// CHECK: }: tensor<?x?xf32>, memref<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
+// CHECK: } -> tensor<?x?xf32>
// clang-format on
TEST_FUNC(linalg_tensors_test) {
using namespace edsc;
@@ -1103,10 +1123,15 @@ TEST_FUNC(linalg_tensors_test) {
AffineExpr i, j;
bindDims(&globalContext(), i, j);
StructuredIndexed SA(A), SB(B), SC(tensorType);
- linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
- linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
- linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j}));
- Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0);
+ Value added = linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}))
+ ->getResult(0);
+ Value maxed = linalg_generic_pointwise_max(
+ SA({i, j}), StructuredIndexed(added)({i, j}), SC({i, j}))
+ ->getResult(0);
+ Value tanhed = linalg_generic_pointwise_tanh(StructuredIndexed(maxed)({i, j}),
+ SC({i, j}))
+ ->getResult(0);
+ Value o1 = linalg_generic_matmul(A, B, tanhed, tensorType)->getResult(0);
linalg_generic_matmul(A, B, o1, tensorType);
f.print(llvm::outs());
@@ -1135,19 +1160,19 @@ TEST_FUNC(vector_extractelement_op_i32) {
f.erase();
}
+// clang-format off
// CHECK-LABEL: func @memref_vector_matmul_test(
// CHECK-SAME: %[[A:.*]]: memref<?x?xvector<4x16xf32>>,
// CHECK-SAME: %[[B:.*]]: memref<?x?xvector<16x8xf32>>,
// CHECK-SAME: %[[C:.*]]: memref<?x?xvector<4x8xf32>>)
-// CHECK: linalg.generic {{.*}} %[[A]], %[[B]], %[[C]]
-// CHECK: vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0,
-// d2)>,
+// CHECK: linalg.generic {{{.*}}}
+// CHECK-SAME: ins(%[[A]], %[[B]] : memref<?x?xvector<4x16xf32>>, memref<?x?xvector<16x8xf32>>)
+// CHECK-SAME: outs(%[[C]] : memref<?x?xvector<4x8xf32>>)
+// CHECK: vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"]
-// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
-// CHECK: memref<?x?xvector<4x16xf32>>, memref<?x?xvector<16x8xf32>>,
-// CHECK-SAME: memref<?x?xvector<4x8xf32>>
+// clang-format on
TEST_FUNC(memref_vector_matmul_test) {
using namespace edsc;
using namespace edsc::ops;
diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
index e1dacdf..b9b5d88 100644
--- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
@@ -19,15 +19,13 @@ func @void_function_signature_conversion(%arg0: tensor<4x8xf32>) {
func @complex_signature_conversion(%arg0: tensor<5xf32>, %arg1: memref<10xf32>, %arg2: i1, %arg3: f16) -> (i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16) {
%0 = alloc() : memref<15xf32>
%1 = linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]
- } %arg0 {
+ iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<5xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: tensor<5xf32> -> tensor<5xf32>
+ } -> tensor<5xf32>
return %arg2, %1, %arg1, %0, %arg3 : i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16
}
// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: f16)
diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index b1cfdfd..4fcd225 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -17,11 +17,12 @@ func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
// CHECK-LABEL: func @memref_in_function_results
func @memref_in_function_results(%arg0: tensor<5xf32>, %arg1: memref<10xf32>) -> (tensor<5xf32>, memref<10xf32>, memref<15xf32>) {
%0 = alloc() : memref<15xf32>
- %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<5xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: tensor<5xf32> -> tensor<5xf32>
+ } -> tensor<5xf32>
return %1, %arg1, %0 : tensor<5xf32>, memref<10xf32>, memref<15xf32>
}
// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>)
@@ -97,23 +98,25 @@ func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %ar
// CHECK-LABEL: func @compute_allocs_position_simple
func @compute_allocs_position_simple(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
- %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<2xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: tensor<2xf32> -> tensor<2xf32>
- %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 {
+ } -> tensor<2xf32>
+ %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%0 : tensor<2xf32>) {
^bb0(%gen2_arg0: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: tensor<2xf32> -> tensor<2xf32>
+ } -> tensor<2xf32>
return %1 : tensor<2xf32>
}
// CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>,
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[FIRST_ALLOC]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[FIRST_ALLOC]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[FIRST_ALLOC]]{{.*}} outs(%[[SECOND_ALLOC]]
// -----
@@ -123,78 +126,86 @@ func @compute_allocs_position_simple(%cond: i1, %arg0: tensor<2xf32>) -> tensor<
// CHECK-LABEL: func @compute_allocs_position
func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
- %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<2xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: tensor<2xf32> -> tensor<2xf32>
- %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 {
+ } -> tensor<2xf32>
+ %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%0 : tensor<2xf32>) {
^bb0(%gen2_arg0: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: tensor<2xf32> -> tensor<2xf32>
+ } -> tensor<2xf32>
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>),
^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
- %2 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %2 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<2xf32>) {
^bb0(%gen3_arg0: f32):
%tmp3 = exp %gen3_arg0 : f32
linalg.yield %tmp3 : f32
- }: tensor<2xf32> -> tensor<2xf32>
- %3 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %2 {
+ } -> tensor<2xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%2 : tensor<2xf32>) {
^bb0(%gen4_arg0: f32):
%tmp4 = exp %gen4_arg0 : f32
linalg.yield %tmp4 : f32
- }: tensor<2xf32> -> tensor<2xf32>
+ } -> tensor<2xf32>
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
- %4 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %4 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<2xf32>) {
^bb0(%gen5_arg0: f32):
%tmp5 = exp %gen5_arg0 : f32
linalg.yield %tmp5 : f32
- }: tensor<2xf32> -> tensor<2xf32>
- %5 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %4 {
+ } -> tensor<2xf32>
+ %5 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%4 : tensor<2xf32>) {
^bb0(%gen6_arg0: f32):
%tmp6 = exp %gen6_arg0 : f32
linalg.yield %tmp6 : f32
- }: tensor<2xf32> -> tensor<2xf32>
+ } -> tensor<2xf32>
br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
- %6 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
+ %6 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<2xf32>) {
^bb0(%gen7_arg0: f32):
%tmp7 = exp %gen7_arg0 : f32
linalg.yield %tmp7 : f32
- }: tensor<2xf32> -> tensor<2xf32>
- %7 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %6 {
+ } -> tensor<2xf32>
+ %7 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%6 : tensor<2xf32>) {
^bb0(%gen8_arg0: f32):
%tmp8 = exp %gen8_arg0 : f32
linalg.yield %tmp8 : f32
- }: tensor<2xf32> -> tensor<2xf32>
+ } -> tensor<2xf32>
return %7 : tensor<2xf32>
}
// CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>,
// CHECK-NEXT: %[[ALLOC0:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC0]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC0]]
// CHECK: %[[ALLOC1:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC0]], %[[ALLOC1]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC0]]{{.*}} outs(%[[ALLOC1]]
// CHECK: cond_br %{{.*}}, ^[[BB0:.*]]({{.*}}), ^[[BB1:.*]](
// CHECK-NEXT: ^[[BB0]]
// CHECK-NEXT: %[[ALLOC2:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC2]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC2]]
// CHECK: %[[ALLOC3:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC2]], %[[ALLOC3]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC2]]{{.*}} outs(%[[ALLOC3]]
// CHECK: br ^[[EXIT:.*]]({{.*}})
// CHECK-NEXT: ^[[BB1]]
// CHECK-NEXT: %[[ALLOC4:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC4]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC4]]
// CHECK: %[[ALLOC5:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC4]], %[[ALLOC5]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC4]]{{.*}} outs(%[[ALLOC5]]
// CHECK: br ^[[EXIT]]
// CHECK-NEXT: ^[[EXIT]]
// CHECK-NEXT: %[[ALLOC6:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC6]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC6]]
// CHECK: %[[ALLOC7:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC6]], %[[ALLOC7]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC6]]{{.*}} outs(%[[ALLOC7]]
// -----
@@ -211,16 +222,12 @@ func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
// CHECK-LABEL: func @callee
func @callee(%arg1: tensor<5xf32>) -> tensor<5xf32> {
- %0 = linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]
- } %arg1 {
+ %0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg1 : tensor<5xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: tensor<5xf32> -> tensor<5xf32>
+ } -> tensor<5xf32>
return %0 : tensor<5xf32>
}
// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>)
diff --git a/mlir/test/Transforms/buffer-placement.mlir b/mlir/test/Transforms/buffer-placement.mlir
index e03f8c9..d8b7de3 100644
--- a/mlir/test/Transforms/buffer-placement.mlir
+++ b/mlir/test/Transforms/buffer-placement.mlir
@@ -24,14 +24,14 @@ func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
^bb2:
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
@@ -73,14 +73,14 @@ func @condBranchDynamicType(
^bb2(%0: index):
%1 = alloc(%0) : memref<?xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %1 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<?xf32>)
+ outs(%1: memref<?xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<?xf32>, memref<?xf32>
+ }
br ^bb3(%1 : memref<?xf32>)
^bb3(%2: memref<?xf32>):
"linalg.copy"(%2, %arg2) : (memref<?xf32>, memref<?xf32>) -> ()
@@ -141,14 +141,14 @@ func @condBranchDynamicTypeNested(
^bb2(%0: index):
%1 = alloc(%0) : memref<?xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %1 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<?xf32>)
+ outs(%1: memref<?xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<?xf32>, memref<?xf32>
+ }
cond_br %arg0, ^bb3, ^bb4
^bb3:
br ^bb5(%1 : memref<?xf32>)
@@ -224,14 +224,14 @@ func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
^bb1:
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
br ^bb2(%0 : memref<2xf32>)
^bb2(%1: memref<2xf32>):
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
@@ -262,14 +262,14 @@ func @criticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
cond_br %arg0, ^bb1, ^bb2(%arg1 : memref<2xf32>)
^bb1:
br ^bb2(%0 : memref<2xf32>)
@@ -300,14 +300,14 @@ func @invCriticalEdge(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
@@ -318,14 +318,14 @@ func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
%7 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %5, %7 {
+ iterator_types = ["parallel"]}
+ ins(%5: memref<2xf32>)
+ outs(%7: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
"linalg.copy"(%7, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
@@ -357,14 +357,14 @@ func @ifElse(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
@@ -401,14 +401,14 @@ func @ifElseNoUsers(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
@@ -423,14 +423,14 @@ func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
%9 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %7, %9 {
+ iterator_types = ["parallel"]}
+ ins(%7: memref<2xf32>)
+ outs(%9: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
"linalg.copy"(%9, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
@@ -456,32 +456,32 @@ func @ifElseNested(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @redundantOperations(%arg0: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg0: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
%1 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %0, %1 {
+ iterator_types = ["parallel"]}
+ ins(%0: memref<2xf32>)
+ outs(%1: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
return
}
// CHECK: (%[[ARG0:.*]]: {{.*}})
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[FIRST_ALLOC]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}}outs(%[[FIRST_ALLOC]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{.*}} %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]
+// CHECK-NEXT: linalg.generic {{.*}} ins(%[[FIRST_ALLOC]]{{.*}}outs(%[[SECOND_ALLOC]]
// CHECK: dealloc
// CHECK-NEXT: dealloc
// CHECK-NEXT: return
@@ -509,26 +509,26 @@ func @moving_alloc_and_inserting_missing_dealloc(
^bb1:
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg0: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
br ^exit(%0 : memref<2xf32>)
^bb2:
%1 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %1 {
+ iterator_types = ["parallel"]}
+ ins(%arg0: memref<2xf32>)
+ outs(%1: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>):
"linalg.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
@@ -567,14 +567,14 @@ func @moving_invalid_dealloc_op_complex(
^bb2:
%1 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %1 {
+ iterator_types = ["parallel"]}
+ ins(%arg0: memref<2xf32>)
+ outs(%1: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
dealloc %1 : memref<2xf32>
br ^exit(%1 : memref<2xf32>)
^exit(%arg2: memref<2xf32>):
@@ -599,14 +599,14 @@ func @inserting_missing_dealloc_simple(
%arg1: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg0: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
"linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
@@ -625,14 +625,14 @@ func @inserting_missing_dealloc_simple(
func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg0: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
dealloc %0 : memref<2xf32>
"linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> ()
return
@@ -659,17 +659,21 @@ func @nested_regions_and_cond_branch(%arg0: i1, %arg1: memref<2xf32>, %arg2: mem
br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = alloc() : memref<2xf32>
- linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 {
+ linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%1 = alloc() : memref<2xf32>
- linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %1 {
+ linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%1: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
@@ -679,9 +683,9 @@ func @nested_regions_and_cond_branch(%arg0: i1, %arg1: memref<2xf32>, %arg2: mem
// CHECK-NEXT: %[[GENERIC1_ALLOC:.*]] = alloc()
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK: ^[[BB2]]:
-// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[GENERIC1_ALLOC]]
+// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[GENERIC1_ALLOC]]
// CHECK: %[[GENERIC2_ALLOC:.*]] = alloc()
-// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[GENERIC2_ALLOC]]
+// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[GENERIC2_ALLOC]]
// CHECK: dealloc %[[GENERIC2_ALLOC]]
// CHECK-NEXT: %{{.*}} = exp
// CHECK: ^[[BB3:.*]]({{.*}}):
@@ -701,11 +705,13 @@ func @nested_regions_and_cond_branch(%arg0: i1, %arg1: memref<2xf32>, %arg2: mem
func @memref_in_function_results(%arg0: memref<5xf32>, %arg1: memref<10xf32>, %arg2: memref<5xf32>) -> (memref<10xf32>, memref<15xf32>) {
%x = alloc() : memref<15xf32>
%y = alloc() : memref<5xf32>
- linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %y {
+ linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg0: memref<5xf32>)
+ outs(%y: memref<5xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%2 = exp %arg3 : f32
linalg.yield %2 : f32
- }: memref<5xf32>, memref<5xf32>
+ }
linalg.copy(%y, %arg2) : memref<5xf32>, memref<5xf32>
return %arg1, %x : memref<10xf32>, memref<15xf32>
}
@@ -946,14 +952,14 @@ func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
^bb2:
%0 = alloca() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
@@ -975,14 +981,14 @@ func @condBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
@@ -993,14 +999,14 @@ func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
^bb3(%5: memref<2xf32>, %6: memref<2xf32>):
%7 = alloca() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %5, %7 {
+ iterator_types = ["parallel"]}
+ ins(%5: memref<2xf32>)
+ outs(%7: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
"linalg.copy"(%7, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
@@ -1021,14 +1027,14 @@ func @ifElseAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @ifElseNestedAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = alloca() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg1, %0 {
+ iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
cond_br %arg0,
^bb1(%arg1, %0 : memref<2xf32>, memref<2xf32>),
^bb2(%0, %arg1 : memref<2xf32>, memref<2xf32>)
@@ -1043,14 +1049,14 @@ func @ifElseNestedAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>)
^bb5(%7: memref<2xf32>, %8: memref<2xf32>):
%9 = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %7, %9 {
+ iterator_types = ["parallel"]}
+ ins(%7: memref<2xf32>)
+ outs(%9: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
"linalg.copy"(%9, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
@@ -1074,17 +1080,21 @@ func @nestedRegionsAndCondBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: m
br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = alloc() : memref<2xf32>
- linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 {
+ linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%0: memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%1 = alloca() : memref<2xf32>
- linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %1 {
+ linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
+ ins(%arg1: memref<2xf32>)
+ outs(%1: memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
@@ -1094,9 +1104,9 @@ func @nestedRegionsAndCondBranchAlloca(%arg0: i1, %arg1: memref<2xf32>, %arg2: m
// CHECK-NEXT: %[[ALLOC:.*]] = alloc()
// CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]]
// CHECK: ^[[BB2]]:
-// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[ALLOC]]
+// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[ALLOC]]
// CHECK: %[[ALLOCA:.*]] = alloca()
-// CHECK-NEXT: linalg.generic {{{.*}}} %[[ARG1]], %[[ALLOCA]]
+// CHECK-NEXT: linalg.generic {{{.*}}} ins(%[[ARG1]]{{.*}}outs(%[[ALLOCA]]
// CHECK: %{{.*}} = exp
// CHECK: ^[[BB3:.*]]({{.*}}):
// CHECK: linalg.copy
diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir
index a0d1193..1a8ea02 100644
--- a/mlir/test/Transforms/copy-removal.mlir
+++ b/mlir/test/Transforms/copy-removal.mlir
@@ -157,14 +157,14 @@ func @test_with_temp_usage_after_copy() -> memref<5xf32> {
%temp = alloc() : memref<5xf32>
linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %temp, %res {
+ iterator_types = ["parallel"]}
+ ins(%temp : memref<5xf32>)
+ outs(%res : memref<5xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<5xf32>, memref<5xf32>
+ }
dealloc %ret : memref<5xf32>
return %temp : memref<5xf32>
}
@@ -231,18 +231,18 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
// CHECK-NOT: %{{.*}} = alloc
%temp = alloc() : memref<2xf32>
// CHECK-NEXT: linalg.generic
- // CHECK-SAME: %[[ARG0]], %[[RES]]
+ // CHECK-SAME: ins(%[[ARG0]]{{.*}}outs(%[[RES]]
// CHECK-NOT: linalg.copy(%{{.*}}, %[[RES]])
// CHECK-NOT: dealloc %{{.*}}
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %temp {
+ iterator_types = ["parallel"]}
+ ins(%arg0 : memref<2xf32>)
+ outs(%temp : memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
"linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32>
// CHECK: return
@@ -261,23 +261,23 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
%to = alloc() : memref<2xf32>
%temp = alloc() : memref<2xf32>
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %temp {
+ iterator_types = ["parallel"]}
+ ins(%arg0 : memref<2xf32>)
+ outs(%temp : memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
linalg.generic {
- args_in = 1 : i64,
- args_out = 1 : i64,
indexing_maps = [#map0, #map0],
- iterator_types = ["parallel"]} %arg0, %to {
+ iterator_types = ["parallel"]}
+ ins(%arg0 : memref<2xf32>)
+ outs(%to : memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
- }: memref<2xf32>, memref<2xf32>
+ }
// CHECK: linalg.copy
"linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32>
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index c338f0f..dd6629e 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -39,6 +39,11 @@ struct TestBufferPlacementPreparationPass
/// Converts tensor-type generic linalg operations to memref ones using
/// buffer assignment.
+ /// TODO: Avoid the copy-pasta by exposing the pattern from BufferPlacement.h
+ /// This is limited by not wanting BufferPlacement to depend on Linalg. Fixing
+ /// this probably requires an OpConversionPattern over generic Operation*. For
+ /// now only RewritePattern but not ConversionPattern allow this.
+
class GenericOpConverter
: public BufferAssignmentOpConversionPattern<linalg::GenericOp> {
public:
@@ -48,34 +53,47 @@ struct TestBufferPlacementPreparationPass
LogicalResult
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ linalg::GenericOpAdaptor adaptor(operands,
+ op.getOperation()->getAttrDictionary());
+
+ // TODO: support ops with reduction.
+ if (!op.init_tensors().empty())
+ return failure();
+
+ // All inputs need to be turned into buffers first. Until then, bail out.
+ if (llvm::any_of(adaptor.inputs(), [](Value in) {
+ return !in.getType().isa<MemRefType>();
+ }))
+ return failure();
+
Location loc = op.getLoc();
- ResultRange results = op.getOperation()->getResults();
- SmallVector<Value, 2> newArgs, newResults;
- newArgs.reserve(operands.size() + results.size());
- newArgs.append(operands.begin(), operands.end());
- newResults.reserve(results.size());
+ SmallVector<Value, 2> outputBuffers, newOutputBuffers;
+ outputBuffers.assign(adaptor.output_buffers().begin(),
+ adaptor.output_buffers().end());
+ newOutputBuffers.reserve(op.getNumOutputs());
+ newOutputBuffers.append(adaptor.output_buffers().begin(),
+ adaptor.output_buffers().end());
// Update all types to memref types.
- for (auto result : results) {
- ShapedType type = result.getType().cast<ShapedType>();
- assert(type && "Generic operations with non-shaped typed results are "
- "not currently supported.");
+ for (Type t : op.getResultTypes()) {
+ auto type = t.cast<ShapedType>();
if (!type.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "dynamic shapes not currently supported");
auto memrefType =
MemRefType::get(type.getShape(), type.getElementType());
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
- newArgs.push_back(alloc);
- newResults.push_back(alloc);
+ newOutputBuffers.push_back(alloc);
}
// Generate a new linalg operation that works on buffers.
auto linalgOp = rewriter.create<linalg::GenericOp>(
- loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
- rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
- op.iterator_types(), op.docAttr(), op.library_callAttr(),
- op.symbol_sourceAttr());
+ loc,
+ /*resultTensorTypes=*/ArrayRef<Type>{},
+ /*inputs=*/adaptor.inputs(),
+ /*outputBuffers=*/newOutputBuffers,
+ /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(),
+ op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr());
// Create a new block in the region of the new Generic Op.
Block &oldBlock = op.getRegion().front();
@@ -83,23 +101,24 @@ struct TestBufferPlacementPreparationPass
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
oldBlock.getArgumentTypes());
- // Map the old block arguments to the new ones.
- BlockAndValueMapping mapping;
- mapping.map(oldBlock.getArguments(), newBlock->getArguments());
-
// Add the result arguments to the new block.
- for (auto result : newResults)
- newBlock->addArgument(
- result.getType().cast<ShapedType>().getElementType());
+ for (Value v : newOutputBuffers)
+ newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
// Clone the body of the old block to the new block.
+ BlockAndValueMapping mapping;
+ for (unsigned i = 0; i < oldBlock.getNumArguments(); i++)
+ mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i));
+
+ OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(newBlock);
- for (auto &op : oldBlock.getOperations())
- rewriter.clone(op, mapping);
+ for (auto &op : oldBlock.getOperations()) {
+ Operation *clonedOp = rewriter.clone(op, mapping);
+ mapping.map(op.getResults(), clonedOp->getResults());
+ }
- // Replace the results of the old Generic Op with the results of the new
- // one.
- rewriter.replaceOp(op, newResults);
+ // Replace the results of the old op with the new output buffers.
+ rewriter.replaceOp(op, newOutputBuffers);
return success();
}
};
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 99b0c03..4efdaf6 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1452,7 +1452,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyMemRef>:$output_buffers,
Variadic<AnyRankedTensor>:$init_tensors);
- let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let builders = [ OpBuilder<