diff options
author | Denis Khalikov <khalikov.denis@huawei.com> | 2019-12-05 13:10:10 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-05 13:10:44 -0800 |
commit | e67acfa4684e4bee38d3b4c90eff1e78adc62cef (patch) | |
tree | 430b259445efeae3ff83d7e28d331016b1d27b0c /mlir | |
parent | 33a64540ade2dc9e860ddd6d4c1adbd1088e94c2 (diff) | |
download | llvm-e67acfa4684e4bee38d3b4c90eff1e78adc62cef.zip llvm-e67acfa4684e4bee38d3b4c90eff1e78adc62cef.tar.gz llvm-e67acfa4684e4bee38d3b4c90eff1e78adc62cef.tar.bz2 |
[spirv] Add CompositeInsertOp operation
A CompositeInsertOp operation make a copy of a composite object,
while modifying one part of it.
Closes tensorflow/mlir#292
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/292 from denis0x0D:sandbox/composite_insert 2200962b9057bda53cd2f2866b461e2797196380
PiperOrigin-RevId: 284036551
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 29 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td | 118 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 47 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 107 | ||||
-rw-r--r-- | mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir | 9 | ||||
-rw-r--r-- | mlir/test/Dialect/SPIRV/composite-ops.mlir | 179 | ||||
-rw-r--r-- | mlir/test/Dialect/SPIRV/ops.mlir | 144 |
7 files changed, 408 insertions, 225 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index c7acc37..62095a5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -1076,6 +1076,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; +def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>; @@ -1170,20 +1171,20 @@ def SPV_OpcodeAttr : SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpConvertFToU, - SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, - SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, - SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, - SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, + SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, + SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, + SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, + SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, + SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, + SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td new file mode 100644 index 0000000..7165050 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -0,0 +1,118 @@ +//===-- SPIRVCompositeOps.td - MLIR SPIR-V Composite Ops ---*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains composite ops for SPIR-V dialect. It corresponds +// to "3.32.12. Composite Instructions" of the SPIR-V spec. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_COMPOSITE_OPS +#define SPIRV_COMPOSITE_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { + let summary = "Extract a part of a composite object."; + + let description = [{ + Result Type must be the type of object selected by the last provided + index. The instruction result is the extracted object. + + Composite is the composite to extract from. + + Indexes walk the type hierarchy, potentially down to component + granularity, to select the part to extract. All indexes must be in + bounds. All composite constituents use zero-based numbering, as + described by their OpType… instruction. + + ### Custom assembly form + + ``` {.ebnf} + composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use + `[` integer-literal (',' integer-literal)* `]` + `:` composite-type + ``` + + For example: + + ``` + %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function> + %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> + %2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>> + ``` + + }]; + + let arguments = (ins + SPV_Composite:$composite, + I32ArrayAttr:$indices + ); + + let results = (outs + SPV_Type:$component + ); + + let hasFolder = 1; +} + +// ----- + +def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> { + let summary = [{ + Make a copy of a composite object, while modifying one part of it. + }]; + + let description = [{ + Result Type must be the same type as Composite. + + Object is the object to use as the modified part. + + Composite is the composite to copy all but the modified part from. + + Indexes walk the type hierarchy of Composite to the desired depth, + potentially down to component granularity, to select the part to modify. + All indexes must be in bounds. All composite constituents use zero-based + numbering, as described by their OpType… instruction. The type of the + part selected to modify must match the type of Object. + + ### Custom assembly form + + ``` {.ebnf} + composite-insert-op ::= ssa-id `=` `spv.CompositeInsert` ssa-use, ssa-use + `[` integer-literal (',' integer-literal)* `]` + `:` object-type `into` composite-type + ``` + + For example: + + ``` + %0 = spv.CompositeInsert %object, %composite[1 : i32] : f32 into !spv.array<4xf32> + ``` + }]; + + let arguments = (ins + SPV_Type:$object, + SPV_Composite:$composite, + I32ArrayAttr:$indices + ); + + let results = (outs + SPV_Composite:$result + ); +} + +#endif // SPIRV_COMPOSITE_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 000f1dd..bbb99da 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -35,6 +35,7 @@ include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td" include "mlir/Dialect/SPIRV/SPIRVAtomicOps.td" include "mlir/Dialect/SPIRV/SPIRVBitOps.td" include "mlir/Dialect/SPIRV/SPIRVCastOps.td" +include "mlir/Dialect/SPIRV/SPIRVCompositeOps.td" include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" @@ -108,52 +109,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { // ----- -def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { - let summary = "Extract a part of a composite object."; - - let description = [{ - Result Type must be the type of object selected by the last provided - index. The instruction result is the extracted object. - - Composite is the composite to extract from. - - Indexes walk the type hierarchy, potentially down to component - granularity, to select the part to extract. All indexes must be in - bounds. All composite constituents use zero-based numbering, as - described by their OpType… instruction. - - ### Custom assembly form - - ``` {.ebnf} - composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use - `[` integer-literal (',' integer-literal)* `]` - `:` composite-type - ``` - - For example: - - ``` - %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function> - %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> - %2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>> - ``` - - }]; - - let arguments = (ins - SPV_Composite:$composite, - I32ArrayAttr:$indices - ); - - let results = (outs - SPV_Type:$component - ); - - let hasFolder = 1; -} - -// ----- - def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> { let summary = [{ Wait for other invocations of this module to reach the current point of diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 99705f6d..34e2e88 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -377,6 +377,34 @@ static unsigned getBitWidth(Type type) { llvm_unreachable("unhandled bit width computation for type"); } +/// Walks the given type hierarchy with the given indices, potentially down +/// to component granularity, to select an element type. Returns null type and +/// emits errors with the given loc on failure. +static Type getElementType(Type type, ArrayAttr indices, Location loc) { + if (!indices.size()) { + emitError(loc, "expected at least one index"); + return nullptr; + } + + int32_t index; + for (auto indexAttr : indices) { + index = indexAttr.dyn_cast<IntegerAttr>().getInt(); + if (auto cType = type.dyn_cast<spirv::CompositeType>()) { + if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { + emitError(loc, "index ") << index << " out of bounds for " << type; + return nullptr; + } + type = cType.getElementType(index); + } else { + emitError(loc, "cannot extract from non-composite type ") + << type << " with index " << index; + return nullptr; + } + } + + return type; +} + /// Returns true if the given `block` only contains one `spv._merge` op. static inline bool isMergeBlock(Block &block) { return !block.empty() && std::next(block.begin()) == block.end() && @@ -1094,28 +1122,11 @@ static void print(spirv::CompositeExtractOp compositeExtractOp, } static LogicalResult verify(spirv::CompositeExtractOp compExOp) { - auto resultType = compExOp.composite()->getType(); auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>(); - - if (!indicesArrayAttr.size()) { - return compExOp.emitOpError( - "expected at least one index for spv.CompositeExtractOp"); - } - - int32_t index; - for (auto indexAttr : indicesArrayAttr) { - index = indexAttr.dyn_cast<IntegerAttr>().getInt(); - if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) { - if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { - return compExOp.emitOpError("index ") - << index << " out of bounds for " << resultType; - } - resultType = cType.getElementType(index); - } else { - return compExOp.emitError("cannot extract from non-composite type ") - << resultType << " with index " << index; - } - } + auto resultType = getElementType(compExOp.composite()->getType(), + indicesArrayAttr, compExOp.getLoc()); + if (!resultType) + return failure(); if (resultType != compExOp.getType()) { return compExOp.emitOpError("invalid result type: expected ") @@ -1136,6 +1147,60 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) { } //===----------------------------------------------------------------------===// +// spv.CompositeInsert +//===----------------------------------------------------------------------===// + +static ParseResult parseCompositeInsertOp(OpAsmParser &parser, + OperationState &state) { + SmallVector<OpAsmParser::OperandType, 2> operands; + Type objectType, compositeType; + Attribute indicesAttr; + auto loc = parser.getCurrentLocation(); + + return failure( + parser.parseOperandList(operands, 2) || + parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) || + parser.parseColonType(objectType) || + parser.parseKeywordType("into", compositeType) || + parser.resolveOperands(operands, {objectType, compositeType}, loc, + state.operands) || + parser.addTypesToList(compositeType, state.types)); +} + +static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) { + auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>(); + auto objectType = + getElementType(compositeInsertOp.composite()->getType(), indicesArrayAttr, + compositeInsertOp.getLoc()); + if (!objectType) + return failure(); + + if (objectType != compositeInsertOp.object()->getType()) { + return compositeInsertOp.emitOpError("object operand type should be ") + << objectType << ", but found " + << compositeInsertOp.object()->getType(); + } + + if (compositeInsertOp.composite()->getType() != compositeInsertOp.getType()) { + return compositeInsertOp.emitOpError("result type should be the same as " + "the composite type, but found ") + << compositeInsertOp.composite()->getType() << " vs " + << compositeInsertOp.getType(); + } + + return success(); +} + +static void print(spirv::CompositeInsertOp compositeInsertOp, + OpAsmPrinter &printer) { + printer << spirv::CompositeInsertOp::getOperationName() << " " + << *compositeInsertOp.object() << ", " + << *compositeInsertOp.composite() << compositeInsertOp.indices() + << " : " << compositeInsertOp.object()->getType() << " into " + << compositeInsertOp.composite()->getType(); +} + +//===----------------------------------------------------------------------===// // spv.constant //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir new file mode 100644 index 0000000..a3f74ca --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s + +spv.module "Logical" "GLSL450" { + func @composite_insert(%arg0 : !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> + spv.ReturnValue %0: !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> + } +} diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir new file mode 100644 index 0000000..353080c --- /dev/null +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -0,0 +1,179 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.CompositeExtractOp +//===----------------------------------------------------------------------===// + +func @composite_extract_array(%arg0: !spv.array<4xf32>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32] : !spv.array<4 x f32> + %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4xf32> + return %0: f32 +} + +// ----- + +func @composite_extract_struct(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4 x f32>> + %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4xf32>> + return %0 : f32 +} + +// ----- + +func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[1 : i32] : vector<4xf32> + %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32> + return %0 : f32 +} + +// ----- + +func @composite_extract_no_ssa_operand() -> () { + // expected-error @+1 {{expected SSA operand}} + %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_index_type_1() -> () { + %0 = spv.constant 10 : i32 + %1 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function> + %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> + // expected-error @+1 {{expected non-function type}} + %3 = spv.CompositeExtract %2[%0] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}} + %0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_index_identifier(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{expected bare identifier}} + %0 = spv.CompositeExtract %arg0(1 : i32) : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x !spv.array<4 x f32>>'}} + %0 = spv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spv.array<4x!spv.array<4xf32>> +) -> () { + // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x f32>'}} + %0 = spv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> () { + // expected-error @+1 {{index 2 out of bounds for '!spv.struct<f32, !spv.array<4 x f32>>'}} + %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct<f32, !spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () { + // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}} + %0 = spv.CompositeExtract %arg0[4 : i32] : vector<4xf32> + return +} + +// ----- + +func @composite_extract_invalid_types_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 3}} + %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_types_2(%arg0: f32) -> () { + // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 1}} + %0 = spv.CompositeExtract %arg0[1 : i32] : f32 + return +} + +// ----- + +func @composite_extract_invalid_extracted_type(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{expected at least one index for spv.CompositeExtract}} + %0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_result_type_mismatch(%arg0: !spv.array<4xf32>) -> i32 { + // expected-error @+1 {{invalid result type: expected 'f32' but provided 'i32'}} + %0 = "spv.CompositeExtract"(%arg0) {indices = [2: i32]} : (!spv.array<4xf32>) -> (i32) + return %0: i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.CompositeInsert +//===----------------------------------------------------------------------===// + +func @composite_insert_array(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into !spv.array<4 x f32> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32] : f32 into !spv.array<4xf32> + return %0: !spv.array<4xf32> +} + +// ----- + +func @composite_insert_struct(%arg0: !spv.struct<!spv.array<4xf32>, f32>, %arg1: !spv.array<4xf32>) -> !spv.struct<!spv.array<4xf32>, f32> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spv.array<4 x f32> into !spv.struct<!spv.array<4 x f32>, f32> + %0 = spv.CompositeInsert %arg1, %arg0[0 : i32] : !spv.array<4xf32> into !spv.struct<!spv.array<4xf32>, f32> + return %0: !spv.struct<!spv.array<4xf32>, f32> +} + +// ----- + +func @composite_insert_no_indices(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> { + // expected-error @+1 {{expected at least one index}} + %0 = spv.CompositeInsert %arg1, %arg0[] : f32 into !spv.array<4xf32> + return %0: !spv.array<4xf32> +} + +// ----- + +func @composite_insert_out_of_bounds(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> { + // expected-error @+1 {{index 4 out of bounds}} + %0 = spv.CompositeInsert %arg1, %arg0[4 : i32] : f32 into !spv.array<4xf32> + return %0: !spv.array<4xf32> +} + +// ----- + +func @composite_insert_invalid_object_type(%arg0: !spv.array<4xf32>, %arg1: f64) -> !spv.array<4xf32> { + // expected-error @+1 {{object operand type should be 'f32', but found 'f64'}} + %0 = spv.CompositeInsert %arg1, %arg0[3 : i32] : f64 into !spv.array<4xf32> + return %0: !spv.array<4xf32> +} + +// ----- + +func @composite_insert_invalid_result_type(%arg0: !spv.array<4xf32>, %arg1 : f32) -> !spv.array<4xf64> { + // expected-error @+1 {{result type should be the same as the composite type, but found '!spv.array<4 x f32>' vs '!spv.array<4 x f64>'}} + %0 = "spv.CompositeInsert"(%arg1, %arg0) {indices = [0: i32]} : (f32, !spv.array<4xf32>) -> !spv.array<4xf64> + return %0: !spv.array<4xf64> +} diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 784af94..9fb0b41 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -277,150 +277,6 @@ func @bitreverse(%arg: i32) -> i32 { // ----- //===----------------------------------------------------------------------===// -// spv.CompositeExtractOp -//===----------------------------------------------------------------------===// - -func @composite_extract_f32_from_1D_array(%arg0: !spv.array<4xf32>) -> f32 { - // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4 x f32> - %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4xf32> - return %0: f32 -} - -// ----- - -func @composite_extract_f32_from_2D_array(%arg0: !spv.array<4x!spv.array<4xf32>>) -> f32 { - // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.array<4 x !spv.array<4 x f32>> - %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.array<4x!spv.array<4xf32>> - return %0: f32 -} - -// ----- - -func @composite_extract_1D_array_from_2D_array(%arg0: !spv.array<4x!spv.array<4xf32>>) -> !spv.array<4xf32> { - // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4 x !spv.array<4 x f32>> - %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4x!spv.array<4xf32>> - return %0 : !spv.array<4xf32> -} - -// ----- - -func @composite_extract_struct(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> f32 { - // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4 x f32>> - %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct<f32, !spv.array<4xf32>> - return %0 : f32 -} - -// ----- - -func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 { - // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32> - %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32> - return %0 : f32 -} - -// ----- - -func @composite_extract_no_ssa_operand() -> () { - // expected-error @+1 {{expected SSA operand}} - %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_invalid_index_type_1() -> () { - %0 = spv.constant 10 : i32 - %1 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function> - %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> - // expected-error @+1 {{expected non-function type}} - %3 = spv.CompositeExtract %2[%0] : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () { - // expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}} - %0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_invalid_index_identifier(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () { - // expected-error @+1 {{expected bare identifier}} - %0 = spv.CompositeExtract %arg0(1 : i32) : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { - // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x !spv.array<4 x f32>>'}} - %0 = spv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spv.array<4x!spv.array<4xf32>> -) -> () { - // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x f32>'}} - %0 = spv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct<f32, !spv.array<4xf32>>) -> () { - // expected-error @+1 {{index 2 out of bounds for '!spv.struct<f32, !spv.array<4 x f32>>'}} - %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct<f32, !spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () { - // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}} - %0 = spv.CompositeExtract %arg0[4 : i32] : vector<4xf32> - return -} - -// ----- - -func @composite_extract_invalid_types_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { - // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 3}} - %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_invalid_types_2(%arg0: f32) -> () { - // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 1}} - %0 = spv.CompositeExtract %arg0[1 : i32] : f32 - return -} - -// ----- - -func @composite_extract_invalid_extracted_type(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { - // expected-error @+1 {{expected at least one index for spv.CompositeExtract}} - %0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>> - return -} - -// ----- - -func @composite_extract_result_type_mismatch(%arg0: !spv.array<4xf32>) -> i32 { - // expected-error @+1 {{invalid result type: expected 'f32' but provided 'i32'}} - %0 = "spv.CompositeExtract"(%arg0) {indices = [2: i32]} : (!spv.array<4xf32>) -> (i32) - return %0: i32 -} - -// ----- - -//===----------------------------------------------------------------------===// // spv.ControlBarrier //===----------------------------------------------------------------------===// |