aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Target/LLVMIR/ModuleImport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ModuleImport.cpp')
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp11
1 files changed, 3 insertions, 8 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d3c0ef..823e5f0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -759,11 +759,6 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
iface->setAttr(iface.getFastmathAttrName(), attr);
}
-/// Returns if `type` is a scalar integer or floating-point type.
-static bool isScalarType(Type type) {
- return isa<IntegerType, FloatType>(type);
-}
-
/// Returns `type` if it is a builtin integer or floating-point vector type that
/// can be used to create an attribute or nullptr otherwise. If provided,
/// `arrayShape` is added to the shape of the vector to create an attribute that
@@ -781,7 +776,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
// An LLVM dialect vector can only contain scalars.
Type elementType = LLVM::getVectorElementType(type);
- if (!isScalarType(elementType))
+ if (!elementType.isIntOrFloat())
return {};
SmallVector<int64_t> shape(arrayShape);
@@ -794,7 +789,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
return {};
// Return builtin integer and floating-point types as is.
- if (isScalarType(type))
+ if (type.isIntOrFloat())
return type;
// Return builtin vectors of integer and floating-point types as is.
@@ -808,7 +803,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
arrayShape.push_back(arrayType.getNumElements());
type = arrayType.getElementType();
}
- if (isScalarType(type))
+ if (type.isIntOrFloat())
return RankedTensorType::get(arrayShape, type);
return getVectorTypeForAttr(type, arrayShape);
}