aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/Constants.cpp
diff options
context:
space:
mode:
authorPaul Walker <paul.walker@arm.com>2024-02-22 14:07:16 +0000
committerGitHub <noreply@github.com>2024-02-22 14:07:16 +0000
commitcbb24e139d0753d755d17fbe6bfac48ab44d0721 (patch)
tree3e4f75895dd6b9520ec68b53d0591d51203f454e /llvm/lib/IR/Constants.cpp
parent5b8e5604c297aa8fd09bf641d12d0a663e0ea801 (diff)
downloadllvm-cbb24e139d0753d755d17fbe6bfac48ab44d0721.zip
llvm-cbb24e139d0753d755d17fbe6bfac48ab44d0721.tar.gz
llvm-cbb24e139d0753d755d17fbe6bfac48ab44d0721.tar.bz2
[LLVM][IR] Add native vector support to ConstantInt & ConstantFP. (#74502)
NOTE: For brevity the following talks about ConstantInt but everything extends to cover ConstantFP as well. Whilst ConstantInt::get() supports the creation of vectors whereby each lane has the same value, it achieves this via other constants: * ConstantVector for fixed-length vectors * ConstantExprs for scalable vectors However, ConstantExprs are being deprecated and ConstantVector is not space efficient for larger vector types. By extending ConstantInt we can represent vector splats by only storing the underlying scalar value. More specifically: * ConstantInt gains an ElementCount variant of get(). * LLVMContext is extended to map <EC,APInt>->ConstantInt. * BitcodeReader/Writer support is extended to allow vector types. Whilst this patch adds the base support, more work is required before it's production ready. For example, there's likely to be many places where isa<ConstantInt> assumes a scalar type. Accordingly the default behaviour of ConstantInt::get() remains unchanged but a set of flags are added to allow wider testing and thus help with the migration: --use-constant-int-for-fixed-length-splat --use-constant-fp-for-fixed-length-splat --use-constant-int-for-scalable-splat --use-constant-fp-for-scalable-splat NOTE: No change is required to the bitcode format because types and values are handled separately. NOTE: For similar reasons as above, code generation doesn't work out-the-box.
Diffstat (limited to 'llvm/lib/IR/Constants.cpp')
-rw-r--r--llvm/lib/IR/Constants.cpp94
1 files changed, 89 insertions, 5 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index a38b912..e6b92aa 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -35,6 +35,20 @@
using namespace llvm;
using namespace PatternMatch;
+// As set of temporary options to help migrate how splats are represented.
+static cl::opt<bool> UseConstantIntForFixedLengthSplat(
+ "use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
+ cl::desc("Use ConstantInt's native fixed-length vector splat support."));
+static cl::opt<bool> UseConstantFPForFixedLengthSplat(
+ "use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
+ cl::desc("Use ConstantFP's native fixed-length vector splat support."));
+static cl::opt<bool> UseConstantIntForScalableSplat(
+ "use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
+ cl::desc("Use ConstantInt's native scalable vector splat support."));
+static cl::opt<bool> UseConstantFPForScalableSplat(
+ "use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
+ cl::desc("Use ConstantFP's native scalable vector splat support."));
+
//===----------------------------------------------------------------------===//
// Constant Class
//===----------------------------------------------------------------------===//
@@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const {
// ConstantInt
//===----------------------------------------------------------------------===//
-ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
+ConstantInt::ConstantInt(Type *Ty, const APInt &V)
: ConstantData(Ty, ConstantIntVal), Val(V) {
- assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
+ assert(V.getBitWidth() ==
+ cast<IntegerType>(Ty->getScalarType())->getBitWidth() &&
+ "Invalid constant for type");
}
ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
@@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
return Slot.get();
}
+// Get a ConstantInt vector with each lane set to the same APInt.
+ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC,
+ const APInt &V) {
+ // Get an existing value or the insertion position.
+ std::unique_ptr<ConstantInt> &Slot =
+ Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
+ if (!Slot) {
+ IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
+ VectorType *VTy = VectorType::get(ITy, EC);
+ Slot.reset(new ConstantInt(VTy, V));
+ }
+
+#ifndef NDEBUG
+ IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
+ VectorType *VTy = VectorType::get(ITy, EC);
+ assert(Slot->getType() == VTy);
+#endif
+ return Slot.get();
+}
+
Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);
@@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
return Slot.get();
}
+// Get a ConstantFP vector with each lane set to the same APFloat.
+ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC,
+ const APFloat &V) {
+ // Get an existing value or the insertion position.
+ std::unique_ptr<ConstantFP> &Slot =
+ Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
+ if (!Slot) {
+ Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
+ VectorType *VTy = VectorType::get(EltTy, EC);
+ Slot.reset(new ConstantFP(VTy, V));
+ }
+
+#ifndef NDEBUG
+ Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
+ VectorType *VTy = VectorType::get(EltTy, EC);
+ assert(Slot->getType() == VTy);
+#endif
+ return Slot.get();
+}
+
Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics();
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
@@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
: ConstantData(Ty, ConstantFPVal), Val(V) {
- assert(&V.getSemantics() == &Ty->getFltSemantics() &&
+ assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
"FP type Mismatch");
}
@@ -1356,11 +1412,13 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
bool isZero = C->isNullValue();
bool isUndef = isa<UndefValue>(C);
bool isPoison = isa<PoisonValue>(C);
+ bool isSplatFP = UseConstantFPForFixedLengthSplat && isa<ConstantFP>(C);
+ bool isSplatInt = UseConstantIntForFixedLengthSplat && isa<ConstantInt>(C);
- if (isZero || isUndef) {
+ if (isZero || isUndef || isSplatFP || isSplatInt) {
for (unsigned i = 1, e = V.size(); i != e; ++i)
if (V[i] != C) {
- isZero = isUndef = isPoison = false;
+ isZero = isUndef = isPoison = isSplatFP = isSplatInt = false;
break;
}
}
@@ -1371,6 +1429,12 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
return PoisonValue::get(T);
if (isUndef)
return UndefValue::get(T);
+ if (isSplatFP)
+ return ConstantFP::get(C->getContext(), T->getElementCount(),
+ cast<ConstantFP>(C)->getValue());
+ if (isSplatInt)
+ return ConstantInt::get(C->getContext(), T->getElementCount(),
+ cast<ConstantInt>(C)->getValue());
// Check to see if all of the elements are ConstantFP or ConstantInt and if
// the element type is compatible with ConstantDataVector. If so, use it.
@@ -1384,6 +1448,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
if (!EC.isScalable()) {
+ // Maintain special handling of zero.
+ if (!V->isNullValue()) {
+ if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
+ return ConstantInt::get(V->getContext(), EC,
+ cast<ConstantInt>(V)->getValue());
+ if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
+ return ConstantFP::get(V->getContext(), EC,
+ cast<ConstantFP>(V)->getValue());
+ }
+
// If this splat is compatible with ConstantDataVector, use it instead of
// ConstantVector.
if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
@@ -1394,6 +1468,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
return get(Elts);
}
+ // Maintain special handling of zero.
+ if (!V->isNullValue()) {
+ if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
+ return ConstantInt::get(V->getContext(), EC,
+ cast<ConstantInt>(V)->getValue());
+ if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
+ return ConstantFP::get(V->getContext(), EC,
+ cast<ConstantFP>(V)->getValue());
+ }
+
Type *VTy = VectorType::get(V->getType(), EC);
if (V->isNullValue())