diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2025-03-06 09:04:30 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-06 09:04:30 +0000 |
commit | 620c38371de46d45e9095936c823e3e40a2c5f64 (patch) | |
tree | 1e27d600022ff5e809352cc0bb294de3d7bd2490 /mlir/lib/Target/LLVMIR/ModuleImport.cpp | |
parent | d61d2197390161db86b48d044970f48132139ccb (diff) | |
download | llvm-620c38371de46d45e9095936c823e3e40a2c5f64.zip llvm-620c38371de46d45e9095936c823e3e40a2c5f64.tar.gz llvm-620c38371de46d45e9095936c823e3e40a2c5f64.tar.bz2 |
[mlir][nfc] De-duplicate tests from `Type::isIntOrFloat` (#129710)
This PR makes sure that we always use `Type::isIntOrFloat` rather than
re-implementing this condition inline. Also, it removes `isScalarType`
that effectively re-implemented this method.
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ModuleImport.cpp')
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 11 |
1 files changed, 3 insertions, 8 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 2d3c0ef..823e5f0 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -759,11 +759,6 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, iface->setAttr(iface.getFastmathAttrName(), attr); } -/// Returns if `type` is a scalar integer or floating-point type. -static bool isScalarType(Type type) { - return isa<IntegerType, FloatType>(type); -} - /// Returns `type` if it is a builtin integer or floating-point vector type that /// can be used to create an attribute or nullptr otherwise. If provided, /// `arrayShape` is added to the shape of the vector to create an attribute that @@ -781,7 +776,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) { // An LLVM dialect vector can only contain scalars. Type elementType = LLVM::getVectorElementType(type); - if (!isScalarType(elementType)) + if (!elementType.isIntOrFloat()) return {}; SmallVector<int64_t> shape(arrayShape); @@ -794,7 +789,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) { return {}; // Return builtin integer and floating-point types as is. - if (isScalarType(type)) + if (type.isIntOrFloat()) return type; // Return builtin vectors of integer and floating-point types as is. @@ -808,7 +803,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) { arrayShape.push_back(arrayType.getNumElements()); type = arrayType.getElementType(); } - if (isScalarType(type)) + if (type.isIntOrFloat()) return RankedTensorType::get(arrayShape, type); return getVectorTypeForAttr(type, arrayShape); } |