diff options
author | Andrea Faulds <andrea.faulds@amd.com> | 2025-02-05 14:27:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-05 14:27:35 +0100 |
commit | 5df62bdc9be9c258c5ac45c8093b71e23777fa0e (patch) | |
tree | e5d4d69635cc4c7b24d324e756a94dabc08ba8a2 | |
parent | 814db6c53faeb1dc66361b67cf30a5e42036c1bb (diff) | |
download | llvm-5df62bdc9be9c258c5ac45c8093b71e23777fa0e.zip llvm-5df62bdc9be9c258c5ac45c8093b71e23777fa0e.tar.gz llvm-5df62bdc9be9c258c5ac45c8093b71e23777fa0e.tar.bz2 |
[mlir][spirv] Support poison index when converting vector.insert/extract (#125560)
This modifies the conversion patterns so that, in the case where the
index is known statically to be poison, the insertion/extraction is
replaced by an arbitrary junk constant value, and in the dynamic case,
the index is sanitized at runtime. This avoids triggering a UB in both
cases. The dynamic case is definitely a pessimisation of the generated
code, but the use of dynamic indexes is expected to be very rare and
already slow on real-world GPU compilers ingesting SPIR-V, so the impact
should be negligible.
Resolves #124162.
-rw-r--r-- | mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 77 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir | 47 |
2 files changed, 107 insertions, 17 deletions
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index af882cb..2c8bc14 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -137,6 +137,33 @@ struct VectorBroadcastConvert final } }; +// SPIR-V does not have a concept of a poison index for certain instructions, +// which creates a UB hazard when lowering from otherwise equivalent Vector +// dialect instructions, because this index will be considered out-of-bounds. +// To avoid this, this function implements a dynamic sanitization that returns +// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask +// (presumably more efficient), and otherwise index 0 (always in-bounds). +static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter, + Location loc, Value dynamicIndex, + int64_t kPoisonIndex, unsigned vectorSize) { + if (llvm::isPowerOf2_32(vectorSize)) { + Value inBoundsMask = rewriter.create<spirv::ConstantOp>( + loc, dynamicIndex.getType(), + rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1)); + return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex, + inBoundsMask); + } + Value poisonIndex = rewriter.create<spirv::ConstantOp>( + loc, dynamicIndex.getType(), + rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex)); + Value cmpResult = + rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex); + return rewriter.create<spirv::SelectOp>( + loc, cmpResult, + spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter), + dynamicIndex); +} + struct VectorExtractOpConvert final : public OpConversionPattern<vector::ExtractOp> { using OpConversionPattern::OpConversionPattern; @@ -154,14 +181,26 @@ struct VectorExtractOpConvert final } if (std::optional<int64_t> id = - getConstantIntValue(extractOp.getMixedPosition()[0])) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, dstType, adaptor.getVector(), - rewriter.getI32ArrayAttr(id.value())); - else + getConstantIntValue(extractOp.getMixedPosition()[0])) { + // TODO: ExtractOp::fold() already can fold a static poison index to + // ub.poison; remove this once ub.poison can be converted to SPIR-V. + if (id == vector::ExtractOp::kPoisonIndex) { + // Arbitrary choice of poison result, intended to stick out. + Value zero = + spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter); + rewriter.replaceOp(extractOp, zero); + } else + rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( + extractOp, dstType, adaptor.getVector(), + rewriter.getI32ArrayAttr(id.value())); + } else { + Value sanitizedIndex = sanitizeDynamicIndex( + rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0], + vector::ExtractOp::kPoisonIndex, + extractOp.getSourceVectorType().getNumElements()); rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, dstType, adaptor.getVector(), - adaptor.getDynamicPosition()[0]); + extractOp, dstType, adaptor.getVector(), sanitizedIndex); + } return success(); } }; @@ -266,13 +305,25 @@ struct VectorInsertOpConvert final } if (std::optional<int64_t> id = - getConstantIntValue(insertOp.getMixedPosition()[0])) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); - else + getConstantIntValue(insertOp.getMixedPosition()[0])) { + // TODO: ExtractOp::fold() already can fold a static poison index to + // ub.poison; remove this once ub.poison can be converted to SPIR-V. + if (id == vector::InsertOp::kPoisonIndex) { + // Arbitrary choice of poison result, intended to stick out. + Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(), + insertOp.getLoc(), rewriter); + rewriter.replaceOp(insertOp, zero); + } else + rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( + insertOp, adaptor.getSource(), adaptor.getDest(), id.value()); + } else { + Value sanitizedIndex = sanitizeDynamicIndex( + rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0], + vector::InsertOp::kPoisonIndex, + insertOp.getDestVectorType().getNumElements()); rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, insertOp.getDest(), adaptor.getSource(), - adaptor.getDynamicPosition()[0]); + insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex); + } return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 383215c..5fd7324 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) { // ----- func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 { - // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}} + // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00 + // CHECK: return %[[ZERO]] %0 = vector.extract %arg0[-1] : f32 from vector<4xf32> return %0: f32 } @@ -208,12 +209,31 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f // CHECK-LABEL: @extract_dynamic // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +// CHECK: %[[MASK:.+]] = spirv.Constant 3 : +// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] : +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[MASKED]]] : vector<4xf32>, i32 func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 { %0 = vector.extract %arg0[%id] : f32 from vector<4xf32> return %0: f32 } +// ----- + +// CHECK-LABEL: @extract_dynamic_non_pow2 +// CHECK-SAME: %[[V:.*]]: vector<3xf32>, %[[ARG1:.*]]: index +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 +// CHECK: %[[POISON:.+]] = spirv.Constant -1 : +// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]] +// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : +// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] : +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<3xf32>, i32 +func.func @extract_dynamic_non_pow2(%arg0 : vector<3xf32>, %id : index) -> f32 { + %0 = vector.extract %arg0[%id] : f32 from vector<3xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_dynamic_cst // CHECK-SAME: %[[V:.*]]: vector<4xf32> // CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> @@ -264,8 +284,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { // ----- +// CHECK-LABEL: @insert_poison_idx +// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00> +// CHECK: return %[[ZERO]] func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { - // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}} %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32> return %1: vector<4xf32> } @@ -306,7 +328,9 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : // CHECK-LABEL: @insert_dynamic // CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 +// CHECK: %[[MASK:.+]] = spirv.Constant 3 : +// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] : +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[MASKED]]] : vector<4xf32>, i32 func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32> return %0: vector<4xf32> @@ -314,6 +338,21 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect // ----- +// CHECK-LABEL: @insert_dynamic_non_pow2 +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<3xf32>, %[[ARG2:.*]]: index +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 +// CHECK: %[[POISON:.+]] = spirv.Constant -1 : +// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]] +// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : +// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] : +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<3xf32>, i32 +func.func @insert_dynamic_non_pow2(%val: f32, %arg0 : vector<3xf32>, %id : index) -> vector<3xf32> { + %0 = vector.insert %val, %arg0[%id] : f32 into vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + // CHECK-LABEL: @insert_dynamic_cst // CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> // CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> |