diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ModuleImport.cpp')
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 11 |
1 files changed, 3 insertions, 8 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 2d3c0ef..823e5f0 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -759,11 +759,6 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, iface->setAttr(iface.getFastmathAttrName(), attr); } -/// Returns if `type` is a scalar integer or floating-point type. -static bool isScalarType(Type type) { - return isa<IntegerType, FloatType>(type); -} - /// Returns `type` if it is a builtin integer or floating-point vector type that /// can be used to create an attribute or nullptr otherwise. If provided, /// `arrayShape` is added to the shape of the vector to create an attribute that @@ -781,7 +776,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) { // An LLVM dialect vector can only contain scalars. Type elementType = LLVM::getVectorElementType(type); - if (!isScalarType(elementType)) + if (!elementType.isIntOrFloat()) return {}; SmallVector<int64_t> shape(arrayShape); @@ -794,7 +789,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) { return {}; // Return builtin integer and floating-point types as is. - if (isScalarType(type)) + if (type.isIntOrFloat()) return type; // Return builtin vectors of integer and floating-point types as is. @@ -808,7 +803,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) { arrayShape.push_back(arrayType.getNumElements()); type = arrayType.getElementType(); } - if (isScalarType(type)) + if (type.isIntOrFloat()) return RankedTensorType::get(arrayShape, type); return getVectorTypeForAttr(type, arrayShape); } |