diff options
Diffstat (limited to 'llvm/lib/CodeGen/CodeGenPrepare.cpp')
-rw-r--r-- | llvm/lib/CodeGen/CodeGenPrepare.cpp | 154 |
1 files changed, 89 insertions, 65 deletions
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 3272f36..9a4ed2f 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -5314,88 +5314,112 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, /// zero index. bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr) { - const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); - if (!GEP || !GEP->hasIndices()) + // FIXME: Support scalable vectors. + if (isa<ScalableVectorType>(Ptr->getType())) return false; - // If the GEP and the gather/scatter aren't in the same BB, don't optimize. - // FIXME: We should support this by sinking the GEP. - if (MemoryInst->getParent() != GEP->getParent()) - return false; - - SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end()); + Value *NewAddr; - bool RewriteGEP = false; + if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { + // Don't optimize GEPs that don't have indices. + if (!GEP->hasIndices()) + return false; - if (Ops[0]->getType()->isVectorTy()) { - Ops[0] = const_cast<Value *>(getSplatValue(Ops[0])); - if (!Ops[0]) + // If the GEP and the gather/scatter aren't in the same BB, don't optimize. + // FIXME: We should support this by sinking the GEP. + if (MemoryInst->getParent() != GEP->getParent()) return false; - RewriteGEP = true; - } - unsigned FinalIndex = Ops.size() - 1; + SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end()); - // Ensure all but the last index is 0. - // FIXME: This isn't strictly required. All that's required is that they are - // all scalars or splats. - for (unsigned i = 1; i < FinalIndex; ++i) { - auto *C = dyn_cast<Constant>(Ops[i]); - if (!C) - return false; - if (isa<VectorType>(C->getType())) - C = C->getSplatValue(); - auto *CI = dyn_cast_or_null<ConstantInt>(C); - if (!CI || !CI->isZero()) - return false; - // Scalarize the index if needed. - Ops[i] = CI; - } - - // Try to scalarize the final index. - if (Ops[FinalIndex]->getType()->isVectorTy()) { - if (Value *V = const_cast<Value *>(getSplatValue(Ops[FinalIndex]))) { - auto *C = dyn_cast<ConstantInt>(V); - // Don't scalarize all zeros vector. - if (!C || !C->isZero()) { - Ops[FinalIndex] = V; - RewriteGEP = true; - } + bool RewriteGEP = false; + + if (Ops[0]->getType()->isVectorTy()) { + Ops[0] = getSplatValue(Ops[0]); + if (!Ops[0]) + return false; + RewriteGEP = true; } - } - // If we made any changes or the we have extra operands, we need to generate - // new instructions. - if (!RewriteGEP && Ops.size() == 2) - return false; + unsigned FinalIndex = Ops.size() - 1; - unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements(); + // Ensure all but the last index is 0. + // FIXME: This isn't strictly required. All that's required is that they are + // all scalars or splats. + for (unsigned i = 1; i < FinalIndex; ++i) { + auto *C = dyn_cast<Constant>(Ops[i]); + if (!C) + return false; + if (isa<VectorType>(C->getType())) + C = C->getSplatValue(); + auto *CI = dyn_cast_or_null<ConstantInt>(C); + if (!CI || !CI->isZero()) + return false; + // Scalarize the index if needed. + Ops[i] = CI; + } + + // Try to scalarize the final index. + if (Ops[FinalIndex]->getType()->isVectorTy()) { + if (Value *V = getSplatValue(Ops[FinalIndex])) { + auto *C = dyn_cast<ConstantInt>(V); + // Don't scalarize all zeros vector. + if (!C || !C->isZero()) { + Ops[FinalIndex] = V; + RewriteGEP = true; + } + } + } - IRBuilder<> Builder(MemoryInst); + // If we made any changes or the we have extra operands, we need to generate + // new instructions. + if (!RewriteGEP && Ops.size() == 2) + return false; - Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); + unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements(); - Value *NewAddr; + IRBuilder<> Builder(MemoryInst); - // If the final index isn't a vector, emit a scalar GEP containing all ops - // and a vector GEP with all zeroes final index. - if (!Ops[FinalIndex]->getType()->isVectorTy()) { - NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); - auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); - NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); - } else { - Value *Base = Ops[0]; - Value *Index = Ops[FinalIndex]; + Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); - // Create a scalar GEP if there are more than 2 operands. - if (Ops.size() != 2) { - // Replace the last index with 0. - Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); - Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + // If the final index isn't a vector, emit a scalar GEP containing all ops + // and a vector GEP with all zeroes final index. + if (!Ops[FinalIndex]->getType()->isVectorTy()) { + NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); + } else { + Value *Base = Ops[0]; + Value *Index = Ops[FinalIndex]; + + // Create a scalar GEP if there are more than 2 operands. + if (Ops.size() != 2) { + // Replace the last index with 0. + Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); + Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + } + + // Now create the GEP with scalar pointer and vector index. + NewAddr = Builder.CreateGEP(Base, Index); } + } else if (!isa<Constant>(Ptr)) { + // Not a GEP, maybe its a splat and we can create a GEP to enable + // SelectionDAGBuilder to use it as a uniform base. + Value *V = getSplatValue(Ptr); + if (!V) + return false; + + unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements(); + + IRBuilder<> Builder(MemoryInst); - // Now create the GEP with scalar pointer and vector index. - NewAddr = Builder.CreateGEP(Base, Index); + // Emit a vector GEP with a scalar pointer and all 0s vector index. + Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy)); + } else { + // Constant, SelectionDAGBuilder knows to check if its a splat. + return false; } MemoryInst->replaceUsesOfWith(Ptr, NewAddr); |