aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorDenis Khalikov <khalikov.denis@huawei.com>2019-12-05 13:10:10 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-05 13:10:44 -0800
commite67acfa4684e4bee38d3b4c90eff1e78adc62cef (patch)
tree430b259445efeae3ff83d7e28d331016b1d27b0c /mlir
parent33a64540ade2dc9e860ddd6d4c1adbd1088e94c2 (diff)
downloadllvm-e67acfa4684e4bee38d3b4c90eff1e78adc62cef.zip
llvm-e67acfa4684e4bee38d3b4c90eff1e78adc62cef.tar.gz
llvm-e67acfa4684e4bee38d3b4c90eff1e78adc62cef.tar.bz2
[spirv] Add CompositeInsertOp operation
A CompositeInsertOp operation make a copy of a composite object, while modifying one part of it. Closes tensorflow/mlir#292 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/292 from denis0x0D:sandbox/composite_insert 2200962b9057bda53cd2f2866b461e2797196380 PiperOrigin-RevId: 284036551
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td29
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td118
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td47
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp107
-rw-r--r--mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir9
-rw-r--r--mlir/test/Dialect/SPIRV/composite-ops.mlir179
-rw-r--r--mlir/test/Dialect/SPIRV/ops.mlir144
7 files changed, 408 insertions, 225 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index c7acc37..62095a5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -1076,6 +1076,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
+def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>;
def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>;
def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>;
@@ -1170,20 +1171,20 @@ def SPV_OpcodeAttr :
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
- SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpConvertFToU,
- SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
- SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
- SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
- SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
- SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
- SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
- SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
- SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
- SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
- SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
- SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
- SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
- SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+ SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert,
+ SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
+ SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
+ SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
+ SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
+ SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
+ SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+ SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+ SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+ SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+ SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+ SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+ SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+ SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
new file mode 100644
index 0000000..7165050
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
@@ -0,0 +1,118 @@
+//===-- SPIRVCompositeOps.td - MLIR SPIR-V Composite Ops ---*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains composite ops for SPIR-V dialect. It corresponds
+// to "3.32.12. Composite Instructions" of the SPIR-V spec.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_COMPOSITE_OPS
+#define SPIRV_COMPOSITE_OPS
+
+include "mlir/Dialect/SPIRV/SPIRVBase.td"
+
+def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
+ let summary = "Extract a part of a composite object.";
+
+ let description = [{
+ Result Type must be the type of object selected by the last provided
+ index. The instruction result is the extracted object.
+
+ Composite is the composite to extract from.
+
+ Indexes walk the type hierarchy, potentially down to component
+ granularity, to select the part to extract. All indexes must be in
+ bounds. All composite constituents use zero-based numbering, as
+ described by their OpType… instruction.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use
+ `[` integer-literal (',' integer-literal)* `]`
+ `:` composite-type
+ ```
+
+ For example:
+
+ ```
+ %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
+ %2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>>
+ ```
+
+ }];
+
+ let arguments = (ins
+ SPV_Composite:$composite,
+ I32ArrayAttr:$indices
+ );
+
+ let results = (outs
+ SPV_Type:$component
+ );
+
+ let hasFolder = 1;
+}
+
+// -----
+
+def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> {
+ let summary = [{
+ Make a copy of a composite object, while modifying one part of it.
+ }];
+
+ let description = [{
+ Result Type must be the same type as Composite.
+
+ Object is the object to use as the modified part.
+
+ Composite is the composite to copy all but the modified part from.
+
+ Indexes walk the type hierarchy of Composite to the desired depth,
+ potentially down to component granularity, to select the part to modify.
+ All indexes must be in bounds. All composite constituents use zero-based
+ numbering, as described by their OpType… instruction. The type of the
+ part selected to modify must match the type of Object.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ composite-insert-op ::= ssa-id `=` `spv.CompositeInsert` ssa-use, ssa-use
+ `[` integer-literal (',' integer-literal)* `]`
+ `:` object-type `into` composite-type
+ ```
+
+ For example:
+
+ ```
+ %0 = spv.CompositeInsert %object, %composite[1 : i32] : f32 into !spv.array<4xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_Type:$object,
+ SPV_Composite:$composite,
+ I32ArrayAttr:$indices
+ );
+
+ let results = (outs
+ SPV_Composite:$result
+ );
+}
+
+#endif // SPIRV_COMPOSITE_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 000f1dd..bbb99da 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -35,6 +35,7 @@ include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td"
include "mlir/Dialect/SPIRV/SPIRVAtomicOps.td"
include "mlir/Dialect/SPIRV/SPIRVBitOps.td"
include "mlir/Dialect/SPIRV/SPIRVCastOps.td"
+include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td"
include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
@@ -108,52 +109,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
// -----
-def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
- let summary = "Extract a part of a composite object.";
-
- let description = [{
- Result Type must be the type of object selected by the last provided
- index. The instruction result is the extracted object.
-
- Composite is the composite to extract from.
-
- Indexes walk the type hierarchy, potentially down to component
- granularity, to select the part to extract. All indexes must be in
- bounds. All composite constituents use zero-based numbering, as
- described by their OpType… instruction.
-
- ### Custom assembly form
-
- ``` {.ebnf}
- composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use
- `[` integer-literal (',' integer-literal)* `]`
- `:` composite-type
- ```
-
- For example:
-
- ```
- %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
- %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
- %2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>>
- ```
-
- }];
-
- let arguments = (ins
- SPV_Composite:$composite,
- I32ArrayAttr:$indices
- );
-
- let results = (outs
- SPV_Type:$component
- );
-
- let hasFolder = 1;
-}
-
-// -----
-
def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
let summary = [{
Wait for other invocations of this module to reach the current point of
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 99705f6d..34e2e88 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -377,6 +377,34 @@ static unsigned getBitWidth(Type type) {
llvm_unreachable("unhandled bit width computation for type");
}
+/// Walks the given type hierarchy with the given indices, potentially down
+/// to component granularity, to select an element type. Returns null type and
+/// emits errors with the given loc on failure.
+static Type getElementType(Type type, ArrayAttr indices, Location loc) {
+ if (!indices.size()) {
+ emitError(loc, "expected at least one index");
+ return nullptr;
+ }
+
+ int32_t index;
+ for (auto indexAttr : indices) {
+ index = indexAttr.dyn_cast<IntegerAttr>().getInt();
+ if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
+ if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+ emitError(loc, "index ") << index << " out of bounds for " << type;
+ return nullptr;
+ }
+ type = cType.getElementType(index);
+ } else {
+ emitError(loc, "cannot extract from non-composite type ")
+ << type << " with index " << index;
+ return nullptr;
+ }
+ }
+
+ return type;
+}
+
/// Returns true if the given `block` only contains one `spv._merge` op.
static inline bool isMergeBlock(Block &block) {
return !block.empty() && std::next(block.begin()) == block.end() &&
@@ -1094,28 +1122,11 @@ static void print(spirv::CompositeExtractOp compositeExtractOp,
}
static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
- auto resultType = compExOp.composite()->getType();
auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
-
- if (!indicesArrayAttr.size()) {
- return compExOp.emitOpError(
- "expected at least one index for spv.CompositeExtractOp");
- }
-
- int32_t index;
- for (auto indexAttr : indicesArrayAttr) {
- index = indexAttr.dyn_cast<IntegerAttr>().getInt();
- if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) {
- if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
- return compExOp.emitOpError("index ")
- << index << " out of bounds for " << resultType;
- }
- resultType = cType.getElementType(index);
- } else {
- return compExOp.emitError("cannot extract from non-composite type ")
- << resultType << " with index " << index;
- }
- }
+ auto resultType = getElementType(compExOp.composite()->getType(),
+ indicesArrayAttr, compExOp.getLoc());
+ if (!resultType)
+ return failure();
if (resultType != compExOp.getType()) {
return compExOp.emitOpError("invalid result type: expected ")
@@ -1136,6 +1147,60 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
+// spv.CompositeInsert
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
+ OperationState &state) {
+ SmallVector<OpAsmParser::OperandType, 2> operands;
+ Type objectType, compositeType;
+ Attribute indicesAttr;
+ auto loc = parser.getCurrentLocation();
+
+ return failure(
+ parser.parseOperandList(operands, 2) ||
+ parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
+ parser.parseColonType(objectType) ||
+ parser.parseKeywordType("into", compositeType) ||
+ parser.resolveOperands(operands, {objectType, compositeType}, loc,
+ state.operands) ||
+ parser.addTypesToList(compositeType, state.types));
+}
+
+static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
+ auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
+ auto objectType =
+ getElementType(compositeInsertOp.composite()->getType(), indicesArrayAttr,
+ compositeInsertOp.getLoc());
+ if (!objectType)
+ return failure();
+
+ if (objectType != compositeInsertOp.object()->getType()) {
+ return compositeInsertOp.emitOpError("object operand type should be ")
+ << objectType << ", but found "
+ << compositeInsertOp.object()->getType();
+ }
+
+ if (compositeInsertOp.composite()->getType() != compositeInsertOp.getType()) {
+ return compositeInsertOp.emitOpError("result type should be the same as "
+ "the composite type, but found ")
+ << compositeInsertOp.composite()->getType() << " vs "
+ << compositeInsertOp.getType();
+ }
+
+ return success();
+}
+
+static void print(spirv::CompositeInsertOp compositeInsertOp,
+ OpAsmPrinter &printer) {
+ printer << spirv::CompositeInsertOp::getOperationName() << " "
+ << *compositeInsertOp.object() << ", "
+ << *compositeInsertOp.composite() << compositeInsertOp.indices()
+ << " : " << compositeInsertOp.object()->getType() << " into "
+ << compositeInsertOp.composite()->getType();
+}
+
+//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
new file mode 100644
index 0000000..a3f74ca
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
+
+spv.module "Logical" "GLSL450" {
+ func @composite_insert(%arg0 : !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> {
+ // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>>
+ %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>
+ spv.ReturnValue %0: !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>
+ }
+}
diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir
new file mode 100644
index 0000000..353080c
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir
@@ -0,0 +1,179 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.CompositeExtractOp
+//===----------------------------------------------------------------------===//
+
+func @composite_extract_array(%arg0: !spv.array<4xf32>) -> f32 {
+ // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32] : !spv.array<4 x f32>
+ %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4xf32>
+ return %0: f32
+}
+
+// -----
+
+func @composite_extract_struct(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> f32 {
+ // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4 x f32>>
+ %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4xf32>>
+ return %0 : f32
+}
+
+// -----
+
+func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
+ // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32] : vector<4xf32>
+ %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32>
+ return %0 : f32
+}
+
+// -----
+
+func @composite_extract_no_ssa_operand() -> () {
+ // expected-error @+1 {{expected SSA operand}}
+ %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_invalid_index_type_1() -> () {
+ %0 = spv.constant 10 : i32
+ %1 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+ %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
+ // expected-error @+1 {{expected non-function type}}
+ %3 = spv.CompositeExtract %2[%0] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () {
+ // expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
+ %0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_invalid_index_identifier(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () {
+ // expected-error @+1 {{expected bare identifier}}
+ %0 = spv.CompositeExtract %arg0(1 : i32) : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () {
+ // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x !spv.array<4 x f32>>'}}
+ %0 = spv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spv.array<4x!spv.array<4xf32>>
+) -> () {
+ // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x f32>'}}
+ %0 = spv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> () {
+ // expected-error @+1 {{index 2 out of bounds for '!spv.struct<f32, !spv.array<4 x f32>>'}}
+ %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct<f32, !spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () {
+ // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}}
+ %0 = spv.CompositeExtract %arg0[4 : i32] : vector<4xf32>
+ return
+}
+
+// -----
+
+func @composite_extract_invalid_types_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () {
+ // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 3}}
+ %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_invalid_types_2(%arg0: f32) -> () {
+ // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 1}}
+ %0 = spv.CompositeExtract %arg0[1 : i32] : f32
+ return
+}
+
+// -----
+
+func @composite_extract_invalid_extracted_type(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () {
+ // expected-error @+1 {{expected at least one index for spv.CompositeExtract}}
+ %0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>>
+ return
+}
+
+// -----
+
+func @composite_extract_result_type_mismatch(%arg0: !spv.array<4xf32>) -> i32 {
+ // expected-error @+1 {{invalid result type: expected 'f32' but provided 'i32'}}
+ %0 = "spv.CompositeExtract"(%arg0) {indices = [2: i32]} : (!spv.array<4xf32>) -> (i32)
+ return %0: i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.CompositeInsert
+//===----------------------------------------------------------------------===//
+
+func @composite_insert_array(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> {
+ // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into !spv.array<4 x f32>
+ %0 = spv.CompositeInsert %arg1, %arg0[1 : i32] : f32 into !spv.array<4xf32>
+ return %0: !spv.array<4xf32>
+}
+
+// -----
+
+func @composite_insert_struct(%arg0: !spv.struct<!spv.array<4xf32>, f32>, %arg1: !spv.array<4xf32>) -> !spv.struct<!spv.array<4xf32>, f32> {
+ // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spv.array<4 x f32> into !spv.struct<!spv.array<4 x f32>, f32>
+ %0 = spv.CompositeInsert %arg1, %arg0[0 : i32] : !spv.array<4xf32> into !spv.struct<!spv.array<4xf32>, f32>
+ return %0: !spv.struct<!spv.array<4xf32>, f32>
+}
+
+// -----
+
+func @composite_insert_no_indices(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> {
+ // expected-error @+1 {{expected at least one index}}
+ %0 = spv.CompositeInsert %arg1, %arg0[] : f32 into !spv.array<4xf32>
+ return %0: !spv.array<4xf32>
+}
+
+// -----
+
+func @composite_insert_out_of_bounds(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> {
+ // expected-error @+1 {{index 4 out of bounds}}
+ %0 = spv.CompositeInsert %arg1, %arg0[4 : i32] : f32 into !spv.array<4xf32>
+ return %0: !spv.array<4xf32>
+}
+
+// -----
+
+func @composite_insert_invalid_object_type(%arg0: !spv.array<4xf32>, %arg1: f64) -> !spv.array<4xf32> {
+ // expected-error @+1 {{object operand type should be 'f32', but found 'f64'}}
+ %0 = spv.CompositeInsert %arg1, %arg0[3 : i32] : f64 into !spv.array<4xf32>
+ return %0: !spv.array<4xf32>
+}
+
+// -----
+
+func @composite_insert_invalid_result_type(%arg0: !spv.array<4xf32>, %arg1 : f32) -> !spv.array<4xf64> {
+ // expected-error @+1 {{result type should be the same as the composite type, but found '!spv.array<4 x f32>' vs '!spv.array<4 x f64>'}}
+ %0 = "spv.CompositeInsert"(%arg1, %arg0) {indices = [0: i32]} : (f32, !spv.array<4xf32>) -> !spv.array<4xf64>
+ return %0: !spv.array<4xf64>
+}
diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index 784af94..9fb0b41 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -277,150 +277,6 @@ func @bitreverse(%arg: i32) -> i32 {
// -----
//===----------------------------------------------------------------------===//
-// spv.CompositeExtractOp
-//===----------------------------------------------------------------------===//
-
-func @composite_extract_f32_from_1D_array(%arg0: !spv.array<4xf32>) -> f32 {
- // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4 x f32>
- %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4xf32>
- return %0: f32
-}
-
-// -----
-
-func @composite_extract_f32_from_2D_array(%arg0: !spv.array<4x!spv.array<4xf32>>) -> f32 {
- // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.array<4 x !spv.array<4 x f32>>
- %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.array<4x!spv.array<4xf32>>
- return %0: f32
-}
-
-// -----
-
-func @composite_extract_1D_array_from_2D_array(%arg0: !spv.array<4x!spv.array<4xf32>>) -> !spv.array<4xf32> {
- // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4 x !spv.array<4 x f32>>
- %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4x!spv.array<4xf32>>
- return %0 : !spv.array<4xf32>
-}
-
-// -----
-
-func @composite_extract_struct(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> f32 {
- // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4 x f32>>
- %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4xf32>>
- return %0 : f32
-}
-
-// -----
-
-func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
- // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32>
- %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32>
- return %0 : f32
-}
-
-// -----
-
-func @composite_extract_no_ssa_operand() -> () {
- // expected-error @+1 {{expected SSA operand}}
- %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_invalid_index_type_1() -> () {
- %0 = spv.constant 10 : i32
- %1 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
- %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
- // expected-error @+1 {{expected non-function type}}
- %3 = spv.CompositeExtract %2[%0] : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () {
- // expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
- %0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_invalid_index_identifier(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () {
- // expected-error @+1 {{expected bare identifier}}
- %0 = spv.CompositeExtract %arg0(1 : i32) : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () {
- // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x !spv.array<4 x f32>>'}}
- %0 = spv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spv.array<4x!spv.array<4xf32>>
-) -> () {
- // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x f32>'}}
- %0 = spv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> () {
- // expected-error @+1 {{index 2 out of bounds for '!spv.struct<f32, !spv.array<4 x f32>>'}}
- %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct<f32, !spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () {
- // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}}
- %0 = spv.CompositeExtract %arg0[4 : i32] : vector<4xf32>
- return
-}
-
-// -----
-
-func @composite_extract_invalid_types_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () {
- // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 3}}
- %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_invalid_types_2(%arg0: f32) -> () {
- // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 1}}
- %0 = spv.CompositeExtract %arg0[1 : i32] : f32
- return
-}
-
-// -----
-
-func @composite_extract_invalid_extracted_type(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () {
- // expected-error @+1 {{expected at least one index for spv.CompositeExtract}}
- %0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>>
- return
-}
-
-// -----
-
-func @composite_extract_result_type_mismatch(%arg0: !spv.array<4xf32>) -> i32 {
- // expected-error @+1 {{invalid result type: expected 'f32' but provided 'i32'}}
- %0 = "spv.CompositeExtract"(%arg0) {indices = [2: i32]} : (!spv.array<4xf32>) -> (i32)
- return %0: i32
-}
-
-// -----
-
-//===----------------------------------------------------------------------===//
// spv.ControlBarrier
//===----------------------------------------------------------------------===//