diff options
author | Amr Hesham <amr96@programmer.net> | 2025-05-12 20:22:02 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-12 20:22:02 +0200 |
commit | a6c4ca8e934474d3ed76718788fb086c28a10863 (patch) | |
tree | 605b454e94f47efe9f01593bedf75a69f5afb451 /clang/lib | |
parent | 377a0476ab34b8c7274562aefe6d3a1614e477a4 (diff) | |
download | llvm-a6c4ca8e934474d3ed76718788fb086c28a10863.zip llvm-a6c4ca8e934474d3ed76718788fb086c28a10863.tar.gz llvm-a6c4ca8e934474d3ed76718788fb086c28a10863.tar.bz2 |
[CIR] Upstream insert op for VectorType (#139146)
This change adds an insert op for VectorType
Issue https://github.com/llvm/llvm-project/issues/136487
Diffstat (limited to 'clang/lib')
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenExpr.cpp | 40 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenValue.h | 27 | ||||
-rw-r--r-- | clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 11 | ||||
-rw-r--r-- | clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 10 |
4 files changed, 78 insertions, 10 deletions
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 711a652..0386961 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -205,6 +205,17 @@ Address CIRGenFunction::emitPointerWithAlignment(const Expr *expr, void CIRGenFunction::emitStoreThroughLValue(RValue src, LValue dst, bool isInit) { if (!dst.isSimple()) { + if (dst.isVectorElt()) { + // Read/modify/write the vector, inserting the new element + const mlir::Location loc = dst.getVectorPointer().getLoc(); + const mlir::Value vector = + builder.createLoad(loc, dst.getVectorAddress().getPointer()); + const mlir::Value newVector = builder.create<cir::VecInsertOp>( + loc, vector, src.getScalarVal(), dst.getVectorIdx()); + builder.createStore(loc, newVector, dst.getVectorAddress().getPointer()); + return; + } + cgm.errorNYI(dst.getPointer().getLoc(), "emitStoreThroughLValue: non-simple lvalue"); return; @@ -418,6 +429,13 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) { if (lv.isSimple()) return RValue::get(emitLoadOfScalar(lv, loc)); + if (lv.isVectorElt()) { + const mlir::Value load = + builder.createLoad(getLoc(loc), lv.getVectorAddress().getPointer()); + return RValue::get(builder.create<cir::VecExtractOp>(getLoc(loc), load, + lv.getVectorIdx())); + } + cgm.errorNYI(loc, "emitLoadOfLValue"); return RValue::get(nullptr); } @@ -638,12 +656,6 @@ static Address emitArraySubscriptPtr(CIRGenFunction &cgf, LValue CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) { - if (e->getBase()->getType()->isVectorType() && - !isa<ExtVectorElementExpr>(e->getBase())) { - cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: VectorType"); - return LValue::makeAddr(Address::invalid(), e->getType(), LValueBaseInfo()); - } - if (isa<ExtVectorElementExpr>(e->getBase())) { cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: ExtVectorElementExpr"); @@ -666,18 +678,28 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) { assert((e->getIdx() == e->getLHS() || e->getIdx() == e->getRHS()) && "index was neither LHS nor RHS"); - auto emitIdxAfterBase = [&]() -> mlir::Value { + auto emitIdxAfterBase = [&](bool promote) -> mlir::Value { const mlir::Value idx = emitScalarExpr(e->getIdx()); // Extend or truncate the index type to 32 or 64-bits. auto ptrTy = mlir::dyn_cast<cir::PointerType>(idx.getType()); - if (ptrTy && mlir::isa<cir::IntType>(ptrTy.getPointee())) + if (promote && ptrTy && ptrTy.isPtrTo<cir::IntType>()) cgm.errorNYI(e->getSourceRange(), "emitArraySubscriptExpr: index type cast"); return idx; }; - const mlir::Value idx = emitIdxAfterBase(); + // If the base is a vector type, then we are forming a vector element + // with this subscript. + if (e->getBase()->getType()->isVectorType() && + !isa<ExtVectorElementExpr>(e->getBase())) { + const mlir::Value idx = emitIdxAfterBase(/*promote=*/false); + const LValue lhs = emitLValue(e->getBase()); + return LValue::makeVectorElt(lhs.getAddress(), idx, e->getBase()->getType(), + lhs.getBaseInfo()); + } + + const mlir::Value idx = emitIdxAfterBase(/*promote=*/true); if (const Expr *array = getSimpleArrayDecayOperand(e->getBase())) { LValue arrayLV; if (const auto *ase = dyn_cast<ArraySubscriptExpr>(array)) diff --git a/clang/lib/CIR/CodeGen/CIRGenValue.h b/clang/lib/CIR/CodeGen/CIRGenValue.h index ce87496..3feadfaf 100644 --- a/clang/lib/CIR/CodeGen/CIRGenValue.h +++ b/clang/lib/CIR/CodeGen/CIRGenValue.h @@ -116,6 +116,7 @@ class LValue { // this is the alignment of the whole vector) unsigned alignment; mlir::Value v; + mlir::Value vectorIdx; // Index for vector subscript mlir::Type elementType; LValueBaseInfo baseInfo; @@ -136,6 +137,7 @@ class LValue { public: bool isSimple() const { return lvType == Simple; } + bool isVectorElt() const { return lvType == VectorElt; } bool isBitField() const { return lvType == BitField; } // TODO: Add support for volatile @@ -176,6 +178,31 @@ public: r.initialize(t, t.getQualifiers(), address.getAlignment(), baseInfo); return r; } + + Address getVectorAddress() const { + return Address(getVectorPointer(), elementType, getAlignment()); + } + + mlir::Value getVectorPointer() const { + assert(isVectorElt()); + return v; + } + + mlir::Value getVectorIdx() const { + assert(isVectorElt()); + return vectorIdx; + } + + static LValue makeVectorElt(Address vecAddress, mlir::Value index, + clang::QualType t, LValueBaseInfo baseInfo) { + LValue r; + r.lvType = VectorElt; + r.v = vecAddress.getPointer(); + r.elementType = vecAddress.getElementType(); + r.vectorIdx = index; + r.initialize(t, t.getQualifiers(), vecAddress.getAlignment(), baseInfo); + return r; + } }; /// An aggregate value slot. diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 5986655..9c46bd3 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1646,7 +1646,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMTrapOpLowering, CIRToLLVMUnaryOpLowering, CIRToLLVMVecCreateOpLowering, - CIRToLLVMVecExtractOpLowering + CIRToLLVMVecExtractOpLowering, + CIRToLLVMVecInsertOpLowering // clang-format on >(converter, patterns.getContext()); @@ -1763,6 +1764,14 @@ mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite( + cir::VecInsertOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>( + op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex()); + return mlir::success(); +} + std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() { return std::make_unique<ConvertCIRToLLVMPass>(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 0ac1b6d..bd077e3 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -322,6 +322,16 @@ public: mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMVecInsertOpLowering + : public mlir::OpConversionPattern<cir::VecInsertOp> { +public: + using mlir::OpConversionPattern<cir::VecInsertOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecInsertOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + } // namespace direct } // namespace cir |