diff options
author | jeanPerier <jperier@nvidia.com> | 2025-05-20 10:45:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-20 10:45:29 +0200 |
commit | 80816e792382da286b29f937938ab54ae159f482 (patch) | |
tree | 53f405ef638b704fe72a3b16e89d9e23f9f0deea | |
parent | 034eaeddc30cbaf273744580f15325514d5fb928 (diff) | |
download | llvm-80816e792382da286b29f937938ab54ae159f482.zip llvm-80816e792382da286b29f937938ab54ae159f482.tar.gz llvm-80816e792382da286b29f937938ab54ae159f482.tar.bz2 |
[mlir][LLVM] handle ArrayAttr for constant array of structs (#139724)
While LLVM IR dialect has a way to represent arbitrary LLVM constant
array of structs via an insert chain, it is in practice very expensive
for the compilation time as soon as the array is bigger than a couple
hundred elements. This is because generating and later folding such
insert chain is really not cheap.
This patch allows representing array of struct constants via ArrayAttr in
the LLVM dialect.
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 11 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 93 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 31 | ||||
-rw-r--r-- | mlir/test/Dialect/LLVMIR/invalid.mlir | 32 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/llvmir-invalid.mlir | 5 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/llvmir.mlir | 26 |
6 files changed, 179 insertions, 19 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index f19f9d5..61ba8f7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -2073,9 +2073,9 @@ def LLVM_ConstantOp Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, all constants must be created as SSA values before being used in other operations. `llvm.mlir.constant` creates such values for scalars, vectors, - strings, and structs. It has a mandatory `value` attribute whose type - depends on the type of the constant value. The type of the constant value - must correspond to the attribute type converted to LLVM IR type. + strings, structs, and array of structs. It has a mandatory `value` attribute + whose type depends on the type of the constant value. The type of the constant + value must correspond to the attribute type converted to LLVM IR type. When creating constant scalars, the `value` attribute must be either an integer attribute or a floating point attribute. The type of the attribute @@ -2097,6 +2097,11 @@ def LLVM_ConstantOp must correspond to the type of the corresponding attribute element converted to LLVM IR. + When creating an array of structs, the `value` attribute must be an array + attribute, itself containing zero, or undef, or array attributes for each + potential nested array type, and the elements of the leaf array attributes + for must match the struct element types or be zero or undef attributes. + Examples: ```mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c757f3c..d8abf6f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3142,6 +3142,74 @@ static bool hasScalableVectorType(Type t) { return false; } +/// Verifies the constant array represented by `arrayAttr` matches the provided +/// `arrayType`. +static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op, + LLVM::LLVMArrayType arrayType, + ArrayAttr arrayAttr, int dim) { + if (arrayType.getNumElements() != arrayAttr.size()) + return op.emitOpError() + << "array attribute size does not match array type size in " + "dimension " + << dim << ": " << arrayAttr.size() << " vs. " + << arrayType.getNumElements(); + + llvm::DenseSet<Attribute> elementsVerified; + + // Recursively verify sub-dimensions for multidimensional arrays. + if (auto subArrayType = + dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) { + for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) + if (elementsVerified.insert(elementAttr).second) { + if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr)) + continue; + auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr); + if (!subArrayAttr) + return op.emitOpError() + << "nested attribute for sub-array in dimension " << dim + << " at index " << idx + << " must be a zero, or undef, or array attribute"; + if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr, + dim + 1))) + return failure(); + } + return success(); + } + + // Forbid usages of ArrayAttr for simple array types that should use + // DenseElementsAttr instead. Note that there would be a use case for such + // array types when one element value is obtained via a ptr-to-int conversion + // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR + // user needs this so far, and it seems better to avoid people misusing the + // ArrayAttr for simple types. + auto structType = dyn_cast<LLVM::LLVMStructType>(arrayType.getElementType()); + if (!structType) + return op.emitOpError() << "for array with an array attribute must have a " + "struct element type"; + + // Shallow verification that leaf attributes are appropriate as struct initial + // value. + size_t numStructElements = structType.getBody().size(); + for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) { + if (elementsVerified.insert(elementAttr).second) { + if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr)) + continue; + auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr); + if (!subArrayAttr) + return op.emitOpError() + << "nested attribute for struct element at index " << idx + << " must be a zero, or undef, or array attribute"; + if (subArrayAttr.size() != numStructElements) + return op.emitOpError() + << "nested array attribute size for struct element at index " + << idx << " must match struct size: " << subArrayAttr.size() + << " vs. " << numStructElements; + } + } + + return success(); +} + LogicalResult LLVM::ConstantOp::verify() { if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) { auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType()); @@ -3208,7 +3276,7 @@ LogicalResult LLVM::ConstantOp::verify() { if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) { return emitOpError() << "expected integer type of width " << floatWidth; } - } else if (isa<ElementsAttr, ArrayAttr>(getValue())) { + } else if (isa<ElementsAttr>(getValue())) { if (hasScalableVectorType(getType())) { // The exact number of elements of a scalable vector is unknown, so we // allow only splat attributes. @@ -3221,15 +3289,20 @@ LogicalResult LLVM::ConstantOp::verify() { if (!isa<VectorType, LLVM::LLVMArrayType>(getType())) return emitOpError() << "expected vector or array type"; // The number of elements of the attribute and the type must match. - int64_t attrNumElements; - if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) - attrNumElements = elementsAttr.getNumElements(); - else - attrNumElements = cast<ArrayAttr>(getValue()).size(); - if (getNumElements(getType()) != attrNumElements) - return emitOpError() - << "type and attribute have a different number of elements: " - << getNumElements(getType()) << " vs. " << attrNumElements; + if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) { + int64_t attrNumElements = elementsAttr.getNumElements(); + if (getNumElements(getType()) != attrNumElements) + return emitOpError() + << "type and attribute have a different number of elements: " + << getNumElements(getType()) << " vs. " << attrNumElements; + } + } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) { + auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType()); + if (!arrayType) + return emitOpError() << "expected array type"; + // When the attribute is an ArrayAttr, check that its nesting matches the + // corresponding ArrayType or VectorType nesting. + return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0); } else { return emitOpError() << "only supports integer, float, string or elements attributes"; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 95b8ee0..9b5c931 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -553,8 +553,10 @@ static llvm::Constant *convertDenseResourceElementsAttr( llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation) { - if (!attr) + if (!attr || isa<UndefAttr>(attr)) return llvm::UndefValue::get(llvmType); + if (isa<ZeroAttr>(attr)) + return llvm::Constant::getNullValue(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { auto arrayAttr = dyn_cast<ArrayAttr>(attr); if (!arrayAttr) { @@ -713,6 +715,33 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( ArrayRef<char>{stringAttr.getValue().data(), stringAttr.getValue().size()}); } + + // Handle arrays of structs that cannot be represented as DenseElementsAttr + // in MLIR. + if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { + if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) { + llvm::Type *elementType = arrayTy->getElementType(); + Attribute previousElementAttr; + llvm::Constant *elementCst = nullptr; + SmallVector<llvm::Constant *> constants; + constants.reserve(arrayTy->getNumElements()); + for (Attribute elementAttr : arrayAttr) { + // Arrays with a single value or with repeating values are quite common. + // Short-circuit the translation when the element value is the same as + // the previous one. + if (!previousElementAttr || previousElementAttr != elementAttr) { + previousElementAttr = elementAttr; + elementCst = + getLLVMConstant(elementType, elementAttr, loc, moduleTranslation); + if (!elementCst) + return nullptr; + } + constants.push_back(elementCst); + } + return llvm::ConstantArray::get(arrayTy, constants); + } + } + emitError(loc, "unsupported constant value"); return nullptr; } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index f9ea066..f5adf4b 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1850,3 +1850,35 @@ llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) { llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> llvm.return } + +// ----- + +llvm.mlir.global @bad_struct_array_init_size() : !llvm.array<2x!llvm.struct<(i32, f32)>> { + // expected-error@below {{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}} + %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>> + llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>> +} + +// ----- + +llvm.mlir.global @bad_struct_array_init_nesting() : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> { + // expected-error@below {{'llvm.mlir.constant' op nested attribute for sub-array in dimension 1 at index 0 must be a zero, or undef, or array attribute}} + %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> + llvm.return %0 : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> +} + +// ----- + +llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct<(i32, f32)>> { + // expected-error@below {{'llvm.mlir.constant' op nested array attribute size for struct element at index 0 must match struct size: 1 vs. 2}} + %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.struct<(i32, f32)>> + llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>> +} + +// ---- + +llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> { + // expected-error@below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}} + %0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64> + llvm.return %0 : !llvm.array<2 x f64> +} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index 90c0f5a..24a7b42 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -79,11 +79,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 { // ----- -// expected-error @below{{unsupported constant value}} -llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64> - -// ----- - // expected-error @below{{LLVM attribute 'readonly' does not expect a value}} llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]} diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 3c8de1c..2376122 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -3022,3 +3022,29 @@ llvm.func internal @i(%arg0: i32) attributes {dso_local} { llvm.call @testfn3(%arg0) : (i32 {llvm.alignstack = 8 : i64}) -> () llvm.return } + +// ----- + +// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }] +llvm.mlir.global @test_array_attr_2() : !llvm.array<2 x !llvm.struct<(i32, f32)>> { + %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2 x !llvm.struct<(i32, f32)>> + llvm.return %0 : !llvm.array<2 x !llvm.struct<(i32, f32)>> +} + +// CHECK: @test_array_attr_3 = global [2 x [3 x { i32, float }]{{.*}}[3 x { i32, float }] [{ i32, float } { i32 1, float 1.000000e+00 }, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } { i32 3, float 1.000000e+00 }], [3 x { i32, float }] [{ i32, float } { i32 4, float 1.000000e+00 }, { i32, float } { i32 5, float 1.000000e+00 }, { i32, float } { i32 6, float 1.000000e+00 } +llvm.mlir.global @test_array_attr_3() : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> { + %0 = llvm.mlir.constant([[[1 : i32, 1.000000e+00 : f32], [2 : i32, 1.000000e+00 : f32], [3 : i32, 1.000000e+00 : f32]], [[4 : i32, 1.000000e+00 : f32], [5 : i32, 1.000000e+00 : f32], [6 : i32, 1.000000e+00 : f32]]]) : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> + llvm.return %0 : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> +} + +// CHECK: @test_array_attr_struct_with_ptr = internal constant [2 x { ptr }] [{ ptr } zeroinitializer, { ptr } undef] +llvm.mlir.global internal constant @test_array_attr_struct_with_ptr() : !llvm.array<2 x struct<(ptr)>> { + %0 = llvm.mlir.constant([[#llvm.zero], [#llvm.undef]]) : !llvm.array<2 x struct<(ptr)>> + llvm.return %0 : !llvm.array<2 x struct<(ptr)>> +} + +// CHECK: @test_array_attr_struct_with_struct = internal constant [3 x { i32, float }] [{ i32, float } zeroinitializer, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } undef] +llvm.mlir.global internal constant @test_array_attr_struct_with_struct() : !llvm.array<3 x struct<(i32, f32)>> { + %0 = llvm.mlir.constant([#llvm.zero, [2 : i32, 1.0 : f32], #llvm.undef]) : !llvm.array<3 x struct<(i32, f32)>> + llvm.return %0 : !llvm.array<3 x struct<(i32, f32)>> +} |