diff options
author | Finn Plummer <50529406+inbelic@users.noreply.github.com> | 2024-03-21 08:49:27 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 08:49:27 -0700 |
commit | 38f8a3cf0d75cd25e13d3757027f7356e4466cb9 (patch) | |
tree | 3a5e2932cf70ec0e28df79c979fff4a208870280 /mlir | |
parent | 6295e677220bb6ec1fa8abe2f4a94b513b91b786 (diff) | |
download | llvm-38f8a3cf0d75cd25e13d3757027f7356e4466cb9.zip llvm-38f8a3cf0d75cd25e13d3757027f7356e4466cb9.tar.gz llvm-38f8a3cf0d75cd25e13d3757027f7356e4466cb9.tar.bz2 |
[mlir][spirv] Improve folding of MemRef to SPIRV Lowering (#85433)
Investigate the lowering of MemRef Load/Store ops and implement
additional folding of created ops
Aims to improve readability of generated lowered SPIR-V code.
Part of work llvm#70704
Diffstat (limited to 'mlir')
8 files changed, 93 insertions, 210 deletions
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 0acb214..81b9f55 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, assert(targetBits % sourceBits == 0); Type type = srcIdx.getType(); IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); - auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr); + auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr); IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); - auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr); - auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx); - return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue); + auto srcBitsValue = + builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr); + auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx); + return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue); } /// Returns an adjusted spirv::AccessChainOp. Based on the @@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, Value lastDim = op->getOperand(op.getNumOperands() - 1); Type type = lastDim.getType(); IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); - auto idx = builder.create<spirv::ConstantOp>(loc, type, attr); + auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr); auto indices = llvm::to_vector<4>(op.getIndices()); // There are two elements if this is a 1-D tensor. assert(indices.size() == 2); - indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx); + indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); Type t = typeConverter.convertType(op.getComponentPtr().getType()); return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); } @@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, return srcBool; Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); Value one = spirv::ConstantOp::getOne(dstType, loc, builder); - return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero); + return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one, + zero); } /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast @@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask, loc, builder.getIntegerType(targetBits), value); } - value = builder.create<spirv::BitwiseAndOp>(loc, value, mask); + value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); } - return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value, - offset); + return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(), + value, offset); } /// Returns true if the allocations of memref `type` generated from `allocOp` @@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { return srcInt; auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); - return builder.create<spirv::IEqualOp>(loc, srcInt, one); + return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one); } //===----------------------------------------------------------------------===// @@ -597,13 +599,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, // ____XXXX________ -> ____________XXXX Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); - Value result = rewriter.create<spirv::ShiftRightArithmeticOp>( + Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( loc, spvLoadOp.getType(), spvLoadOp, offset); // Apply the mask to extract corresponding bits. - Value mask = rewriter.create<spirv::ConstantOp>( + Value mask = rewriter.createOrFold<spirv::ConstantOp>( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); - result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask); + result = + rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask); // Apply sign extension on the loading value unconditionally. The signedness // semantic is carried in the operator itself, we relies other pattern to @@ -611,11 +614,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, IntegerAttr shiftValueAttr = rewriter.getIntegerAttr(dstType, dstBits - srcBits); Value shiftValue = - rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr); - result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result, - shiftValue); - result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result, - shiftValue); + rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr); + result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType, + result, shiftValue); + result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( + loc, dstType, result, shiftValue); rewriter.replaceOp(loadOp, result); @@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, // Create a mask to clear the destination. E.g., if it is the second i8 in // i32, 0xFFFF00FF is created. - Value mask = rewriter.create<spirv::ConstantOp>( + Value mask = rewriter.createOrFold<spirv::ConstantOp>( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); - Value clearBitsMask = - rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset); - clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask); + Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>( + loc, dstType, mask, offset); + clearBitsMask = + rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask); Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, @@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite( int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt(); Attribute attr = rewriter.getIntegerAttr(intType, attrVal); - return rewriter.create<spirv::ConstantOp>(loc, intType, attr); + return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr); }(); rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 2b79c80..4072608 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, // broken down into progressive small steps so we can have intermediate steps // using other dialects. At the moment SPIR-V is the final sink. - Value linearizedIndex = builder.create<spirv::ConstantOp>( + Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>( loc, integerType, IntegerAttr::get(integerType, offset)); for (const auto &index : llvm::enumerate(indices)) { - Value strideVal = builder.create<spirv::ConstantOp>( + Value strideVal = builder.createOrFold<spirv::ConstantOp>( loc, integerType, IntegerAttr::get(integerType, strides[index.index()])); - Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value()); + Value update = + builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal); linearizedIndex = - builder.create<spirv::IAddOp>(loc, linearizedIndex, update); + builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex); } return linearizedIndex; } diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir index fa12da8..4339799 100644 --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -60,13 +60,9 @@ module attributes { // CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]] %13 = arith.addi %arg4, %3 : index // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 - // CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32 // CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32 - // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32 - // CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32 - // CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32 - // CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32 - // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32 + // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32 + // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32 // CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}} // CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]] %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>> diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir index 470c853..52ed14e 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir @@ -12,16 +12,10 @@ module attributes { // CHECK-LABEL: @load_i1 func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 { // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 - // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] - // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 - // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 @@ -37,32 +31,20 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 // INDEX64-LABEL: @load_i8 func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 { // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 - // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] - // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 - // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 // CHECK: builtin.unrealized_conversion_cast %[[SR]] // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 - // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 - // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64 + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 // INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 - // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 - // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 - // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64 // INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32 - // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 // INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32 // INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 @@ -76,15 +58,12 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 { // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32 - // CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32 // CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32 - // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32 + // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] // CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32 // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32 @@ -110,20 +89,12 @@ func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) { func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) { // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32 - // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32 - // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] - // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] + // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]] memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>> return } @@ -136,36 +107,22 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32 - // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32 // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 - // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 - // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] - // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] - // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] + // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] + // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]] + // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]] // INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 // INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 - // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 - // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 - // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 // INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32 - // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64 - // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32 // INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 - // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64 - // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64 - // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] - // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 + // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]] + // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]] memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>> return } @@ -177,19 +134,16 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32 - // CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32 // CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32 // CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32 // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32 // CHECK: %[[MASK1:.+]] = spirv.Constant 65535 : i32 // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG2_CAST]], %[[MASK1]] : i32 // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 - // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32 // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] @@ -222,15 +176,12 @@ module attributes { func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i: index) -> i4 { // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32 - // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32 // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32 // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 @@ -248,19 +199,16 @@ func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %v // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32 // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32 // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32 - // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[FOUR:.+]] = spirv.Constant [[OFFSET]] : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32 // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32 // CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32 // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32 // CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32 // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32 // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32 - // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32 + // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32 // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]] // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]] // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] @@ -283,16 +231,10 @@ module attributes { // INDEX64-LABEL: @load_i8 func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 { // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 - // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] - // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 - // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 @@ -300,16 +242,10 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 // CHECK: return %[[CAST]] : i8 // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 - // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 - // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64 + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 // INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 - // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 - // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 - // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64 // INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32 - // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32 // INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32 // INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 @@ -326,37 +262,19 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) { // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 - // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 - // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32 - // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 - // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // CHECK: %[[MASK1:.+]] = spirv.Constant -256 : i32 // CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32 - // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 - // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 - // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 - // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] - // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] - // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] + // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] + // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]] + // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]] // INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 - // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 - // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 - // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 - // INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32 - // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64 - // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // INDEX64: %[[MASK1:.+]] = spirv.Constant -256 : i32 // INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32 - // INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 - // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64 - // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 - // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64 - // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]] - // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]] + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64 + // INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]] + // INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]] memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>> return } diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index feb6d4e..10c03a2 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -70,11 +70,8 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 { // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] - // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32 - // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 - // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[ADD]]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8 // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8 @@ -90,15 +87,10 @@ func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i: %true = arith.constant true // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] - // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32 - // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 - // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[ADD]]] - // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[IDX_CAST]]] // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 - // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8 - // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8 + // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[ONE_I8]] : i8 memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<StorageBuffer>> return } @@ -234,11 +226,7 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 { // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] - // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32 - // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 - // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]] + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]] // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8 // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8 @@ -254,15 +242,9 @@ func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i %true = arith.constant true // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] - // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 - // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32 - // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32 - // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]] - // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[IDX_CAST]]] // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 - // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8 - // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[RES]] : i8 + // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[ONE_I8]] : i8 memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>> return } diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir index 02558463..81661ec 100644 --- a/mlir/test/Conversion/SCFToSPIRV/for.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir @@ -19,17 +19,9 @@ func.func @loop_kernel(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer // CHECK: spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] // CHECK: ^[[BODY]]: // CHECK: %[[ZERO1:.*]] = spirv.Constant 0 : i32 - // CHECK: %[[OFFSET1:.*]] = spirv.Constant 0 : i32 - // CHECK: %[[STRIDE1:.*]] = spirv.Constant 1 : i32 - // CHECK: %[[UPDATE1:.*]] = spirv.IMul %[[STRIDE1]], %[[INDVAR]] : i32 - // CHECK: %[[INDEX1:.*]] = spirv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32 - // CHECK: spirv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}} + // CHECK: spirv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDVAR]]{{\]}} // CHECK: %[[ZERO2:.*]] = spirv.Constant 0 : i32 - // CHECK: %[[OFFSET2:.*]] = spirv.Constant 0 : i32 - // CHECK: %[[STRIDE2:.*]] = spirv.Constant 1 : i32 - // CHECK: %[[UPDATE2:.*]] = spirv.IMul %[[STRIDE2]], %[[INDVAR]] : i32 - // CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32 - // CHECK: spirv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]] + // CHECK: spirv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDVAR]]] // CHECK: %[[INCREMENT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32 // CHECK: spirv.Branch ^[[HEADER]](%[[INCREMENT]] : i32) // CHECK: ^[[MERGE]] diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir index 19de613..32d0fbe 100644 --- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir +++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir @@ -14,14 +14,12 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 { // CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.array<12 x i32> // CHECK: %[[C0:.+]] = spirv.Constant 0 : i32 // CHECK: %[[C6:.+]] = spirv.Constant 6 : i32 - // CHECK: %[[MUL0:.+]] = spirv.IMul %[[C6]], %[[A]] : i32 - // CHECK: %[[ADD0:.+]] = spirv.IAdd %[[C0]], %[[MUL0]] : i32 + // CHECK: %[[MUL0:.+]] = spirv.IMul %[[A]], %[[C6]] : i32 // CHECK: %[[C3:.+]] = spirv.Constant 3 : i32 - // CHECK: %[[MUL1:.+]] = spirv.IMul %[[C3]], %[[B]] : i32 - // CHECK: %[[ADD1:.+]] = spirv.IAdd %[[ADD0]], %[[MUL1]] : i32 + // CHECK: %[[MUL1:.+]] = spirv.IMul %[[B]], %[[C3]] : i32 + // CHECK: %[[ADD1:.+]] = spirv.IAdd %[[MUL1]], %[[MUL0]] : i32 // CHECK: %[[C1:.+]] = spirv.Constant 1 : i32 - // CHECK: %[[MUL2:.+]] = spirv.IMul %[[C1]], %[[C]] : i32 - // CHECK: %[[ADD2:.+]] = spirv.IAdd %[[ADD1]], %[[MUL2]] : i32 + // CHECK: %[[ADD2:.+]] = spirv.IAdd %[[C]], %[[ADD1]] : i32 // CHECK: %[[AC:.+]] = spirv.AccessChain %[[VAR]][%[[ADD2]]] // CHECK: %[[VAL:.+]] = spirv.Load "Function" %[[AC]] : i32 %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32> diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index c9984091..cddc4ee 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -720,9 +720,7 @@ module attributes { // CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32 -// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32 -// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32 -// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 +// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 // CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer> // CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32> // CHECK: return %[[R0]] : vector<4xf32> @@ -743,11 +741,9 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer> // CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32 -// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32 -// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32 +// CHECK: %[[S3:.+]] = spirv.IMul %[[S1]], %[[CST4]] : i32 // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32 -// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32 -// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32 +// CHECK: %[[S6:.+]] = spirv.IAdd %[[S2]], %[[S3]] : i32 // CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 // CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer> // CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32> @@ -768,9 +764,7 @@ func.func @vector_load_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBu // CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32 -// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32 -// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32 -// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 +// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 // CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer> // CHECK: spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32> func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) { @@ -790,11 +784,9 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer // CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32 // CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32 -// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32 -// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32 +// CHECK: %[[S3:.+]] = spirv.IMul %[[S1]], %[[CST4]] : i32 // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32 -// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32 -// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32 +// CHECK: %[[S6:.+]] = spirv.IAdd %[[S2]], %[[S3]] : i32 // CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 // CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer> // CHECK: spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32> |