diff options
| author | Matthias Springer <me@m-sp.org> | 2026-02-13 16:58:41 +0000 |
|---|---|---|
| committer | Matthias Springer <me@m-sp.org> | 2026-02-16 14:31:34 +0000 |
| commit | b5396781d72ca782729b5c0954f3e0571a479be8 (patch) | |
| tree | fb12029c8175350dd44ed326eba72d6ec04294ab | |
| parent | 3765b09d20e01976a6ab6f8b922a6b93751fbf44 (diff) | |
| download | llvm-users/matthias-springer/split_dense_string_elements.zip llvm-users/matthias-springer/split_dense_string_elements.tar.gz llvm-users/matthias-springer/split_dense_string_elements.tar.bz2 | |
[mlir][IR] Separate `DenseStringElementsAttr` from `DenseElementsAttr`users/matthias-springer/split_dense_string_elements
| -rw-r--r-- | mlir/include/mlir-c/BuiltinAttributes.h | 20 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/BuiltinAttributes.h | 25 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/BuiltinAttributes.td | 114 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/CommonAttrConstraints.td | 4 | ||||
| -rw-r--r-- | mlir/lib/AsmParser/AttributeParser.cpp | 36 | ||||
| -rw-r--r-- | mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 41 | ||||
| -rw-r--r-- | mlir/lib/IR/BuiltinAttributes.cpp | 68 | ||||
| -rw-r--r-- | mlir/test/IR/parser.mlir | 4 | ||||
| -rw-r--r-- | mlir/test/mlir-tblgen/openmp-clause-ops.td | 2 | ||||
| -rw-r--r-- | mlir/unittests/IR/AttributeTest.cpp | 37 |
10 files changed, 215 insertions, 136 deletions
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 69a5094..7461f88 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -465,6 +465,7 @@ MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr, /// Checks whether the given attribute is a dense elements attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseStringElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr); @@ -546,11 +547,20 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get( MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloat16Get( MlirType shapedType, intptr_t numElements, const uint16_t *elements); -/// Creates a dense elements attribute with the given shaped type from string -/// elements. -MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrStringGet( +/// Creates a dense string elements attribute with the given shaped type from +/// string elements. +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseStringElementsAttrStringGet( MlirType shapedType, intptr_t numElements, MlirStringRef *strs); +/// Checks whether the given dense string elements attribute contains a single +/// replicated value (splat). +MLIR_CAPI_EXPORTED bool mlirDenseStringElementsAttrIsSplat(MlirAttribute attr); + +/// Returns the single replicated value (splat) of the given dense string +/// elements attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirDenseStringElementsAttrGetSplatValue(MlirAttribute attr); + /// Creates a dense elements attribute that has the same data as the given dense /// elements attribute and a different shaped type. The new type must have the /// same total number of elements. @@ -584,7 +594,7 @@ mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirStringRef -mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr); +mlirDenseStringElementsAttrGetStringSplatValue(MlirAttribute attr); /// Returns the pos-th value (flat contiguous indexing) of a specific type /// contained by the given dense elements attribute. @@ -613,7 +623,7 @@ MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, MLIR_CAPI_EXPORTED double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED MlirStringRef -mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos); +mlirDenseStringElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos); /// Returns the raw data of the given dense elements attribute. MLIR_CAPI_EXPORTED const void * diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index ee6a8f4..33de7d9 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -152,10 +152,6 @@ public: /// Overload of the above 'get' method that is specialized for boolean values. static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values); - /// Overload of the above 'get' method that is specialized for StringRef - /// values. - static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values); - /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static @@ -232,9 +228,6 @@ public: /// Accesses the Attribute value at this iterator position. Attribute operator*() const; - private: - friend DenseElementsAttr; - /// Constructs a new iterator. AttributeElementIterator(DenseElementsAttr attr, size_t index); }; @@ -461,21 +454,6 @@ public: ElementIterator<T>(rawData, splat, getNumElements())); } - /// Try to get the held element values as a range of StringRef. - template <typename T> - using StringRefValueTemplateCheckT = - std::enable_if_t<std::is_same<T, StringRef>::value>; - template <typename T, typename = StringRefValueTemplateCheckT<T>> - FailureOr<iterator_range_impl<ElementIterator<StringRef>>> - tryGetValues() const { - auto stringRefs = getRawStringData(); - const char *ptr = reinterpret_cast<const char *>(stringRefs.data()); - bool splat = isSplat(); - return iterator_range_impl<ElementIterator<StringRef>>( - getType(), ElementIterator<StringRef>(ptr, splat, 0), - ElementIterator<StringRef>(ptr, splat, getNumElements())); - } - /// Try to get the held element values as a range of Attributes. template <typename T> using AttributeValueTemplateCheckT = @@ -578,9 +556,6 @@ public: /// form the user might expect. ArrayRef<char> getRawData() const; - /// Return the raw StringRef data held by this attribute. - ArrayRef<StringRef> getRawStringData() const; - /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor /// with static shape. ShapedType getType() const; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index dced379..52f377a 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -395,8 +395,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< //===----------------------------------------------------------------------===// def Builtin_DenseStringElementsAttr : Builtin_Attr< - "DenseStringElements", "dense_string_elements", [ElementsAttrInterface], - "DenseElementsAttr" + "DenseStringElements", "dense_string_elements", [ElementsAttrInterface] > { let summary = "An Attribute containing a dense multi-dimensional array of " "strings"; @@ -431,13 +430,101 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< }]>, ]; let extraClassDeclaration = [{ - using DenseElementsAttr::empty; - using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getElementType; - using DenseElementsAttr::getValues; - using DenseElementsAttr::isSplat; - using DenseElementsAttr::size; - using DenseElementsAttr::value_begin; + /// Iterator for walking StringRef element values. + class StringRefElementIterator + : public detail::DenseElementIndexedIteratorImpl<StringRefElementIterator, + const StringRef> { + public: + const StringRef &operator*() const { + return reinterpret_cast<const StringRef *>(this->getData())[this->getDataIndex()]; + } + StringRefElementIterator(const char *data, bool isSplat, size_t dataIndex) + : detail::DenseElementIndexedIteratorImpl<StringRefElementIterator, + const StringRef>( + data, isSplat, dataIndex) {} + }; + + /// Iterator for walking element values as Attribute (StringAttr). + class StringAttributeElementIterator + : public llvm::indexed_accessor_iterator<StringAttributeElementIterator, + const void *, Attribute, + Attribute, Attribute> { + public: + Attribute operator*() const; + StringAttributeElementIterator(const DenseStringElementsAttr *attr, + size_t index) + : llvm::indexed_accessor_iterator<StringAttributeElementIterator, + const void *, Attribute, + Attribute, Attribute>( + attr->getAsOpaquePointer(), index) {} + }; + + /// Return the type of this attribute (vector or tensor with static shape). + ShapedType getType() const; + + /// Helper methods for ElementsAttr interface. + bool empty() const { return getNumElements() == 0; } + int64_t getNumElements() const { return getType().getNumElements(); } + Type getElementType() const { return getType().getElementType(); } + bool isSplat() const { return getRawStringData().size() == 1; } + int64_t size() const { return getNumElements(); } + + /// Return the raw StringRef data held by this attribute. + ArrayRef<StringRef> getRawStringData() const; + + /// Try to get the held element values as a range of StringRef. + template <typename T> + using StringRefValueTemplateCheckT = + std::enable_if_t<std::is_same<T, StringRef>::value>; + template <typename T, typename = StringRefValueTemplateCheckT<T>> + FailureOr<detail::ElementsAttrRange<StringRefElementIterator>> + tryGetValues() const { + auto stringRefs = getRawStringData(); + const char *ptr = reinterpret_cast<const char *>(stringRefs.data()); + bool splat = isSplat(); + return detail::ElementsAttrRange<StringRefElementIterator>( + getType(), StringRefElementIterator(ptr, splat, 0), + StringRefElementIterator(ptr, splat, getNumElements())); + } + + /// Try to get the held element values as a range of Attributes. + template <typename T> + using AttributeValueTemplateCheckT = + std::enable_if_t<std::is_same<T, Attribute>::value>; + template <typename T, typename = AttributeValueTemplateCheckT<T>> + FailureOr<detail::ElementsAttrRange<StringAttributeElementIterator>> + tryGetValues() const { + return detail::ElementsAttrRange<StringAttributeElementIterator>( + getType(), StringAttributeElementIterator(this, 0), + StringAttributeElementIterator(this, getNumElements())); + } + + template <typename T> + auto getValues() const { + auto range = tryGetValues<T>(); + assert(succeeded(range) && "element type cannot be iterated"); + return std::move(*range); + } + + template <typename T> + auto value_begin() const { return getValues<T>().begin(); } + + template <typename T> + auto value_end() const { return getValues<T>().end(); } + + /// Return the splat value. Asserts that the attribute is a splat. + template <typename T> + auto getSplatValue() const { + assert(isSplat() && "expected the attribute to be a splat"); + return *value_begin<T>(); + } + + template <typename T> + auto try_value_begin() const { + auto range = tryGetValues<T>(); + using iterator = decltype(range->begin()); + return failed(range) ? FailureOr<iterator>(failure()) : range->begin(); + } /// The set of data types that can be iterated by this attribute. using ContiguousIterableTypesT = std::tuple<StringRef>; @@ -449,11 +536,6 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< auto try_value_begin_impl(OverloadToken<T>) const { return try_value_begin<T>(); } - - protected: - friend DenseElementsAttr; - - public: }]; let genAccessors = 0; let genStorageClass = 0; @@ -931,9 +1013,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>, // Float types. APFloat, float, double, - std::complex<APFloat>, std::complex<float>, std::complex<double>, - // String types. - StringRef + std::complex<APFloat>, std::complex<float>, std::complex<double> >; using ElementsAttr::Trait<SparseElementsAttr>::getValues; using ElementsAttr::Trait<SparseElementsAttr>::value_begin; diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index ba6cf55..634881f 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -565,8 +565,8 @@ def StringElementsAttr : ElementsAttrBase< CPred<"::llvm::isa<::mlir::DenseStringElementsAttr>($_self)" >, "string elements attribute"> { - let storageType = [{ ::mlir::DenseElementsAttr }]; - let returnType = [{ ::mlir::DenseElementsAttr }]; + let storageType = [{ ::mlir::DenseStringElementsAttr }]; + let returnType = [{ ::mlir::DenseStringElementsAttr }]; let convertFromStorage = "$_self"; } diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index dc9744a..15c2e02 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -472,8 +472,8 @@ public: ParseResult parse(bool allowHex); /// Build a dense attribute instance with the parsed elements and the given - /// shaped type. - DenseElementsAttr getAttr(SMLoc loc, ShapedType type); + /// shaped type. Returns DenseElementsAttr or DenseStringElementsAttr. + Attribute getAttr(SMLoc loc, ShapedType type); ArrayRef<int64_t> getShape() const { return shape; } @@ -487,7 +487,7 @@ private: std::vector<APFloat> &floatValues); /// Build a Dense String attribute for the given type. - DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); + DenseStringElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); /// Build a Dense attribute with hex data for the given type. DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); @@ -539,7 +539,7 @@ ParseResult TensorLiteralParser::parse(bool allowHex) { /// Build a dense attribute instance with the parsed elements and the given /// shaped type. -DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { +Attribute TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { Type eltType = type.getElementType(); // Check to see if we parse the literal from a hex string. @@ -679,8 +679,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, } /// Build a Dense String attribute for the given type. -DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, - Type eltTy) { +DenseStringElementsAttr +TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, Type eltTy) { if (hexStorage.has_value()) { auto stringValue = hexStorage->getStringValue(); return DenseStringElementsAttr::get(type, {stringValue}); @@ -1174,6 +1174,13 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { if (!type) return nullptr; + // SparseElementsAttr only supports int/float element types. + if (!type.getElementType().isIntOrIndexOrFloat()) { + emitError(loc) << "sparse elements attribute does not support string " + "element type"; + return nullptr; + } + // Construct the sparse elements attr using zero element indice/value // attributes. ShapedType indicesType = @@ -1219,9 +1226,10 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // Otherwise, set the shape to the one parsed by the literal parser. indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); } - auto indices = indiceParser.getAttr(indicesLoc, indicesType); - if (!indices) + auto indicesAttr = indiceParser.getAttr(indicesLoc, indicesType); + if (!indicesAttr) return nullptr; + auto indices = llvm::cast<DenseIntElementsAttr>(indicesAttr); // If the values are a splat, set the shape explicitly based on the number of // indices. The number of indices is encoded in the first dimension of the @@ -1231,10 +1239,18 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { valuesParser.getShape().empty() ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) : RankedTensorType::get(valuesParser.getShape(), valuesEltType); - auto values = valuesParser.getAttr(valuesLoc, valuesType); - if (!values) + auto valuesAttr = valuesParser.getAttr(valuesLoc, valuesType); + if (!valuesAttr) return nullptr; + // SparseElementsAttr only supports DenseElementsAttr for values (not string). + auto values = llvm::dyn_cast<DenseElementsAttr>(valuesAttr); + if (!values) { + emitError(valuesLoc) + << "dense string elements not supported in sparse elements attribute"; + return nullptr; + } + // Build the sparse elements attribute by the indices and values. return getChecked<SparseElementsAttr>(loc, type, indices, values); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 44a3dea..38d4e40 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -551,6 +551,10 @@ bool mlirAttributeIsADenseElements(MlirAttribute attr) { return llvm::isa<DenseElementsAttr>(unwrap(attr)); } +bool mlirAttributeIsADenseStringElements(MlirAttribute attr) { + return llvm::isa<DenseStringElementsAttr>(unwrap(attr)); +} + bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { return llvm::isa<DenseIntElementsAttr>(unwrap(attr)); } @@ -720,16 +724,16 @@ MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); } -MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, - intptr_t numElements, - MlirStringRef *strs) { +MlirAttribute mlirDenseStringElementsAttrStringGet(MlirType shapedType, + intptr_t numElements, + MlirStringRef *strs) { SmallVector<StringRef, 8> values; values.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); - return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), - values)); + return wrap(DenseStringElementsAttr::get( + llvm::cast<ShapedType>(unwrap(shapedType)), values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, @@ -745,10 +749,16 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat(); } - +bool mlirDenseStringElementsAttrIsSplat(MlirAttribute attr) { + return llvm::cast<DenseStringElementsAttr>(unwrap(attr)).isSplat(); +} MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { - return wrap( - llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>()); + return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr)) + .getSplatValue<mlir::Attribute>()); +} +MlirAttribute mlirDenseStringElementsAttrGetSplatValue(MlirAttribute attr) { + return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr)) + .getSplatValue<mlir::Attribute>()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>(); @@ -777,9 +787,10 @@ float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>(); } -MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { - return wrap( - llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>()); +MlirStringRef +mlirDenseStringElementsAttrGetStringSplatValue(MlirAttribute attr) { + return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr)) + .getSplatValue<llvm::StringRef>()); } //===----------------------------------------------------------------------===// @@ -822,10 +833,10 @@ float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<double>()[pos]; } -MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, - intptr_t pos) { - return wrap( - llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]); +MlirStringRef mlirDenseStringElementsAttrGetStringValue(MlirAttribute attr, + intptr_t pos) { + return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr)) + .getValues<StringRef>()[pos]); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index bbbc919..b459d32 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -596,16 +596,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base)); - Type eltTy = owner.getElementType(); - - // Handle strings specially. - if (llvm::isa<DenseStringElementsAttr>(owner)) { - ArrayRef<StringRef> vals = owner.getRawStringData(); - return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); - } - - // All other types should implement DenseElementTypeInterface. - auto denseEltTy = llvm::cast<DenseElementType>(eltTy); + auto denseEltTy = llvm::cast<DenseElementType>(owner.getElementType()); ArrayRef<char> rawData = owner.getRawData(); // Storage is byte-aligned: align bit size up to next byte boundary. size_t bitSize = denseEltTy.getDenseElementBitSize(); @@ -864,28 +855,13 @@ template class DenseArrayAttrImpl<double>; /// Method for support type inquiry through isa, cast and dyn_cast. bool DenseElementsAttr::classof(Attribute attr) { - return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr); + return llvm::isa<DenseIntOrFPElementsAttr>(attr); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<Attribute> values) { assert(hasSameNumElementsOrSplat(type, values)); - Type eltType = type.getElementType(); - - // Handle strings specially. - if (!llvm::isa<DenseElementType>(eltType)) { - SmallVector<StringRef, 8> stringValues; - stringValues.reserve(values.size()); - for (Attribute attr : values) { - assert(llvm::isa<StringAttr>(attr) && - "expected string value for non-DenseElementType element"); - stringValues.push_back(llvm::cast<StringAttr>(attr).getValue()); - } - return get(type, stringValues); - } - - // All other types go through DenseElementTypeInterface. - auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType); + auto denseEltType = llvm::dyn_cast<DenseElementType>(type.getElementType()); assert(denseEltType && "attempted to get DenseElementsAttr with unsupported element type"); SmallVector<char> data; @@ -906,12 +882,6 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, values.size())); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef<StringRef> values) { - assert(!type.getElementType().isIntOrFloat()); - return DenseStringElementsAttr::get(type, values); -} - /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. @@ -1048,9 +1018,6 @@ bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, /// values are the same. bool DenseElementsAttr::isSplat() const { // Splat iff the data array has exactly one element. - if (isa<DenseStringElementsAttr>(*this)) - return getRawStringData().size() == 1; - // FP/Int case. size_t storageSize = llvm::divideCeil( getDenseElementBitWidth(getType().getElementType()), CHAR_BIT); return getRawData().size() == storageSize; @@ -1100,10 +1067,6 @@ ArrayRef<char> DenseElementsAttr::getRawData() const { return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; } -ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { - return static_cast<DenseStringElementsAttrStorage *>(impl)->data; -} - /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. @@ -1391,6 +1354,27 @@ bool DenseIntElementsAttr::classof(Attribute attr) { } //===----------------------------------------------------------------------===// +// DenseStringElementsAttr +//===----------------------------------------------------------------------===// + +ShapedType DenseStringElementsAttr::getType() const { + return static_cast<const DenseStringElementsAttrStorage *>(impl)->type; +} + +ArrayRef<StringRef> DenseStringElementsAttr::getRawStringData() const { + return static_cast<const DenseStringElementsAttrStorage *>(impl)->data; +} + +Attribute +DenseStringElementsAttr::StringAttributeElementIterator::operator*() const { + auto attr = llvm::cast<DenseStringElementsAttr>( + Attribute::getFromOpaquePointer(this->base)); + auto data = attr.getRawStringData(); + return StringAttr::get(attr.isSplat() ? data.front() : data[this->index], + attr.getElementType()); +} + +//===----------------------------------------------------------------------===// // DenseResourceElementsAttr //===----------------------------------------------------------------------===// @@ -1557,10 +1541,6 @@ Attribute SparseElementsAttr::getZeroAttr() const { ArrayRef<Attribute>{zero, zero}); } - // Handle string type. - if (llvm::isa<DenseStringElementsAttr>(getValues())) - return StringAttr::get("", eltType); - // Otherwise, this is an integer. return IntegerAttr::get(eltType, 0); } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 3bb6e38..c4a415e 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -797,10 +797,6 @@ func.func @sparsetensorattr() -> () { // CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> () "foof321"(){bar = sparse<> : tensor<f32>} : () -> () -// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> () - "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> () -// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> () - "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> () return } diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td index 3e5896a..c502b21 100644 --- a/mlir/test/mlir-tblgen/openmp-clause-ops.td +++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td @@ -59,7 +59,7 @@ def OpenMP_MyFirstClause : OpenMP_Clause< // CHECK-NEXT: ::mlir::IntegerAttr complexOptIntAttr; // CHECK-NEXT: ::mlir::ElementsAttr elementsAttr; -// CHECK-NEXT: ::mlir::DenseElementsAttr stringElementsAttr; +// CHECK-NEXT: ::mlir::DenseStringElementsAttr stringElementsAttr; // CHECK-NEXT: } def OpenMP_MySecondClause : OpenMP_Clause< diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 404aa8c..f72aeba 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -38,6 +38,21 @@ static void testSplat(Type eltType, const EltTy &splatElt) { EXPECT_TRUE(newValue == splatElt); } +template <> +void testSplat<StringRef>(Type eltType, const StringRef &splatElt) { + RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); + + DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, splatElt); + EXPECT_TRUE(splat.isSplat()); + + auto detectedSplat = + DenseStringElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); + EXPECT_EQ(detectedSplat, splat); + + for (auto newValue : detectedSplat.getValues<StringRef>()) + EXPECT_TRUE(newValue == splatElt); +} + namespace { TEST(DenseSplatTest, BoolSplat) { MLIRContext context; @@ -184,8 +199,16 @@ TEST(DenseSplatTest, StringAttrSplat) { context.allowUnregisteredDialects(); Type stringType = OpaqueType::get(StringAttr::get(&context, "test"), "string"); + RankedTensorType shape = RankedTensorType::get({2, 1}, stringType); Attribute stringAttr = StringAttr::get("test-string", stringType); - testSplat(stringType, stringAttr); + StringRef value = llvm::cast<StringAttr>(stringAttr).getValue(); + DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, value); + EXPECT_TRUE(splat.isSplat()); + auto detectedSplat = + DenseStringElementsAttr::get(shape, llvm::ArrayRef({value, value})); + EXPECT_EQ(detectedSplat, splat); + for (auto newValue : detectedSplat.getValues<StringRef>()) + EXPECT_TRUE(newValue == value); } TEST(DenseComplexTest, ComplexFloatSplat) { @@ -396,11 +419,9 @@ TEST(SparseElementsAttrTest, GetZero) { IntegerType intTy = IntegerType::get(&context, 32); FloatType floatTy = Float32Type::get(&context); - Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); - ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); @@ -413,13 +434,8 @@ TEST(SparseElementsAttrTest, GetZero) { RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); - RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); - auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); - auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); - auto sparseString = - SparseElementsAttr::get(tensorString, indices, stringValue); // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. @@ -432,11 +448,6 @@ TEST(SparseElementsAttrTest, GetZero) { cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]); EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); - - auto zeroStringValue = - cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]); - EXPECT_TRUE(zeroStringValue.empty()); - EXPECT_TRUE(zeroStringValue.getType() == stringTy); } //===----------------------------------------------------------------------===// |
