aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Target/LLVMIR/ModuleImport.cpp
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2025-03-06 09:04:30 +0000
committerGitHub <noreply@github.com>2025-03-06 09:04:30 +0000
commit620c38371de46d45e9095936c823e3e40a2c5f64 (patch)
tree1e27d600022ff5e809352cc0bb294de3d7bd2490 /mlir/lib/Target/LLVMIR/ModuleImport.cpp
parentd61d2197390161db86b48d044970f48132139ccb (diff)
downloadllvm-620c38371de46d45e9095936c823e3e40a2c5f64.zip
llvm-620c38371de46d45e9095936c823e3e40a2c5f64.tar.gz
llvm-620c38371de46d45e9095936c823e3e40a2c5f64.tar.bz2
[mlir][nfc] De-duplicate tests from `Type::isIntOrFloat` (#129710)
This PR makes sure that we always use `Type::isIntOrFloat` rather than re-implementing this condition inline. Also, it removes `isScalarType` that effectively re-implemented this method.
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ModuleImport.cpp')
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp11
1 files changed, 3 insertions, 8 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d3c0ef..823e5f0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -759,11 +759,6 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
iface->setAttr(iface.getFastmathAttrName(), attr);
}
-/// Returns if `type` is a scalar integer or floating-point type.
-static bool isScalarType(Type type) {
- return isa<IntegerType, FloatType>(type);
-}
-
/// Returns `type` if it is a builtin integer or floating-point vector type that
/// can be used to create an attribute or nullptr otherwise. If provided,
/// `arrayShape` is added to the shape of the vector to create an attribute that
@@ -781,7 +776,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
// An LLVM dialect vector can only contain scalars.
Type elementType = LLVM::getVectorElementType(type);
- if (!isScalarType(elementType))
+ if (!elementType.isIntOrFloat())
return {};
SmallVector<int64_t> shape(arrayShape);
@@ -794,7 +789,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
return {};
// Return builtin integer and floating-point types as is.
- if (isScalarType(type))
+ if (type.isIntOrFloat())
return type;
// Return builtin vectors of integer and floating-point types as is.
@@ -808,7 +803,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
arrayShape.push_back(arrayType.getNumElements());
type = arrayType.getElementType();
}
- if (isScalarType(type))
+ if (type.isIntOrFloat())
return RankedTensorType::get(arrayShape, type);
return getVectorTypeForAttr(type, arrayShape);
}