aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2026-02-13 16:58:41 +0000
committerMatthias Springer <me@m-sp.org>2026-02-16 14:31:34 +0000
commitb5396781d72ca782729b5c0954f3e0571a479be8 (patch)
treefb12029c8175350dd44ed326eba72d6ec04294ab
parent3765b09d20e01976a6ab6f8b922a6b93751fbf44 (diff)
downloadllvm-users/matthias-springer/split_dense_string_elements.zip
llvm-users/matthias-springer/split_dense_string_elements.tar.gz
llvm-users/matthias-springer/split_dense_string_elements.tar.bz2
[mlir][IR] Separate `DenseStringElementsAttr` from `DenseElementsAttr`users/matthias-springer/split_dense_string_elements
-rw-r--r--mlir/include/mlir-c/BuiltinAttributes.h20
-rw-r--r--mlir/include/mlir/IR/BuiltinAttributes.h25
-rw-r--r--mlir/include/mlir/IR/BuiltinAttributes.td114
-rw-r--r--mlir/include/mlir/IR/CommonAttrConstraints.td4
-rw-r--r--mlir/lib/AsmParser/AttributeParser.cpp36
-rw-r--r--mlir/lib/CAPI/IR/BuiltinAttributes.cpp41
-rw-r--r--mlir/lib/IR/BuiltinAttributes.cpp68
-rw-r--r--mlir/test/IR/parser.mlir4
-rw-r--r--mlir/test/mlir-tblgen/openmp-clause-ops.td2
-rw-r--r--mlir/unittests/IR/AttributeTest.cpp37
10 files changed, 215 insertions, 136 deletions
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 69a5094..7461f88 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -465,6 +465,7 @@ MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr,
/// Checks whether the given attribute is a dense elements attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr);
+MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseStringElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
@@ -546,11 +547,20 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get(
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloat16Get(
MlirType shapedType, intptr_t numElements, const uint16_t *elements);
-/// Creates a dense elements attribute with the given shaped type from string
-/// elements.
-MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrStringGet(
+/// Creates a dense string elements attribute with the given shaped type from
+/// string elements.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseStringElementsAttrStringGet(
MlirType shapedType, intptr_t numElements, MlirStringRef *strs);
+/// Checks whether the given dense string elements attribute contains a single
+/// replicated value (splat).
+MLIR_CAPI_EXPORTED bool mlirDenseStringElementsAttrIsSplat(MlirAttribute attr);
+
+/// Returns the single replicated value (splat) of the given dense string
+/// elements attribute.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirDenseStringElementsAttrGetSplatValue(MlirAttribute attr);
+
/// Creates a dense elements attribute that has the same data as the given dense
/// elements attribute and a different shaped type. The new type must have the
/// same total number of elements.
@@ -584,7 +594,7 @@ mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED double
mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
-mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr);
+mlirDenseStringElementsAttrGetStringSplatValue(MlirAttribute attr);
/// Returns the pos-th value (flat contiguous indexing) of a specific type
/// contained by the given dense elements attribute.
@@ -613,7 +623,7 @@ MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,
MLIR_CAPI_EXPORTED double
mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED MlirStringRef
-mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos);
+mlirDenseStringElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos);
/// Returns the raw data of the given dense elements attribute.
MLIR_CAPI_EXPORTED const void *
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index ee6a8f4..33de7d9 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -152,10 +152,6 @@ public:
/// Overload of the above 'get' method that is specialized for boolean values.
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
- /// Overload of the above 'get' method that is specialized for StringRef
- /// values.
- static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values);
-
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
@@ -232,9 +228,6 @@ public:
/// Accesses the Attribute value at this iterator position.
Attribute operator*() const;
- private:
- friend DenseElementsAttr;
-
/// Constructs a new iterator.
AttributeElementIterator(DenseElementsAttr attr, size_t index);
};
@@ -461,21 +454,6 @@ public:
ElementIterator<T>(rawData, splat, getNumElements()));
}
- /// Try to get the held element values as a range of StringRef.
- template <typename T>
- using StringRefValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, StringRef>::value>;
- template <typename T, typename = StringRefValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<ElementIterator<StringRef>>>
- tryGetValues() const {
- auto stringRefs = getRawStringData();
- const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
- bool splat = isSplat();
- return iterator_range_impl<ElementIterator<StringRef>>(
- getType(), ElementIterator<StringRef>(ptr, splat, 0),
- ElementIterator<StringRef>(ptr, splat, getNumElements()));
- }
-
/// Try to get the held element values as a range of Attributes.
template <typename T>
using AttributeValueTemplateCheckT =
@@ -578,9 +556,6 @@ public:
/// form the user might expect.
ArrayRef<char> getRawData() const;
- /// Return the raw StringRef data held by this attribute.
- ArrayRef<StringRef> getRawStringData() const;
-
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
ShapedType getType() const;
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index dced379..52f377a 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -395,8 +395,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DenseStringElementsAttr : Builtin_Attr<
- "DenseStringElements", "dense_string_elements", [ElementsAttrInterface],
- "DenseElementsAttr"
+ "DenseStringElements", "dense_string_elements", [ElementsAttrInterface]
> {
let summary = "An Attribute containing a dense multi-dimensional array of "
"strings";
@@ -431,13 +430,101 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
}]>,
];
let extraClassDeclaration = [{
- using DenseElementsAttr::empty;
- using DenseElementsAttr::getNumElements;
- using DenseElementsAttr::getElementType;
- using DenseElementsAttr::getValues;
- using DenseElementsAttr::isSplat;
- using DenseElementsAttr::size;
- using DenseElementsAttr::value_begin;
+ /// Iterator for walking StringRef element values.
+ class StringRefElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<StringRefElementIterator,
+ const StringRef> {
+ public:
+ const StringRef &operator*() const {
+ return reinterpret_cast<const StringRef *>(this->getData())[this->getDataIndex()];
+ }
+ StringRefElementIterator(const char *data, bool isSplat, size_t dataIndex)
+ : detail::DenseElementIndexedIteratorImpl<StringRefElementIterator,
+ const StringRef>(
+ data, isSplat, dataIndex) {}
+ };
+
+ /// Iterator for walking element values as Attribute (StringAttr).
+ class StringAttributeElementIterator
+ : public llvm::indexed_accessor_iterator<StringAttributeElementIterator,
+ const void *, Attribute,
+ Attribute, Attribute> {
+ public:
+ Attribute operator*() const;
+ StringAttributeElementIterator(const DenseStringElementsAttr *attr,
+ size_t index)
+ : llvm::indexed_accessor_iterator<StringAttributeElementIterator,
+ const void *, Attribute,
+ Attribute, Attribute>(
+ attr->getAsOpaquePointer(), index) {}
+ };
+
+ /// Return the type of this attribute (vector or tensor with static shape).
+ ShapedType getType() const;
+
+ /// Helper methods for ElementsAttr interface.
+ bool empty() const { return getNumElements() == 0; }
+ int64_t getNumElements() const { return getType().getNumElements(); }
+ Type getElementType() const { return getType().getElementType(); }
+ bool isSplat() const { return getRawStringData().size() == 1; }
+ int64_t size() const { return getNumElements(); }
+
+ /// Return the raw StringRef data held by this attribute.
+ ArrayRef<StringRef> getRawStringData() const;
+
+ /// Try to get the held element values as a range of StringRef.
+ template <typename T>
+ using StringRefValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, StringRef>::value>;
+ template <typename T, typename = StringRefValueTemplateCheckT<T>>
+ FailureOr<detail::ElementsAttrRange<StringRefElementIterator>>
+ tryGetValues() const {
+ auto stringRefs = getRawStringData();
+ const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
+ bool splat = isSplat();
+ return detail::ElementsAttrRange<StringRefElementIterator>(
+ getType(), StringRefElementIterator(ptr, splat, 0),
+ StringRefElementIterator(ptr, splat, getNumElements()));
+ }
+
+ /// Try to get the held element values as a range of Attributes.
+ template <typename T>
+ using AttributeValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, Attribute>::value>;
+ template <typename T, typename = AttributeValueTemplateCheckT<T>>
+ FailureOr<detail::ElementsAttrRange<StringAttributeElementIterator>>
+ tryGetValues() const {
+ return detail::ElementsAttrRange<StringAttributeElementIterator>(
+ getType(), StringAttributeElementIterator(this, 0),
+ StringAttributeElementIterator(this, getNumElements()));
+ }
+
+ template <typename T>
+ auto getValues() const {
+ auto range = tryGetValues<T>();
+ assert(succeeded(range) && "element type cannot be iterated");
+ return std::move(*range);
+ }
+
+ template <typename T>
+ auto value_begin() const { return getValues<T>().begin(); }
+
+ template <typename T>
+ auto value_end() const { return getValues<T>().end(); }
+
+ /// Return the splat value. Asserts that the attribute is a splat.
+ template <typename T>
+ auto getSplatValue() const {
+ assert(isSplat() && "expected the attribute to be a splat");
+ return *value_begin<T>();
+ }
+
+ template <typename T>
+ auto try_value_begin() const {
+ auto range = tryGetValues<T>();
+ using iterator = decltype(range->begin());
+ return failed(range) ? FailureOr<iterator>(failure()) : range->begin();
+ }
/// The set of data types that can be iterated by this attribute.
using ContiguousIterableTypesT = std::tuple<StringRef>;
@@ -449,11 +536,6 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
auto try_value_begin_impl(OverloadToken<T>) const {
return try_value_begin<T>();
}
-
- protected:
- friend DenseElementsAttr;
-
- public:
}];
let genAccessors = 0;
let genStorageClass = 0;
@@ -931,9 +1013,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>,
// Float types.
APFloat, float, double,
- std::complex<APFloat>, std::complex<float>, std::complex<double>,
- // String types.
- StringRef
+ std::complex<APFloat>, std::complex<float>, std::complex<double>
>;
using ElementsAttr::Trait<SparseElementsAttr>::getValues;
using ElementsAttr::Trait<SparseElementsAttr>::value_begin;
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index ba6cf55..634881f 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -565,8 +565,8 @@ def StringElementsAttr : ElementsAttrBase<
CPred<"::llvm::isa<::mlir::DenseStringElementsAttr>($_self)" >,
"string elements attribute"> {
- let storageType = [{ ::mlir::DenseElementsAttr }];
- let returnType = [{ ::mlir::DenseElementsAttr }];
+ let storageType = [{ ::mlir::DenseStringElementsAttr }];
+ let returnType = [{ ::mlir::DenseStringElementsAttr }];
let convertFromStorage = "$_self";
}
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index dc9744a..15c2e02 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -472,8 +472,8 @@ public:
ParseResult parse(bool allowHex);
/// Build a dense attribute instance with the parsed elements and the given
- /// shaped type.
- DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
+ /// shaped type. Returns DenseElementsAttr or DenseStringElementsAttr.
+ Attribute getAttr(SMLoc loc, ShapedType type);
ArrayRef<int64_t> getShape() const { return shape; }
@@ -487,7 +487,7 @@ private:
std::vector<APFloat> &floatValues);
/// Build a Dense String attribute for the given type.
- DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
+ DenseStringElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
/// Build a Dense attribute with hex data for the given type.
DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
@@ -539,7 +539,7 @@ ParseResult TensorLiteralParser::parse(bool allowHex) {
/// Build a dense attribute instance with the parsed elements and the given
/// shaped type.
-DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
+Attribute TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
Type eltType = type.getElementType();
// Check to see if we parse the literal from a hex string.
@@ -679,8 +679,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
}
/// Build a Dense String attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
- Type eltTy) {
+DenseStringElementsAttr
+TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, Type eltTy) {
if (hexStorage.has_value()) {
auto stringValue = hexStorage->getStringValue();
return DenseStringElementsAttr::get(type, {stringValue});
@@ -1174,6 +1174,13 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
if (!type)
return nullptr;
+ // SparseElementsAttr only supports int/float element types.
+ if (!type.getElementType().isIntOrIndexOrFloat()) {
+ emitError(loc) << "sparse elements attribute does not support string "
+ "element type";
+ return nullptr;
+ }
+
// Construct the sparse elements attr using zero element indice/value
// attributes.
ShapedType indicesType =
@@ -1219,9 +1226,10 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// Otherwise, set the shape to the one parsed by the literal parser.
indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
}
- auto indices = indiceParser.getAttr(indicesLoc, indicesType);
- if (!indices)
+ auto indicesAttr = indiceParser.getAttr(indicesLoc, indicesType);
+ if (!indicesAttr)
return nullptr;
+ auto indices = llvm::cast<DenseIntElementsAttr>(indicesAttr);
// If the values are a splat, set the shape explicitly based on the number of
// indices. The number of indices is encoded in the first dimension of the
@@ -1231,10 +1239,18 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
valuesParser.getShape().empty()
? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
- auto values = valuesParser.getAttr(valuesLoc, valuesType);
- if (!values)
+ auto valuesAttr = valuesParser.getAttr(valuesLoc, valuesType);
+ if (!valuesAttr)
return nullptr;
+ // SparseElementsAttr only supports DenseElementsAttr for values (not string).
+ auto values = llvm::dyn_cast<DenseElementsAttr>(valuesAttr);
+ if (!values) {
+ emitError(valuesLoc)
+ << "dense string elements not supported in sparse elements attribute";
+ return nullptr;
+ }
+
// Build the sparse elements attribute by the indices and values.
return getChecked<SparseElementsAttr>(loc, type, indices, values);
}
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 44a3dea..38d4e40 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -551,6 +551,10 @@ bool mlirAttributeIsADenseElements(MlirAttribute attr) {
return llvm::isa<DenseElementsAttr>(unwrap(attr));
}
+bool mlirAttributeIsADenseStringElements(MlirAttribute attr) {
+ return llvm::isa<DenseStringElementsAttr>(unwrap(attr));
+}
+
bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
return llvm::isa<DenseIntElementsAttr>(unwrap(attr));
}
@@ -720,16 +724,16 @@ MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
}
-MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
- intptr_t numElements,
- MlirStringRef *strs) {
+MlirAttribute mlirDenseStringElementsAttrStringGet(MlirType shapedType,
+ intptr_t numElements,
+ MlirStringRef *strs) {
SmallVector<StringRef, 8> values;
values.reserve(numElements);
for (intptr_t i = 0; i < numElements; ++i)
values.push_back(unwrap(strs[i]));
- return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
- values));
+ return wrap(DenseStringElementsAttr::get(
+ llvm::cast<ShapedType>(unwrap(shapedType)), values));
}
MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
@@ -745,10 +749,16 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
}
-
+bool mlirDenseStringElementsAttrIsSplat(MlirAttribute attr) {
+ return llvm::cast<DenseStringElementsAttr>(unwrap(attr)).isSplat();
+}
MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
- return wrap(
- llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>());
+ return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
+ .getSplatValue<mlir::Attribute>());
+}
+MlirAttribute mlirDenseStringElementsAttrGetSplatValue(MlirAttribute attr) {
+ return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr))
+ .getSplatValue<mlir::Attribute>());
}
int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>();
@@ -777,9 +787,10 @@ float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>();
}
-MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
- return wrap(
- llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>());
+MlirStringRef
+mlirDenseStringElementsAttrGetStringSplatValue(MlirAttribute attr) {
+ return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr))
+ .getSplatValue<llvm::StringRef>());
}
//===----------------------------------------------------------------------===//
@@ -822,10 +833,10 @@ float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<double>()[pos];
}
-MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
- intptr_t pos) {
- return wrap(
- llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]);
+MlirStringRef mlirDenseStringElementsAttrGetStringValue(MlirAttribute attr,
+ intptr_t pos) {
+ return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr))
+ .getValues<StringRef>()[pos]);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index bbbc919..b459d32 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -596,16 +596,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
- Type eltTy = owner.getElementType();
-
- // Handle strings specially.
- if (llvm::isa<DenseStringElementsAttr>(owner)) {
- ArrayRef<StringRef> vals = owner.getRawStringData();
- return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
- }
-
- // All other types should implement DenseElementTypeInterface.
- auto denseEltTy = llvm::cast<DenseElementType>(eltTy);
+ auto denseEltTy = llvm::cast<DenseElementType>(owner.getElementType());
ArrayRef<char> rawData = owner.getRawData();
// Storage is byte-aligned: align bit size up to next byte boundary.
size_t bitSize = denseEltTy.getDenseElementBitSize();
@@ -864,28 +855,13 @@ template class DenseArrayAttrImpl<double>;
/// Method for support type inquiry through isa, cast and dyn_cast.
bool DenseElementsAttr::classof(Attribute attr) {
- return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr);
+ return llvm::isa<DenseIntOrFPElementsAttr>(attr);
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameNumElementsOrSplat(type, values));
- Type eltType = type.getElementType();
-
- // Handle strings specially.
- if (!llvm::isa<DenseElementType>(eltType)) {
- SmallVector<StringRef, 8> stringValues;
- stringValues.reserve(values.size());
- for (Attribute attr : values) {
- assert(llvm::isa<StringAttr>(attr) &&
- "expected string value for non-DenseElementType element");
- stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
- }
- return get(type, stringValues);
- }
-
- // All other types go through DenseElementTypeInterface.
- auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType);
+ auto denseEltType = llvm::dyn_cast<DenseElementType>(type.getElementType());
assert(denseEltType &&
"attempted to get DenseElementsAttr with unsupported element type");
SmallVector<char> data;
@@ -906,12 +882,6 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
values.size()));
}
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<StringRef> values) {
- assert(!type.getElementType().isIntOrFloat());
- return DenseStringElementsAttr::get(type, values);
-}
-
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'.
@@ -1048,9 +1018,6 @@ bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
/// values are the same.
bool DenseElementsAttr::isSplat() const {
// Splat iff the data array has exactly one element.
- if (isa<DenseStringElementsAttr>(*this))
- return getRawStringData().size() == 1;
- // FP/Int case.
size_t storageSize = llvm::divideCeil(
getDenseElementBitWidth(getType().getElementType()), CHAR_BIT);
return getRawData().size() == storageSize;
@@ -1100,10 +1067,6 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
}
-ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
- return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
-}
-
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
@@ -1391,6 +1354,27 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
}
//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+//===----------------------------------------------------------------------===//
+
+ShapedType DenseStringElementsAttr::getType() const {
+ return static_cast<const DenseStringElementsAttrStorage *>(impl)->type;
+}
+
+ArrayRef<StringRef> DenseStringElementsAttr::getRawStringData() const {
+ return static_cast<const DenseStringElementsAttrStorage *>(impl)->data;
+}
+
+Attribute
+DenseStringElementsAttr::StringAttributeElementIterator::operator*() const {
+ auto attr = llvm::cast<DenseStringElementsAttr>(
+ Attribute::getFromOpaquePointer(this->base));
+ auto data = attr.getRawStringData();
+ return StringAttr::get(attr.isSplat() ? data.front() : data[this->index],
+ attr.getElementType());
+}
+
+//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//
@@ -1557,10 +1541,6 @@ Attribute SparseElementsAttr::getZeroAttr() const {
ArrayRef<Attribute>{zero, zero});
}
- // Handle string type.
- if (llvm::isa<DenseStringElementsAttr>(getValues()))
- return StringAttr::get("", eltType);
-
// Otherwise, this is an integer.
return IntegerAttr::get(eltType, 0);
}
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 3bb6e38..c4a415e 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -797,10 +797,6 @@ func.func @sparsetensorattr() -> () {
// CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> ()
"foof321"(){bar = sparse<> : tensor<f32>} : () -> ()
-// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> ()
- "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> ()
-// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> ()
- "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> ()
return
}
diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td
index 3e5896a..c502b21 100644
--- a/mlir/test/mlir-tblgen/openmp-clause-ops.td
+++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td
@@ -59,7 +59,7 @@ def OpenMP_MyFirstClause : OpenMP_Clause<
// CHECK-NEXT: ::mlir::IntegerAttr complexOptIntAttr;
// CHECK-NEXT: ::mlir::ElementsAttr elementsAttr;
-// CHECK-NEXT: ::mlir::DenseElementsAttr stringElementsAttr;
+// CHECK-NEXT: ::mlir::DenseStringElementsAttr stringElementsAttr;
// CHECK-NEXT: }
def OpenMP_MySecondClause : OpenMP_Clause<
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 404aa8c..f72aeba 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -38,6 +38,21 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
EXPECT_TRUE(newValue == splatElt);
}
+template <>
+void testSplat<StringRef>(Type eltType, const StringRef &splatElt) {
+ RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
+
+ DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, splatElt);
+ EXPECT_TRUE(splat.isSplat());
+
+ auto detectedSplat =
+ DenseStringElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
+ EXPECT_EQ(detectedSplat, splat);
+
+ for (auto newValue : detectedSplat.getValues<StringRef>())
+ EXPECT_TRUE(newValue == splatElt);
+}
+
namespace {
TEST(DenseSplatTest, BoolSplat) {
MLIRContext context;
@@ -184,8 +199,16 @@ TEST(DenseSplatTest, StringAttrSplat) {
context.allowUnregisteredDialects();
Type stringType =
OpaqueType::get(StringAttr::get(&context, "test"), "string");
+ RankedTensorType shape = RankedTensorType::get({2, 1}, stringType);
Attribute stringAttr = StringAttr::get("test-string", stringType);
- testSplat(stringType, stringAttr);
+ StringRef value = llvm::cast<StringAttr>(stringAttr).getValue();
+ DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, value);
+ EXPECT_TRUE(splat.isSplat());
+ auto detectedSplat =
+ DenseStringElementsAttr::get(shape, llvm::ArrayRef({value, value}));
+ EXPECT_EQ(detectedSplat, splat);
+ for (auto newValue : detectedSplat.getValues<StringRef>())
+ EXPECT_TRUE(newValue == value);
}
TEST(DenseComplexTest, ComplexFloatSplat) {
@@ -396,11 +419,9 @@ TEST(SparseElementsAttrTest, GetZero) {
IntegerType intTy = IntegerType::get(&context, 32);
FloatType floatTy = Float32Type::get(&context);
- Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
- ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
auto indicesType =
RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
@@ -413,13 +434,8 @@ TEST(SparseElementsAttrTest, GetZero) {
RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
- RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
- auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
-
auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
- auto sparseString =
- SparseElementsAttr::get(tensorString, indices, stringValue);
// Only index (0, 0) contains an element, others are supposed to return
// the zero/empty value.
@@ -432,11 +448,6 @@ TEST(SparseElementsAttrTest, GetZero) {
cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
-
- auto zeroStringValue =
- cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
- EXPECT_TRUE(zeroStringValue.empty());
- EXPECT_TRUE(zeroStringValue.getType() == stringTy);
}
//===----------------------------------------------------------------------===//