diff options
author | Morris Hafner <mmha@users.noreply.github.com> | 2025-04-08 18:53:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-08 09:53:54 -0700 |
commit | 441f87968df5dfb74d710fa32147789be98c20a6 (patch) | |
tree | c5d66b40035337c456cdbd8975f160da1969fc22 /clang/lib | |
parent | 271399831b780d25dce85715727c841843e10d4b (diff) | |
download | llvm-441f87968df5dfb74d710fa32147789be98c20a6.zip llvm-441f87968df5dfb74d710fa32147789be98c20a6.tar.gz llvm-441f87968df5dfb74d710fa32147789be98c20a6.tar.bz2 |
[CIR] Upstream CmpOp (#133159)
This patch adds support for comparison operators with ClangIR, both
integral and floating point.
---------
Co-authored-by: Morris Hafner <mhafner@nvidia.com>
Co-authored-by: Henrich Lauko <xlauko@mail.muni.cz>
Co-authored-by: Andy Kaylor <akaylor@nvidia.com>
Diffstat (limited to 'clang/lib')
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 79 | ||||
-rw-r--r-- | clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 82 | ||||
-rw-r--r-- | clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 13 |
3 files changed, 174 insertions, 0 deletions
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 3863d21..ed49f39 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -707,6 +707,85 @@ public: HANDLEBINOP(Xor) HANDLEBINOP(Or) #undef HANDLEBINOP + + mlir::Value emitCmp(const BinaryOperator *e) { + const mlir::Location loc = cgf.getLoc(e->getExprLoc()); + mlir::Value result; + QualType lhsTy = e->getLHS()->getType(); + QualType rhsTy = e->getRHS()->getType(); + + auto clangCmpToCIRCmp = + [](clang::BinaryOperatorKind clangCmp) -> cir::CmpOpKind { + switch (clangCmp) { + case BO_LT: + return cir::CmpOpKind::lt; + case BO_GT: + return cir::CmpOpKind::gt; + case BO_LE: + return cir::CmpOpKind::le; + case BO_GE: + return cir::CmpOpKind::ge; + case BO_EQ: + return cir::CmpOpKind::eq; + case BO_NE: + return cir::CmpOpKind::ne; + default: + llvm_unreachable("unsupported comparison kind for cir.cmp"); + } + }; + + if (lhsTy->getAs<MemberPointerType>()) { + assert(!cir::MissingFeatures::dataMemberType()); + assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE); + mlir::Value lhs = cgf.emitScalarExpr(e->getLHS()); + mlir::Value rhs = cgf.emitScalarExpr(e->getRHS()); + cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode()); + result = builder.createCompare(loc, kind, lhs, rhs); + } else if (!lhsTy->isAnyComplexType() && !rhsTy->isAnyComplexType()) { + BinOpInfo boInfo = emitBinOps(e); + mlir::Value lhs = boInfo.lhs; + mlir::Value rhs = boInfo.rhs; + + if (lhsTy->isVectorType()) { + assert(!cir::MissingFeatures::vectorType()); + cgf.cgm.errorNYI(loc, "vector comparisons"); + result = builder.getBool(false, loc); + } else if (boInfo.isFixedPointOp()) { + assert(!cir::MissingFeatures::fixedPointType()); + cgf.cgm.errorNYI(loc, "fixed point comparisons"); + result = builder.getBool(false, loc); + } else { + // integers and pointers + if (cgf.cgm.getCodeGenOpts().StrictVTablePointers && + mlir::isa<cir::PointerType>(lhs.getType()) && + mlir::isa<cir::PointerType>(rhs.getType())) { + cgf.cgm.errorNYI(loc, "strict vtable pointer comparisons"); + } + + cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode()); + result = builder.createCompare(loc, kind, lhs, rhs); + } + } else { + // Complex Comparison: can only be an equality comparison. + assert(!cir::MissingFeatures::complexType()); + cgf.cgm.errorNYI(loc, "complex comparison"); + result = builder.getBool(false, loc); + } + + return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(), + e->getExprLoc()); + } + +// Comparisons. +#define VISITCOMP(CODE) \ + mlir::Value VisitBin##CODE(const BinaryOperator *E) { return emitCmp(E); } + VISITCOMP(LT) + VISITCOMP(GT) + VISITCOMP(LE) + VISITCOMP(GE) + VISITCOMP(EQ) + VISITCOMP(NE) +#undef VISITCOMP }; LValue ScalarExprEmitter::emitCompoundAssignLValue( diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 81b80e2..7ca3640 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" @@ -1193,6 +1194,86 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite( return mlir::LogicalResult::success(); } +/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind. +static mlir::LLVM::ICmpPredicate +convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) { + using CIR = cir::CmpOpKind; + using LLVMICmp = mlir::LLVM::ICmpPredicate; + switch (kind) { + case CIR::eq: + return LLVMICmp::eq; + case CIR::ne: + return LLVMICmp::ne; + case CIR::lt: + return (isSigned ? LLVMICmp::slt : LLVMICmp::ult); + case CIR::le: + return (isSigned ? LLVMICmp::sle : LLVMICmp::ule); + case CIR::gt: + return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt); + case CIR::ge: + return (isSigned ? LLVMICmp::sge : LLVMICmp::uge); + } + llvm_unreachable("Unknown CmpOpKind"); +} + +/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison +/// kind. +static mlir::LLVM::FCmpPredicate +convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) { + using CIR = cir::CmpOpKind; + using LLVMFCmp = mlir::LLVM::FCmpPredicate; + switch (kind) { + case CIR::eq: + return LLVMFCmp::oeq; + case CIR::ne: + return LLVMFCmp::une; + case CIR::lt: + return LLVMFCmp::olt; + case CIR::le: + return LLVMFCmp::ole; + case CIR::gt: + return LLVMFCmp::ogt; + case CIR::ge: + return LLVMFCmp::oge; + } + llvm_unreachable("Unknown CmpOpKind"); +} + +mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( + cir::CmpOp cmpOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Type type = cmpOp.getLhs().getType(); + + assert(!cir::MissingFeatures::dataMemberType()); + assert(!cir::MissingFeatures::methodType()); + + // Lower to LLVM comparison op. + if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) { + bool isSigned = mlir::isa<cir::IntType>(type) + ? mlir::cast<cir::IntType>(type).isSigned() + : mlir::cast<mlir::IntegerType>(type).isSigned(); + mlir::LLVM::ICmpPredicate kind = + convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned); + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + } else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) { + mlir::LLVM::ICmpPredicate kind = + convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + } else if (mlir::isa<cir::CIRFPTypeInterface>(type)) { + mlir::LLVM::FCmpPredicate kind = + convertCmpKindToFCmpPredicate(cmpOp.getKind()); + rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + } else { + return cmpOp.emitError() << "unsupported type for CmpOp: " << type; + } + + return mlir::success(); +} + static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, mlir::DataLayout &dataLayout) { converter.addConversion([&](cir::PointerType type) -> mlir::Type { @@ -1334,6 +1415,7 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMBinOpLowering, CIRToLLVMBrCondOpLowering, CIRToLLVMBrOpLowering, + CIRToLLVMCmpOpLowering, CIRToLLVMConstantOpLowering, CIRToLLVMFuncOpLowering, CIRToLLVMTrapOpLowering, diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 6f489fb..d53c4b3 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -186,6 +186,19 @@ public: mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> { +public: + CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context) + : OpConversionPattern(typeConverter, context) { + setHasBoundedRewriteRecursion(); + } + + mlir::LogicalResult + matchAndRewrite(cir::CmpOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + class CIRToLLVMBrOpLowering : public mlir::OpConversionPattern<cir::BrOp> { public: using mlir::OpConversionPattern<cir::BrOp>::OpConversionPattern; |