aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td2
-rw-r--r--mlir/include/mlir/IR/CommonTypeConstraints.td15
-rw-r--r--mlir/test/Dialect/Vector/invalid.mlir21
3 files changed, 36 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1..d751894 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1972,7 +1972,7 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
- Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+ Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17de..45ec184 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
// Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
// Whether a type is an UnrankedMemRefType
def IsUnrankedMemRefTypePred
: CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, HasValueSemanticsPred,
"container with value semantics">;
+//===----------------------------------------------------------------------===//
// Vector types.
+//===----------------------------------------------------------------------===//
class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
//===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
//===----------------------------------------------------------------------===//
// Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -879,6 +884,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
"nested tuple">;
//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+ "::mlir::ShapedType">;
+
+//===----------------------------------------------------------------------===//
// Common type constraints
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e..3a83209 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
// -----
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
// -----
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
+ vector.scatter %base[%c0][%indices], %mask, %pass_thru
+ : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index