aboutsummaryrefslogtreecommitdiff
path: root/clang/lib
diff options
context:
space:
mode:
authorAmr Hesham <amr96@programmer.net>2025-05-12 20:22:02 +0200
committerGitHub <noreply@github.com>2025-05-12 20:22:02 +0200
commita6c4ca8e934474d3ed76718788fb086c28a10863 (patch)
tree605b454e94f47efe9f01593bedf75a69f5afb451 /clang/lib
parent377a0476ab34b8c7274562aefe6d3a1614e477a4 (diff)
downloadllvm-a6c4ca8e934474d3ed76718788fb086c28a10863.zip
llvm-a6c4ca8e934474d3ed76718788fb086c28a10863.tar.gz
llvm-a6c4ca8e934474d3ed76718788fb086c28a10863.tar.bz2
[CIR] Upstream insert op for VectorType (#139146)
This change adds an insert op for VectorType Issue https://github.com/llvm/llvm-project/issues/136487
Diffstat (limited to 'clang/lib')
-rw-r--r--clang/lib/CIR/CodeGen/CIRGenExpr.cpp40
-rw-r--r--clang/lib/CIR/CodeGen/CIRGenValue.h27
-rw-r--r--clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp11
-rw-r--r--clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h10
4 files changed, 78 insertions, 10 deletions
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index 711a652..0386961 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -205,6 +205,17 @@ Address CIRGenFunction::emitPointerWithAlignment(const Expr *expr,
void CIRGenFunction::emitStoreThroughLValue(RValue src, LValue dst,
bool isInit) {
if (!dst.isSimple()) {
+ if (dst.isVectorElt()) {
+ // Read/modify/write the vector, inserting the new element
+ const mlir::Location loc = dst.getVectorPointer().getLoc();
+ const mlir::Value vector =
+ builder.createLoad(loc, dst.getVectorAddress().getPointer());
+ const mlir::Value newVector = builder.create<cir::VecInsertOp>(
+ loc, vector, src.getScalarVal(), dst.getVectorIdx());
+ builder.createStore(loc, newVector, dst.getVectorAddress().getPointer());
+ return;
+ }
+
cgm.errorNYI(dst.getPointer().getLoc(),
"emitStoreThroughLValue: non-simple lvalue");
return;
@@ -418,6 +429,13 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) {
if (lv.isSimple())
return RValue::get(emitLoadOfScalar(lv, loc));
+ if (lv.isVectorElt()) {
+ const mlir::Value load =
+ builder.createLoad(getLoc(loc), lv.getVectorAddress().getPointer());
+ return RValue::get(builder.create<cir::VecExtractOp>(getLoc(loc), load,
+ lv.getVectorIdx()));
+ }
+
cgm.errorNYI(loc, "emitLoadOfLValue");
return RValue::get(nullptr);
}
@@ -638,12 +656,6 @@ static Address emitArraySubscriptPtr(CIRGenFunction &cgf,
LValue
CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
- if (e->getBase()->getType()->isVectorType() &&
- !isa<ExtVectorElementExpr>(e->getBase())) {
- cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: VectorType");
- return LValue::makeAddr(Address::invalid(), e->getType(), LValueBaseInfo());
- }
-
if (isa<ExtVectorElementExpr>(e->getBase())) {
cgm.errorNYI(e->getSourceRange(),
"emitArraySubscriptExpr: ExtVectorElementExpr");
@@ -666,18 +678,28 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
assert((e->getIdx() == e->getLHS() || e->getIdx() == e->getRHS()) &&
"index was neither LHS nor RHS");
- auto emitIdxAfterBase = [&]() -> mlir::Value {
+ auto emitIdxAfterBase = [&](bool promote) -> mlir::Value {
const mlir::Value idx = emitScalarExpr(e->getIdx());
// Extend or truncate the index type to 32 or 64-bits.
auto ptrTy = mlir::dyn_cast<cir::PointerType>(idx.getType());
- if (ptrTy && mlir::isa<cir::IntType>(ptrTy.getPointee()))
+ if (promote && ptrTy && ptrTy.isPtrTo<cir::IntType>())
cgm.errorNYI(e->getSourceRange(),
"emitArraySubscriptExpr: index type cast");
return idx;
};
- const mlir::Value idx = emitIdxAfterBase();
+ // If the base is a vector type, then we are forming a vector element
+ // with this subscript.
+ if (e->getBase()->getType()->isVectorType() &&
+ !isa<ExtVectorElementExpr>(e->getBase())) {
+ const mlir::Value idx = emitIdxAfterBase(/*promote=*/false);
+ const LValue lhs = emitLValue(e->getBase());
+ return LValue::makeVectorElt(lhs.getAddress(), idx, e->getBase()->getType(),
+ lhs.getBaseInfo());
+ }
+
+ const mlir::Value idx = emitIdxAfterBase(/*promote=*/true);
if (const Expr *array = getSimpleArrayDecayOperand(e->getBase())) {
LValue arrayLV;
if (const auto *ase = dyn_cast<ArraySubscriptExpr>(array))
diff --git a/clang/lib/CIR/CodeGen/CIRGenValue.h b/clang/lib/CIR/CodeGen/CIRGenValue.h
index ce87496..3feadfaf 100644
--- a/clang/lib/CIR/CodeGen/CIRGenValue.h
+++ b/clang/lib/CIR/CodeGen/CIRGenValue.h
@@ -116,6 +116,7 @@ class LValue {
// this is the alignment of the whole vector)
unsigned alignment;
mlir::Value v;
+ mlir::Value vectorIdx; // Index for vector subscript
mlir::Type elementType;
LValueBaseInfo baseInfo;
@@ -136,6 +137,7 @@ class LValue {
public:
bool isSimple() const { return lvType == Simple; }
+ bool isVectorElt() const { return lvType == VectorElt; }
bool isBitField() const { return lvType == BitField; }
// TODO: Add support for volatile
@@ -176,6 +178,31 @@ public:
r.initialize(t, t.getQualifiers(), address.getAlignment(), baseInfo);
return r;
}
+
+ Address getVectorAddress() const {
+ return Address(getVectorPointer(), elementType, getAlignment());
+ }
+
+ mlir::Value getVectorPointer() const {
+ assert(isVectorElt());
+ return v;
+ }
+
+ mlir::Value getVectorIdx() const {
+ assert(isVectorElt());
+ return vectorIdx;
+ }
+
+ static LValue makeVectorElt(Address vecAddress, mlir::Value index,
+ clang::QualType t, LValueBaseInfo baseInfo) {
+ LValue r;
+ r.lvType = VectorElt;
+ r.v = vecAddress.getPointer();
+ r.elementType = vecAddress.getElementType();
+ r.vectorIdx = index;
+ r.initialize(t, t.getQualifiers(), vecAddress.getAlignment(), baseInfo);
+ return r;
+ }
};
/// An aggregate value slot.
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5986655..9c46bd3 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1646,7 +1646,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMTrapOpLowering,
CIRToLLVMUnaryOpLowering,
CIRToLLVMVecCreateOpLowering,
- CIRToLLVMVecExtractOpLowering
+ CIRToLLVMVecExtractOpLowering,
+ CIRToLLVMVecInsertOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -1763,6 +1764,14 @@ mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
+ cir::VecInsertOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
+ op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 0ac1b6d..bd077e3 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -322,6 +322,16 @@ public:
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecInsertOpLowering
+ : public mlir::OpConversionPattern<cir::VecInsertOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecInsertOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecInsertOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir