aboutsummaryrefslogtreecommitdiff
path: root/clang/lib
diff options
context:
space:
mode:
authorMorris Hafner <mmha@users.noreply.github.com>2025-04-08 18:53:54 +0200
committerGitHub <noreply@github.com>2025-04-08 09:53:54 -0700
commit441f87968df5dfb74d710fa32147789be98c20a6 (patch)
treec5d66b40035337c456cdbd8975f160da1969fc22 /clang/lib
parent271399831b780d25dce85715727c841843e10d4b (diff)
downloadllvm-441f87968df5dfb74d710fa32147789be98c20a6.zip
llvm-441f87968df5dfb74d710fa32147789be98c20a6.tar.gz
llvm-441f87968df5dfb74d710fa32147789be98c20a6.tar.bz2
[CIR] Upstream CmpOp (#133159)
This patch adds support for comparison operators with ClangIR, both integral and floating point. --------- Co-authored-by: Morris Hafner <mhafner@nvidia.com> Co-authored-by: Henrich Lauko <xlauko@mail.muni.cz> Co-authored-by: Andy Kaylor <akaylor@nvidia.com>
Diffstat (limited to 'clang/lib')
-rw-r--r--clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp79
-rw-r--r--clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp82
-rw-r--r--clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h13
3 files changed, 174 insertions, 0 deletions
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 3863d21..ed49f39 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -707,6 +707,85 @@ public:
HANDLEBINOP(Xor)
HANDLEBINOP(Or)
#undef HANDLEBINOP
+
+ mlir::Value emitCmp(const BinaryOperator *e) {
+ const mlir::Location loc = cgf.getLoc(e->getExprLoc());
+ mlir::Value result;
+ QualType lhsTy = e->getLHS()->getType();
+ QualType rhsTy = e->getRHS()->getType();
+
+ auto clangCmpToCIRCmp =
+ [](clang::BinaryOperatorKind clangCmp) -> cir::CmpOpKind {
+ switch (clangCmp) {
+ case BO_LT:
+ return cir::CmpOpKind::lt;
+ case BO_GT:
+ return cir::CmpOpKind::gt;
+ case BO_LE:
+ return cir::CmpOpKind::le;
+ case BO_GE:
+ return cir::CmpOpKind::ge;
+ case BO_EQ:
+ return cir::CmpOpKind::eq;
+ case BO_NE:
+ return cir::CmpOpKind::ne;
+ default:
+ llvm_unreachable("unsupported comparison kind for cir.cmp");
+ }
+ };
+
+ if (lhsTy->getAs<MemberPointerType>()) {
+ assert(!cir::MissingFeatures::dataMemberType());
+ assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
+ mlir::Value lhs = cgf.emitScalarExpr(e->getLHS());
+ mlir::Value rhs = cgf.emitScalarExpr(e->getRHS());
+ cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
+ result = builder.createCompare(loc, kind, lhs, rhs);
+ } else if (!lhsTy->isAnyComplexType() && !rhsTy->isAnyComplexType()) {
+ BinOpInfo boInfo = emitBinOps(e);
+ mlir::Value lhs = boInfo.lhs;
+ mlir::Value rhs = boInfo.rhs;
+
+ if (lhsTy->isVectorType()) {
+ assert(!cir::MissingFeatures::vectorType());
+ cgf.cgm.errorNYI(loc, "vector comparisons");
+ result = builder.getBool(false, loc);
+ } else if (boInfo.isFixedPointOp()) {
+ assert(!cir::MissingFeatures::fixedPointType());
+ cgf.cgm.errorNYI(loc, "fixed point comparisons");
+ result = builder.getBool(false, loc);
+ } else {
+ // integers and pointers
+ if (cgf.cgm.getCodeGenOpts().StrictVTablePointers &&
+ mlir::isa<cir::PointerType>(lhs.getType()) &&
+ mlir::isa<cir::PointerType>(rhs.getType())) {
+ cgf.cgm.errorNYI(loc, "strict vtable pointer comparisons");
+ }
+
+ cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
+ result = builder.createCompare(loc, kind, lhs, rhs);
+ }
+ } else {
+ // Complex Comparison: can only be an equality comparison.
+ assert(!cir::MissingFeatures::complexType());
+ cgf.cgm.errorNYI(loc, "complex comparison");
+ result = builder.getBool(false, loc);
+ }
+
+ return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),
+ e->getExprLoc());
+ }
+
+// Comparisons.
+#define VISITCOMP(CODE) \
+ mlir::Value VisitBin##CODE(const BinaryOperator *E) { return emitCmp(E); }
+ VISITCOMP(LT)
+ VISITCOMP(GT)
+ VISITCOMP(LE)
+ VISITCOMP(GE)
+ VISITCOMP(EQ)
+ VISITCOMP(NE)
+#undef VISITCOMP
};
LValue ScalarExprEmitter::emitCompoundAssignLValue(
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 81b80e2..7ca3640 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
@@ -1193,6 +1194,86 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
return mlir::LogicalResult::success();
}
+/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
+static mlir::LLVM::ICmpPredicate
+convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) {
+ using CIR = cir::CmpOpKind;
+ using LLVMICmp = mlir::LLVM::ICmpPredicate;
+ switch (kind) {
+ case CIR::eq:
+ return LLVMICmp::eq;
+ case CIR::ne:
+ return LLVMICmp::ne;
+ case CIR::lt:
+ return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
+ case CIR::le:
+ return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
+ case CIR::gt:
+ return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
+ case CIR::ge:
+ return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
+ }
+ llvm_unreachable("Unknown CmpOpKind");
+}
+
+/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
+/// kind.
+static mlir::LLVM::FCmpPredicate
+convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) {
+ using CIR = cir::CmpOpKind;
+ using LLVMFCmp = mlir::LLVM::FCmpPredicate;
+ switch (kind) {
+ case CIR::eq:
+ return LLVMFCmp::oeq;
+ case CIR::ne:
+ return LLVMFCmp::une;
+ case CIR::lt:
+ return LLVMFCmp::olt;
+ case CIR::le:
+ return LLVMFCmp::ole;
+ case CIR::gt:
+ return LLVMFCmp::ogt;
+ case CIR::ge:
+ return LLVMFCmp::oge;
+ }
+ llvm_unreachable("Unknown CmpOpKind");
+}
+
+mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
+ cir::CmpOp cmpOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ mlir::Type type = cmpOp.getLhs().getType();
+
+ assert(!cir::MissingFeatures::dataMemberType());
+ assert(!cir::MissingFeatures::methodType());
+
+ // Lower to LLVM comparison op.
+ if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
+ bool isSigned = mlir::isa<cir::IntType>(type)
+ ? mlir::cast<cir::IntType>(type).isSigned()
+ : mlir::cast<mlir::IntegerType>(type).isSigned();
+ mlir::LLVM::ICmpPredicate kind =
+ convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
+ cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
+ } else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
+ mlir::LLVM::ICmpPredicate kind =
+ convertCmpKindToICmpPredicate(cmpOp.getKind(),
+ /* isSigned=*/false);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
+ cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
+ } else if (mlir::isa<cir::CIRFPTypeInterface>(type)) {
+ mlir::LLVM::FCmpPredicate kind =
+ convertCmpKindToFCmpPredicate(cmpOp.getKind());
+ rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
+ cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
+ } else {
+ return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
+ }
+
+ return mlir::success();
+}
+
static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
mlir::DataLayout &dataLayout) {
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1334,6 +1415,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMBinOpLowering,
CIRToLLVMBrCondOpLowering,
CIRToLLVMBrOpLowering,
+ CIRToLLVMCmpOpLowering,
CIRToLLVMConstantOpLowering,
CIRToLLVMFuncOpLowering,
CIRToLLVMTrapOpLowering,
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 6f489fb..d53c4b3 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -186,6 +186,19 @@ public:
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
+public:
+ CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
+ mlir::MLIRContext *context)
+ : OpConversionPattern(typeConverter, context) {
+ setHasBoundedRewriteRecursion();
+ }
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::CmpOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> {
public:
using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern;