aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrea Faulds <andrea.faulds@amd.com>2025-02-05 14:27:35 +0100
committerGitHub <noreply@github.com>2025-02-05 14:27:35 +0100
commit5df62bdc9be9c258c5ac45c8093b71e23777fa0e (patch)
treee5d4d69635cc4c7b24d324e756a94dabc08ba8a2
parent814db6c53faeb1dc66361b67cf30a5e42036c1bb (diff)
downloadllvm-5df62bdc9be9c258c5ac45c8093b71e23777fa0e.zip
llvm-5df62bdc9be9c258c5ac45c8093b71e23777fa0e.tar.gz
llvm-5df62bdc9be9c258c5ac45c8093b71e23777fa0e.tar.bz2
[mlir][spirv] Support poison index when converting vector.insert/extract (#125560)
This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible. Resolves #124162.
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp77
-rw-r--r--mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir47
2 files changed, 107 insertions, 17 deletions
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index af882cb..2c8bc14 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -137,6 +137,33 @@ struct VectorBroadcastConvert final
}
};
+// SPIR-V does not have a concept of a poison index for certain instructions,
+// which creates a UB hazard when lowering from otherwise equivalent Vector
+// dialect instructions, because this index will be considered out-of-bounds.
+// To avoid this, this function implements a dynamic sanitization that returns
+// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
+// (presumably more efficient), and otherwise index 0 (always in-bounds).
+static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
+ Location loc, Value dynamicIndex,
+ int64_t kPoisonIndex, unsigned vectorSize) {
+ if (llvm::isPowerOf2_32(vectorSize)) {
+ Value inBoundsMask = rewriter.create<spirv::ConstantOp>(
+ loc, dynamicIndex.getType(),
+ rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
+ return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex,
+ inBoundsMask);
+ }
+ Value poisonIndex = rewriter.create<spirv::ConstantOp>(
+ loc, dynamicIndex.getType(),
+ rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
+ Value cmpResult =
+ rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
+ return rewriter.create<spirv::SelectOp>(
+ loc, cmpResult,
+ spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
+ dynamicIndex);
+}
+
struct VectorExtractOpConvert final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +181,26 @@ struct VectorExtractOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(extractOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
- rewriter.getI32ArrayAttr(id.value()));
- else
+ getConstantIntValue(extractOp.getMixedPosition()[0])) {
+ // TODO: ExtractOp::fold() already can fold a static poison index to
+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
+ if (id == vector::ExtractOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero =
+ spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
+ rewriter.replaceOp(extractOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::ExtractOp::kPoisonIndex,
+ extractOp.getSourceVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, dstType, adaptor.getVector(),
- adaptor.getDynamicPosition()[0]);
+ extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+ }
return success();
}
};
@@ -266,13 +305,25 @@ struct VectorInsertOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(insertOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
- else
+ getConstantIntValue(insertOp.getMixedPosition()[0])) {
+ // TODO: ExtractOp::fold() already can fold a static poison index to
+ // ub.poison; remove this once ub.poison can be converted to SPIR-V.
+ if (id == vector::InsertOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
+ insertOp.getLoc(), rewriter);
+ rewriter.replaceOp(insertOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::InsertOp::kPoisonIndex,
+ insertOp.getDestVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, insertOp.getDest(), adaptor.getSource(),
- adaptor.getDynamicPosition()[0]);
+ insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 383215c..5fd7324 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
// -----
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
- // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+ // CHECK: return %[[ZERO]]
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
return %0: f32
}
@@ -208,12 +209,31 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
// CHECK-LABEL: @extract_dynamic
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
-// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[MASK:.+]] = spirv.Constant 3 :
+// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[MASKED]]] : vector<4xf32>, i32
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
return %0: f32
}
+// -----
+
+// CHECK-LABEL: @extract_dynamic_non_pow2
+// CHECK-SAME: %[[V:.*]]: vector<3xf32>, %[[ARG1:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<3xf32>, i32
+func.func @extract_dynamic_non_pow2(%arg0 : vector<3xf32>, %id : index) -> f32 {
+ %0 = vector.extract %arg0[%id] : f32 from vector<3xf32>
+ return %0: f32
+}
+
+// -----
+
// CHECK-LABEL: @extract_dynamic_cst
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@@ -264,8 +284,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
+// CHECK: return %[[ZERO]]
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
- // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
return %1: vector<4xf32>
}
@@ -306,7 +328,9 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
// CHECK-LABEL: @insert_dynamic
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
-// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[MASK:.+]] = spirv.Constant 3 :
+// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[ID]], %[[MASK]] :
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[MASKED]]] : vector<4xf32>, i32
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
return %0: vector<4xf32>
@@ -314,6 +338,21 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect
// -----
+// CHECK-LABEL: @insert_dynamic_non_pow2
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<3xf32>, %[[ARG2:.*]]: index
+// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<3xf32>, i32
+func.func @insert_dynamic_non_pow2(%val: f32, %arg0 : vector<3xf32>, %id : index) -> vector<3xf32> {
+ %0 = vector.insert %val, %arg0[%id] : f32 into vector<3xf32>
+ return %0: vector<3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @insert_dynamic_cst
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>