diff options
Diffstat (limited to 'clang/lib/CIR')
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp | 5 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenCall.cpp | 8 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenCall.h | 5 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenDecl.cpp | 78 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenExpr.cpp | 65 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp | 91 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp | 102 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenFunction.cpp | 127 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenFunction.h | 10 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenModule.cpp | 118 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenModule.h | 8 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp | 106 | ||||
-rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 6 | ||||
-rw-r--r-- | clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 158 | ||||
-rw-r--r-- | clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp | 3 | ||||
-rw-r--r-- | clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp | 220 | ||||
-rw-r--r-- | clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 106 | ||||
-rw-r--r-- | clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 15 |
18 files changed, 1159 insertions, 72 deletions
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp index ef136f8..9049a01 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp @@ -190,6 +190,11 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID, assert(!cir::MissingFeatures::builtinCheckKind()); return emitBuiltinBitOp<cir::BitClzOp>(*this, e, /*poisonZero=*/true); + case Builtin::BI__builtin_ffs: + case Builtin::BI__builtin_ffsl: + case Builtin::BI__builtin_ffsll: + return emitBuiltinBitOp<cir::BitFfsOp>(*this, e); + case Builtin::BI__builtin_parity: case Builtin::BI__builtin_parityl: case Builtin::BI__builtin_parityll: diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 938d143..fc208ff 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -582,6 +582,14 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, cir::FuncOp directFuncOp; if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr)) { directFuncOp = fnOp; + } else if (auto getGlobalOp = mlir::dyn_cast<cir::GetGlobalOp>(calleePtr)) { + // FIXME(cir): This peephole optimization avoids indirect calls for + // builtins. This should be fixed in the builtin declaration instead by + // not emitting an unecessary get_global in the first place. + // However, this is also used for no-prototype functions. + mlir::Operation *globalOp = cgm.getGlobalValue(getGlobalOp.getName()); + assert(globalOp && "undefined global function"); + directFuncOp = mlir::cast<cir::FuncOp>(globalOp); } else { [[maybe_unused]] mlir::ValueTypeRange<mlir::ResultRange> resultTypes = calleePtr->getResultTypes(); diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h b/clang/lib/CIR/CodeGen/CIRGenCall.h index bd11329..a78956b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.h +++ b/clang/lib/CIR/CodeGen/CIRGenCall.h @@ -116,6 +116,11 @@ public: assert(isOrdinary()); return reinterpret_cast<mlir::Operation *>(kindOrFunctionPtr); } + + void setFunctionPointer(mlir::Operation *functionPtr) { + assert(isOrdinary()); + kindOrFunctionPtr = SpecialKind(reinterpret_cast<uintptr_t>(functionPtr)); + } }; /// Type for representing both the decl and type of parameters to a function. diff --git a/clang/lib/CIR/CodeGen/CIRGenDecl.cpp b/clang/lib/CIR/CodeGen/CIRGenDecl.cpp index a28ac3c..9e8eaa5 100644 --- a/clang/lib/CIR/CodeGen/CIRGenDecl.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenDecl.cpp @@ -520,7 +520,7 @@ void CIRGenFunction::emitExprAsInit(const Expr *init, const ValueDecl *d, llvm_unreachable("bad evaluation kind"); } -void CIRGenFunction::emitDecl(const Decl &d) { +void CIRGenFunction::emitDecl(const Decl &d, bool evaluateConditionDecl) { switch (d.getKind()) { case Decl::BuiltinTemplate: case Decl::TranslationUnit: @@ -608,11 +608,14 @@ void CIRGenFunction::emitDecl(const Decl &d) { case Decl::UsingDirective: // using namespace X; [C++] assert(!cir::MissingFeatures::generateDebugInfo()); return; - case Decl::Var: { + case Decl::Var: + case Decl::Decomposition: { const VarDecl &vd = cast<VarDecl>(d); assert(vd.isLocalVarDecl() && "Should not see file-scope variables inside a function!"); emitVarDecl(vd); + if (evaluateConditionDecl) + maybeEmitDeferredVarDeclInit(&vd); return; } case Decl::OpenACCDeclare: @@ -632,7 +635,6 @@ void CIRGenFunction::emitDecl(const Decl &d) { case Decl::ImplicitConceptSpecialization: case Decl::TopLevelStmt: case Decl::UsingPack: - case Decl::Decomposition: // This could be moved to join Decl::Var case Decl::OMPDeclareReduction: case Decl::OMPDeclareMapper: cgm.errorNYI(d.getSourceRange(), @@ -649,6 +651,38 @@ void CIRGenFunction::emitNullabilityCheck(LValue lhs, mlir::Value rhs, assert(!cir::MissingFeatures::sanitizers()); } +/// Destroys all the elements of the given array, beginning from last to first. +/// The array cannot be zero-length. +/// +/// \param begin - a type* denoting the first element of the array +/// \param end - a type* denoting one past the end of the array +/// \param elementType - the element type of the array +/// \param destroyer - the function to call to destroy elements +void CIRGenFunction::emitArrayDestroy(mlir::Value begin, mlir::Value end, + QualType elementType, + CharUnits elementAlign, + Destroyer *destroyer) { + assert(!elementType->isArrayType()); + + // Differently from LLVM traditional codegen, use a higher level + // representation instead of lowering directly to a loop. + mlir::Type cirElementType = convertTypeForMem(elementType); + cir::PointerType ptrToElmType = builder.getPointerTo(cirElementType); + + // Emit the dtor call that will execute for every array element. + cir::ArrayDtor::create( + builder, *currSrcLoc, begin, [&](mlir::OpBuilder &b, mlir::Location loc) { + auto arg = b.getInsertionBlock()->addArgument(ptrToElmType, loc); + Address curAddr = Address(arg, cirElementType, elementAlign); + assert(!cir::MissingFeatures::dtorCleanups()); + + // Perform the actual destruction there. + destroyer(*this, curAddr, elementType); + + cir::YieldOp::create(builder, loc); + }); +} + /// Immediately perform the destruction of the given object. /// /// \param addr - the address of the object; a type* @@ -658,10 +692,34 @@ void CIRGenFunction::emitNullabilityCheck(LValue lhs, mlir::Value rhs, /// elements void CIRGenFunction::emitDestroy(Address addr, QualType type, Destroyer *destroyer) { - if (getContext().getAsArrayType(type)) - cgm.errorNYI("emitDestroy: array type"); + const ArrayType *arrayType = getContext().getAsArrayType(type); + if (!arrayType) + return destroyer(*this, addr, type); + + mlir::Value length = emitArrayLength(arrayType, type, addr); + + CharUnits elementAlign = addr.getAlignment().alignmentOfArrayElement( + getContext().getTypeSizeInChars(type)); + + auto constantCount = length.getDefiningOp<cir::ConstantOp>(); + if (!constantCount) { + assert(!cir::MissingFeatures::vlas()); + cgm.errorNYI("emitDestroy: variable length array"); + return; + } + + auto constIntAttr = mlir::dyn_cast<cir::IntAttr>(constantCount.getValue()); + // If it's constant zero, we can just skip the entire thing. + if (constIntAttr && constIntAttr.getUInt() == 0) + return; + + mlir::Value begin = addr.getPointer(); + mlir::Value end; // This will be used for future non-constant counts. + emitArrayDestroy(begin, end, type, elementAlign, destroyer); - return destroyer(*this, addr, type); + // If the array destroy didn't use the length op, we can erase it. + if (constantCount.use_empty()) + constantCount.erase(); } CIRGenFunction::Destroyer * @@ -741,3 +799,11 @@ void CIRGenFunction::emitAutoVarTypeCleanup( assert(!cir::MissingFeatures::ehCleanupFlags()); ehStack.pushCleanup<DestroyObject>(cleanupKind, addr, type, destroyer); } + +void CIRGenFunction::maybeEmitDeferredVarDeclInit(const VarDecl *vd) { + if (auto *dd = dyn_cast_if_present<DecompositionDecl>(vd)) { + for (auto *b : dd->flat_bindings()) + if (auto *hd = b->getHoldingVar()) + emitVarDecl(*hd); + } +} diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 7ff5f26..a509ffa 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -584,6 +584,15 @@ LValue CIRGenFunction::emitDeclRefLValue(const DeclRefExpr *e) { return lv; } + if (const auto *bd = dyn_cast<BindingDecl>(nd)) { + if (e->refersToEnclosingVariableOrCapture()) { + assert(!cir::MissingFeatures::lambdaCaptures()); + cgm.errorNYI(e->getSourceRange(), "emitDeclRefLValue: lambda captures"); + return LValue(); + } + return emitLValue(bd->getBinding()); + } + cgm.errorNYI(e->getSourceRange(), "emitDeclRefLValue: unhandled decl type"); return LValue(); } @@ -949,7 +958,6 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) { case CK_Dynamic: case CK_ToUnion: case CK_BaseToDerived: - case CK_LValueBitCast: case CK_AddressSpaceConversion: case CK_ObjCObjectLValueCast: case CK_VectorSplat: @@ -965,6 +973,18 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) { return {}; } + case CK_LValueBitCast: { + // This must be a reinterpret_cast (or c-style equivalent). + const auto *ce = cast<ExplicitCastExpr>(e); + + cgm.emitExplicitCastExprType(ce, this); + LValue LV = emitLValue(e->getSubExpr()); + Address V = LV.getAddress().withElementType( + builder, convertTypeForMem(ce->getTypeAsWritten()->getPointeeType())); + + return makeAddrLValue(V, e->getType(), LV.getBaseInfo()); + } + case CK_NoOp: { // CK_NoOp can model a qualification conversion, which can remove an array // bound and change the IR type. @@ -1269,7 +1289,7 @@ RValue CIRGenFunction::getUndefRValue(QualType ty) { } RValue CIRGenFunction::emitCall(clang::QualType calleeTy, - const CIRGenCallee &callee, + const CIRGenCallee &origCallee, const clang::CallExpr *e, ReturnValueSlot returnValue) { // Get the actual function type. The callee type will always be a pointer to @@ -1280,6 +1300,8 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy, calleeTy = getContext().getCanonicalType(calleeTy); auto pointeeTy = cast<PointerType>(calleeTy)->getPointeeType(); + CIRGenCallee callee = origCallee; + if (getLangOpts().CPlusPlus) assert(!cir::MissingFeatures::sanitizers()); @@ -1296,7 +1318,44 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy, const CIRGenFunctionInfo &funcInfo = cgm.getTypes().arrangeFreeFunctionCall(args, fnType); - assert(!cir::MissingFeatures::opCallNoPrototypeFunc()); + // C99 6.5.2.2p6: + // If the expression that denotes the called function has a type that does + // not include a prototype, [the default argument promotions are performed]. + // If the number of arguments does not equal the number of parameters, the + // behavior is undefined. If the function is defined with a type that + // includes a prototype, and either the prototype ends with an ellipsis (, + // ...) or the types of the arguments after promotion are not compatible + // with the types of the parameters, the behavior is undefined. If the + // function is defined with a type that does not include a prototype, and + // the types of the arguments after promotion are not compatible with those + // of the parameters after promotion, the behavior is undefined [except in + // some trivial cases]. + // That is, in the general case, we should assume that a call through an + // unprototyped function type works like a *non-variadic* call. The way we + // make this work is to cast to the exxact type fo the promoted arguments. + if (isa<FunctionNoProtoType>(fnType)) { + assert(!cir::MissingFeatures::opCallChain()); + assert(!cir::MissingFeatures::addressSpace()); + cir::FuncType calleeTy = getTypes().getFunctionType(funcInfo); + // get non-variadic function type + calleeTy = cir::FuncType::get(calleeTy.getInputs(), + calleeTy.getReturnType(), false); + auto calleePtrTy = cir::PointerType::get(calleeTy); + + mlir::Operation *fn = callee.getFunctionPointer(); + mlir::Value addr; + if (auto funcOp = mlir::dyn_cast<cir::FuncOp>(fn)) { + addr = builder.create<cir::GetGlobalOp>( + getLoc(e->getSourceRange()), + cir::PointerType::get(funcOp.getFunctionType()), funcOp.getSymName()); + } else { + addr = fn->getResult(0); + } + + fn = builder.createBitcast(addr, calleePtrTy).getDefiningOp(); + callee.setFunctionPointer(fn); + } + assert(!cir::MissingFeatures::opCallFnInfoOpts()); assert(!cir::MissingFeatures::hip()); assert(!cir::MissingFeatures::opCallMustTail()); diff --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp index 0d12c5c..51aab95 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp @@ -357,10 +357,97 @@ void AggExprEmitter::visitCXXParenListOrInitListExpr( emitArrayInit(dest.getAddress(), arrayTy, e->getType(), e, args, arrayFiller); return; + } else if (e->getType()->isVariableArrayType()) { + cgf.cgm.errorNYI(e->getSourceRange(), + "visitCXXParenListOrInitListExpr variable array type"); + return; + } + + if (e->getType()->isArrayType()) { + cgf.cgm.errorNYI(e->getSourceRange(), + "visitCXXParenListOrInitListExpr array type"); + return; + } + + assert(e->getType()->isRecordType() && "Only support structs/unions here!"); + + // Do struct initialization; this code just sets each individual member + // to the approprate value. This makes bitfield support automatic; + // the disadvantage is that the generated code is more difficult for + // the optimizer, especially with bitfields. + unsigned numInitElements = args.size(); + RecordDecl *record = e->getType()->castAs<RecordType>()->getDecl(); + + // We'll need to enter cleanup scopes in case any of the element + // initializers throws an exception. + assert(!cir::MissingFeatures::requiresCleanups()); + + unsigned curInitIndex = 0; + + // Emit initialization of base classes. + if (auto *cxxrd = dyn_cast<CXXRecordDecl>(record)) { + assert(numInitElements >= cxxrd->getNumBases() && + "missing initializer for base class"); + if (cxxrd->getNumBases() > 0) { + cgf.cgm.errorNYI(e->getSourceRange(), + "visitCXXParenListOrInitListExpr base class init"); + return; + } + } + + LValue destLV = cgf.makeAddrLValue(dest.getAddress(), e->getType()); + + if (record->isUnion()) { + cgf.cgm.errorNYI(e->getSourceRange(), + "visitCXXParenListOrInitListExpr union type"); + return; } - cgf.cgm.errorNYI( - "visitCXXParenListOrInitListExpr Record or VariableSizeArray type"); + // Here we iterate over the fields; this makes it simpler to both + // default-initialize fields and skip over unnamed fields. + for (const FieldDecl *field : record->fields()) { + // We're done once we hit the flexible array member. + if (field->getType()->isIncompleteArrayType()) + break; + + // Always skip anonymous bitfields. + if (field->isUnnamedBitField()) + continue; + + // We're done if we reach the end of the explicit initializers, we + // have a zeroed object, and the rest of the fields are + // zero-initializable. + if (curInitIndex == numInitElements && dest.isZeroed() && + cgf.getTypes().isZeroInitializable(e->getType())) + break; + LValue lv = + cgf.emitLValueForFieldInitialization(destLV, field, field->getName()); + // We never generate write-barriers for initialized fields. + assert(!cir::MissingFeatures::setNonGC()); + + if (curInitIndex < numInitElements) { + // Store the initializer into the field. + CIRGenFunction::SourceLocRAIIObject loc{ + cgf, cgf.getLoc(record->getSourceRange())}; + emitInitializationToLValue(args[curInitIndex++], lv); + } else { + // We're out of initializers; default-initialize to null + emitNullInitializationToLValue(cgf.getLoc(e->getSourceRange()), lv); + } + + // Push a destructor if necessary. + // FIXME: if we have an array of structures, all explicitly + // initialized, we can end up pushing a linear number of cleanups. + if (field->getType().isDestructedType()) { + cgf.cgm.errorNYI(e->getSourceRange(), + "visitCXXParenListOrInitListExpr destructor"); + return; + } + + // From classic codegen, maybe not useful for CIR: + // If the GEP didn't get used because of a dead zero init or something + // else, clean it up for -O0 builds and general tidiness. + } } // TODO(cir): This could be shared with classic codegen. diff --git a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp index 7f2e2ce..ea60a95 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp @@ -60,7 +60,7 @@ public: mlir::Value VisitDeclRefExpr(DeclRefExpr *e); mlir::Value VisitGenericSelectionExpr(GenericSelectionExpr *e); mlir::Value VisitImplicitCastExpr(ImplicitCastExpr *e); - mlir::Value VisitInitListExpr(const InitListExpr *e); + mlir::Value VisitInitListExpr(InitListExpr *e); mlir::Value VisitCompoundLiteralExpr(CompoundLiteralExpr *e) { return emitLoadOfLValue(e); @@ -91,6 +91,14 @@ public: } mlir::Value VisitUnaryDeref(const Expr *e); + + mlir::Value VisitUnaryPlus(const UnaryOperator *e); + + mlir::Value VisitPlusMinus(const UnaryOperator *e, cir::UnaryOpKind kind, + QualType promotionType); + + mlir::Value VisitUnaryMinus(const UnaryOperator *e); + mlir::Value VisitUnaryNot(const UnaryOperator *e); struct BinOpInfo { @@ -110,6 +118,7 @@ public: mlir::Value emitBinAdd(const BinOpInfo &op); mlir::Value emitBinSub(const BinOpInfo &op); + mlir::Value emitBinMul(const BinOpInfo &op); QualType getPromotionType(QualType ty, bool isDivOpCode = false) { if (auto *complexTy = ty->getAs<ComplexType>()) { @@ -142,6 +151,7 @@ public: HANDLEBINOP(Add) HANDLEBINOP(Sub) + HANDLEBINOP(Mul) #undef HANDLEBINOP }; } // namespace @@ -189,8 +199,11 @@ mlir::Value ComplexExprEmitter::emitCast(CastKind ck, Expr *op, } case CK_LValueBitCast: { - cgf.cgm.errorNYI("ComplexExprEmitter::emitCast CK_LValueBitCast"); - return {}; + LValue origLV = cgf.emitLValue(op); + Address addr = + origLV.getAddress().withElementType(builder, cgf.convertType(destTy)); + LValue destLV = cgf.makeAddrLValue(addr, destTy); + return emitLoadOfLValue(destLV, op->getExprLoc()); } case CK_LValueToRValueBitCast: { @@ -279,6 +292,41 @@ mlir::Value ComplexExprEmitter::emitCast(CastKind ck, Expr *op, llvm_unreachable("unknown cast resulting in complex value"); } +mlir::Value ComplexExprEmitter::VisitUnaryPlus(const UnaryOperator *e) { + QualType promotionTy = getPromotionType(e->getSubExpr()->getType()); + mlir::Value result = VisitPlusMinus(e, cir::UnaryOpKind::Plus, promotionTy); + if (!promotionTy.isNull()) { + cgf.cgm.errorNYI("ComplexExprEmitter::VisitUnaryPlus emitUnPromotedValue"); + return {}; + } + return result; +} + +mlir::Value ComplexExprEmitter::VisitPlusMinus(const UnaryOperator *e, + cir::UnaryOpKind kind, + QualType promotionType) { + assert(kind == cir::UnaryOpKind::Plus || + kind == cir::UnaryOpKind::Minus && + "Invalid UnaryOp kind for ComplexType Plus or Minus"); + + mlir::Value op; + if (!promotionType.isNull()) + op = cgf.emitPromotedComplexExpr(e->getSubExpr(), promotionType); + else + op = Visit(e->getSubExpr()); + return builder.createUnaryOp(cgf.getLoc(e->getExprLoc()), kind, op); +} + +mlir::Value ComplexExprEmitter::VisitUnaryMinus(const UnaryOperator *e) { + QualType promotionTy = getPromotionType(e->getSubExpr()->getType()); + mlir::Value result = VisitPlusMinus(e, cir::UnaryOpKind::Minus, promotionTy); + if (!promotionTy.isNull()) { + cgf.cgm.errorNYI("ComplexExprEmitter::VisitUnaryMinus emitUnPromotedValue"); + return {}; + } + return result; +} + mlir::Value ComplexExprEmitter::emitConstant( const CIRGenFunction::ConstantEmission &constant, Expr *e) { assert(constant && "not a constant"); @@ -452,7 +500,7 @@ mlir::Value ComplexExprEmitter::VisitImplicitCastExpr(ImplicitCastExpr *e) { return emitCast(e->getCastKind(), e->getSubExpr(), e->getType()); } -mlir::Value ComplexExprEmitter::VisitInitListExpr(const InitListExpr *e) { +mlir::Value ComplexExprEmitter::VisitInitListExpr(InitListExpr *e) { mlir::Location loc = cgf.getLoc(e->getExprLoc()); if (e->getNumInits() == 2) { mlir::Value real = cgf.emitScalarExpr(e->getInit(0)); @@ -460,10 +508,8 @@ mlir::Value ComplexExprEmitter::VisitInitListExpr(const InitListExpr *e) { return builder.createComplexCreate(loc, real, imag); } - if (e->getNumInits() == 1) { - cgf.cgm.errorNYI("Create Complex with InitList with size 1"); - return {}; - } + if (e->getNumInits() == 1) + return Visit(e->getInit(0)); assert(e->getNumInits() == 0 && "Unexpected number of inits"); mlir::Type complexTy = cgf.convertType(e->getType()); @@ -533,13 +579,22 @@ mlir::Value ComplexExprEmitter::emitPromoted(const Expr *e, return emitBin##OP(emitBinOps(bo, promotionTy)); HANDLE_BINOP(Add) HANDLE_BINOP(Sub) + HANDLE_BINOP(Mul) #undef HANDLE_BINOP default: break; } - } else if (isa<UnaryOperator>(e)) { - cgf.cgm.errorNYI("emitPromoted UnaryOperator"); - return {}; + } else if (const auto *unaryOp = dyn_cast<UnaryOperator>(e)) { + switch (unaryOp->getOpcode()) { + case UO_Minus: + case UO_Plus: { + auto kind = unaryOp->getOpcode() == UO_Plus ? cir::UnaryOpKind::Plus + : cir::UnaryOpKind::Minus; + return VisitPlusMinus(unaryOp, kind, promotionTy); + } + default: + break; + } } mlir::Value result = Visit(const_cast<Expr *>(e)); @@ -584,6 +639,31 @@ mlir::Value ComplexExprEmitter::emitBinSub(const BinOpInfo &op) { return builder.create<cir::ComplexSubOp>(op.loc, op.lhs, op.rhs); } +static cir::ComplexRangeKind +getComplexRangeAttr(LangOptions::ComplexRangeKind range) { + switch (range) { + case LangOptions::CX_Full: + return cir::ComplexRangeKind::Full; + case LangOptions::CX_Improved: + return cir::ComplexRangeKind::Improved; + case LangOptions::CX_Promoted: + return cir::ComplexRangeKind::Promoted; + case LangOptions::CX_Basic: + return cir::ComplexRangeKind::Basic; + case LangOptions::CX_None: + // The default value for ComplexRangeKind is Full is no option is selected + return cir::ComplexRangeKind::Full; + } +} + +mlir::Value ComplexExprEmitter::emitBinMul(const BinOpInfo &op) { + assert(!cir::MissingFeatures::fastMathFlags()); + assert(!cir::MissingFeatures::cgFPOptionsRAII()); + cir::ComplexRangeKind rangeKind = + getComplexRangeAttr(op.fpFeatures.getComplexRange()); + return builder.create<cir::ComplexMulOp>(op.loc, op.lhs, op.rhs, rangeKind); +} + LValue CIRGenFunction::emitComplexAssignmentLValue(const BinaryOperator *e) { assert(e->getOpcode() == BO_Assign && "Expected assign op"); diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index b4b95d6..3e9de17 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -923,4 +923,131 @@ CIRGenFunction::emitArrayLength(const clang::ArrayType *origArrayType, return builder.getConstInt(*currSrcLoc, SizeTy, countFromCLAs); } +// TODO(cir): Most of this function can be shared between CIRGen +// and traditional LLVM codegen +void CIRGenFunction::emitVariablyModifiedType(QualType type) { + assert(type->isVariablyModifiedType() && + "Must pass variably modified type to EmitVLASizes!"); + + // We're going to walk down into the type and look for VLA + // expressions. + do { + assert(type->isVariablyModifiedType()); + + const Type *ty = type.getTypePtr(); + switch (ty->getTypeClass()) { + case Type::CountAttributed: + case Type::PackIndexing: + case Type::ArrayParameter: + case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: + case Type::PredefinedSugar: + cgm.errorNYI("CIRGenFunction::emitVariablyModifiedType"); + break; + +#define TYPE(Class, Base) +#define ABSTRACT_TYPE(Class, Base) +#define NON_CANONICAL_TYPE(Class, Base) +#define DEPENDENT_TYPE(Class, Base) case Type::Class: +#define NON_CANONICAL_UNLESS_DEPENDENT_TYPE(Class, Base) +#include "clang/AST/TypeNodes.inc" + llvm_unreachable( + "dependent type must be resolved before the CIR codegen"); + + // These types are never variably-modified. + case Type::Builtin: + case Type::Complex: + case Type::Vector: + case Type::ExtVector: + case Type::ConstantMatrix: + case Type::Record: + case Type::Enum: + case Type::Using: + case Type::TemplateSpecialization: + case Type::ObjCTypeParam: + case Type::ObjCObject: + case Type::ObjCInterface: + case Type::ObjCObjectPointer: + case Type::BitInt: + llvm_unreachable("type class is never variably-modified!"); + + case Type::Elaborated: + type = cast<clang::ElaboratedType>(ty)->getNamedType(); + break; + + case Type::Adjusted: + type = cast<clang::AdjustedType>(ty)->getAdjustedType(); + break; + + case Type::Decayed: + type = cast<clang::DecayedType>(ty)->getPointeeType(); + break; + + case Type::Pointer: + type = cast<clang::PointerType>(ty)->getPointeeType(); + break; + + case Type::BlockPointer: + type = cast<clang::BlockPointerType>(ty)->getPointeeType(); + break; + + case Type::LValueReference: + case Type::RValueReference: + type = cast<clang::ReferenceType>(ty)->getPointeeType(); + break; + + case Type::MemberPointer: + type = cast<clang::MemberPointerType>(ty)->getPointeeType(); + break; + + case Type::ConstantArray: + case Type::IncompleteArray: + // Losing element qualification here is fine. + type = cast<clang::ArrayType>(ty)->getElementType(); + break; + + case Type::VariableArray: { + cgm.errorNYI("CIRGenFunction::emitVariablyModifiedType VLA"); + break; + } + + case Type::FunctionProto: + case Type::FunctionNoProto: + type = cast<clang::FunctionType>(ty)->getReturnType(); + break; + + case Type::Paren: + case Type::TypeOf: + case Type::UnaryTransform: + case Type::Attributed: + case Type::BTFTagAttributed: + case Type::SubstTemplateTypeParm: + case Type::MacroQualified: + // Keep walking after single level desugaring. + type = type.getSingleStepDesugaredType(getContext()); + break; + + case Type::Typedef: + case Type::Decltype: + case Type::Auto: + case Type::DeducedTemplateSpecialization: + // Stop walking: nothing to do. + return; + + case Type::TypeOfExpr: + // Stop walking: emit typeof expression. + emitIgnoredExpr(cast<clang::TypeOfExprType>(ty)->getUnderlyingExpr()); + return; + + case Type::Atomic: + type = cast<clang::AtomicType>(ty)->getValueType(); + break; + + case Type::Pipe: + type = cast<clang::PipeType>(ty)->getElementType(); + break; + } + } while (type->isVariablyModifiedType()); +} + } // namespace clang::CIRGen diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 4891c74..f9c8636 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -848,6 +848,10 @@ public: /// even if no aggregate location is provided. RValue emitAnyExprToTemp(const clang::Expr *e); + void emitArrayDestroy(mlir::Value begin, mlir::Value end, + QualType elementType, CharUnits elementAlign, + Destroyer *destroyer); + mlir::Value emitArrayLength(const clang::ArrayType *arrayType, QualType &baseType, Address &addr); LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e); @@ -866,6 +870,8 @@ public: void emitAutoVarTypeCleanup(const AutoVarEmission &emission, clang::QualType::DestructionKind dtorKind); + void maybeEmitDeferredVarDeclInit(const VarDecl *vd); + void emitBaseInitializer(mlir::Location loc, const CXXRecordDecl *classDecl, CXXCtorInitializer *baseInit); @@ -1055,7 +1061,7 @@ public: void emitCompoundStmtWithoutScope(const clang::CompoundStmt &s); - void emitDecl(const clang::Decl &d); + void emitDecl(const clang::Decl &d, bool evaluateConditionDecl = false); mlir::LogicalResult emitDeclStmt(const clang::DeclStmt &s); LValue emitDeclRefLValue(const clang::DeclRefExpr *e); @@ -1201,6 +1207,8 @@ public: /// inside a function, including static vars etc. void emitVarDecl(const clang::VarDecl &d); + void emitVariablyModifiedType(QualType ty); + mlir::LogicalResult emitWhileStmt(const clang::WhileStmt &s); /// Given an assignment `*lhs = rhs`, emit a test that checks if \p rhs is diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index 3502705..d0f9fc3 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -1103,6 +1103,60 @@ cir::GlobalLinkageKind CIRGenModule::getCIRLinkageForDeclarator( return cir::GlobalLinkageKind::ExternalLinkage; } +/// This function is called when we implement a function with no prototype, e.g. +/// "int foo() {}". If there are existing call uses of the old function in the +/// module, this adjusts them to call the new function directly. +/// +/// This is not just a cleanup: the always_inline pass requires direct calls to +/// functions to be able to inline them. If there is a bitcast in the way, it +/// won't inline them. Instcombine normally deletes these calls, but it isn't +/// run at -O0. +void CIRGenModule::replaceUsesOfNonProtoTypeWithRealFunction( + mlir::Operation *old, cir::FuncOp newFn) { + // If we're redefining a global as a function, don't transform it. + auto oldFn = mlir::dyn_cast<cir::FuncOp>(old); + if (!oldFn) + return; + + // TODO(cir): this RAUW ignores the features below. + assert(!cir::MissingFeatures::opFuncExceptions()); + assert(!cir::MissingFeatures::opFuncParameterAttributes()); + assert(!cir::MissingFeatures::opFuncOperandBundles()); + if (oldFn->getAttrs().size() <= 1) + errorNYI(old->getLoc(), + "replaceUsesOfNonProtoTypeWithRealFunction: Attribute forwarding"); + + // Mark new function as originated from a no-proto declaration. + newFn.setNoProto(oldFn.getNoProto()); + + // Iterate through all calls of the no-proto function. + std::optional<mlir::SymbolTable::UseRange> symUses = + oldFn.getSymbolUses(oldFn->getParentOp()); + for (const mlir::SymbolTable::SymbolUse &use : symUses.value()) { + mlir::OpBuilder::InsertionGuard guard(builder); + + if (auto noProtoCallOp = mlir::dyn_cast<cir::CallOp>(use.getUser())) { + builder.setInsertionPoint(noProtoCallOp); + + // Patch call type with the real function type. + cir::CallOp realCallOp = builder.createCallOp( + noProtoCallOp.getLoc(), newFn, noProtoCallOp.getOperands()); + + // Replace old no proto call with fixed call. + noProtoCallOp.replaceAllUsesWith(realCallOp); + noProtoCallOp.erase(); + } else if (auto getGlobalOp = + mlir::dyn_cast<cir::GetGlobalOp>(use.getUser())) { + // Replace type + getGlobalOp.getAddr().setType( + cir::PointerType::get(newFn.getFunctionType())); + } else { + errorNYI(use.getUser()->getLoc(), + "replaceUsesOfNonProtoTypeWithRealFunction: unexpected use"); + } + } +} + cir::GlobalLinkageKind CIRGenModule::getCIRLinkageVarDefinition(const VarDecl *vd, bool isConstant) { assert(!isConstant && "constant variables NYI"); @@ -1208,6 +1262,15 @@ cir::GlobalOp CIRGenModule::getGlobalForStringLiteral(const StringLiteral *s, return gv; } +void CIRGenModule::emitExplicitCastExprType(const ExplicitCastExpr *e, + CIRGenFunction *cgf) { + if (cgf && e->getType()->isVariablyModifiedType()) + cgf->emitVariablyModifiedType(e->getType()); + + assert(!cir::MissingFeatures::generateDebugInfo() && + "emitExplicitCastExprType"); +} + void CIRGenModule::emitDeclContext(const DeclContext *dc) { for (Decl *decl : dc->decls()) { // Unlike other DeclContexts, the contents of an ObjCImplDecl at TU scope @@ -1235,6 +1298,7 @@ void CIRGenModule::emitTopLevelDecl(Decl *decl) { decl->getDeclKindName()); break; + case Decl::CXXConversion: case Decl::CXXMethod: case Decl::Function: { auto *fd = cast<FunctionDecl>(decl); @@ -1244,8 +1308,13 @@ void CIRGenModule::emitTopLevelDecl(Decl *decl) { break; } - case Decl::Var: { + case Decl::Var: + case Decl::Decomposition: { auto *vd = cast<VarDecl>(decl); + if (isa<DecompositionDecl>(decl)) { + errorNYI(decl->getSourceRange(), "global variable decompositions"); + break; + } emitGlobal(vd); break; } @@ -1267,8 +1336,14 @@ void CIRGenModule::emitTopLevelDecl(Decl *decl) { break; // No code generation needed. - case Decl::UsingShadow: + case Decl::ClassTemplate: + case Decl::Concept: + case Decl::CXXDeductionGuide: case Decl::Empty: + case Decl::FunctionTemplate: + case Decl::StaticAssert: + case Decl::TypeAliasTemplate: + case Decl::UsingShadow: break; case Decl::CXXConstructor: @@ -1530,10 +1605,10 @@ static bool shouldAssumeDSOLocal(const CIRGenModule &cgm, const llvm::Triple &tt = cgm.getTriple(); const CodeGenOptions &cgOpts = cgm.getCodeGenOpts(); - if (tt.isWindowsGNUEnvironment()) { - // In MinGW, variables without DLLImport can still be automatically - // imported from a DLL by the linker; don't mark variables that - // potentially could come from another DLL as DSO local. + if (tt.isOSCygMing()) { + // In MinGW and Cygwin, variables without DLLImport can still be + // automatically imported from a DLL by the linker; don't mark variables + // that potentially could come from another DLL as DSO local. // With EmulatedTLS, TLS variables can be autoimported from other DLLs // (and this actually happens in the public interface of libstdc++), so @@ -1692,8 +1767,7 @@ cir::FuncOp CIRGenModule::getOrCreateCIRFunction( // Lookup the entry, lazily creating it if necessary. mlir::Operation *entry = getGlobalValue(mangledName); if (entry) { - if (!isa<cir::FuncOp>(entry)) - errorNYI(d->getSourceRange(), "getOrCreateCIRFunction: non-FuncOp"); + assert(mlir::isa<cir::FuncOp>(entry)); assert(!cir::MissingFeatures::weakRefReference()); @@ -1729,6 +1803,30 @@ cir::FuncOp CIRGenModule::getOrCreateCIRFunction( invalidLoc ? theModule->getLoc() : getLoc(funcDecl->getSourceRange()), mangledName, mlir::cast<cir::FuncType>(funcType), funcDecl); + // If we already created a function with the same mangled name (but different + // type) before, take its name and add it to the list of functions to be + // replaced with F at the end of CodeGen. + // + // This happens if there is a prototype for a function (e.g. "int f()") and + // then a definition of a different type (e.g. "int f(int x)"). + if (entry) { + + // Fetch a generic symbol-defining operation and its uses. + auto symbolOp = mlir::cast<mlir::SymbolOpInterface>(entry); + + // This might be an implementation of a function without a prototype, in + // which case, try to do special replacement of calls which match the new + // prototype. The really key thing here is that we also potentially drop + // arguments from the call site so as to make a direct call, which makes the + // inliner happier and suppresses a number of optimizer warnings (!) about + // dropping arguments. + if (symbolOp.getSymbolUses(symbolOp->getParentOp())) + replaceUsesOfNonProtoTypeWithRealFunction(entry, funcOp); + + // Obliterate no-proto declaration. + entry->erase(); + } + if (d) setFunctionAttributes(gd, funcOp, /*isIncompleteFunction=*/false, isThunk); @@ -1805,7 +1903,9 @@ CIRGenModule::createCIRFunction(mlir::Location loc, StringRef name, func = builder.create<cir::FuncOp>(loc, name, funcType); assert(!cir::MissingFeatures::opFuncAstDeclAttr()); - assert(!cir::MissingFeatures::opFuncNoProto()); + + if (funcDecl && !funcDecl->hasPrototype()) + func.setNoProto(true); assert(func.isDeclaration() && "expected empty body"); diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.h b/clang/lib/CIR/CodeGen/CIRGenModule.h index 16922b1..5d07d38 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.h +++ b/clang/lib/CIR/CodeGen/CIRGenModule.h @@ -252,6 +252,11 @@ public: getAddrOfGlobal(clang::GlobalDecl gd, ForDefinition_t isForDefinition = NotForDefinition); + /// Emit type info if type of an expression is a variably modified + /// type. Also emit proper debug info for cast types. + void emitExplicitCastExprType(const ExplicitCastExpr *e, + CIRGenFunction *cgf = nullptr); + /// Emit code for a single global function or variable declaration. Forward /// declarations are emitted lazily. void emitGlobal(clang::GlobalDecl gd); @@ -308,6 +313,9 @@ public: static void setInitializer(cir::GlobalOp &op, mlir::Attribute value); + void replaceUsesOfNonProtoTypeWithRealFunction(mlir::Operation *old, + cir::FuncOp newFn); + cir::FuncOp getOrCreateCIRFunction(llvm::StringRef mangledName, mlir::Type funcType, clang::GlobalDecl gd, bool forVTable, diff --git a/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp b/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp index 05e8848..0c8ff4bd 100644 --- a/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenRecordLayoutBuilder.cpp @@ -91,6 +91,9 @@ struct CIRRecordLowering final { return astContext.getTargetInfo().getABI().starts_with("aapcs"); } + /// Helper function to check if the target machine is BigEndian. + bool isBigEndian() const { return astContext.getTargetInfo().isBigEndian(); } + CharUnits bitsToCharUnits(uint64_t bitOffset) { return astContext.toCharUnitsFromBits(bitOffset); } @@ -438,9 +441,7 @@ CIRRecordLowering::accumulateBitFields(RecordDecl::field_iterator field, } else if (cirGenTypes.getCGModule() .getCodeGenOpts() .FineGrainedBitfieldAccesses) { - assert(!cir::MissingFeatures::nonFineGrainedBitfields()); - cirGenTypes.getCGModule().errorNYI(field->getSourceRange(), - "NYI FineGrainedBitfield"); + installBest = true; } else { // Otherwise, we're not installing. Update the bit size // of the current span to go all the way to limitOffset, which is @@ -773,7 +774,104 @@ void CIRRecordLowering::computeVolatileBitfields() { !cirGenTypes.getCGModule().getCodeGenOpts().AAPCSBitfieldWidth) return; - assert(!cir::MissingFeatures::armComputeVolatileBitfields()); + for (auto &[field, info] : bitFields) { + mlir::Type resLTy = cirGenTypes.convertTypeForMem(field->getType()); + + if (astContext.toBits(astRecordLayout.getAlignment()) < + getSizeInBits(resLTy).getQuantity()) + continue; + + // CIRRecordLowering::setBitFieldInfo() pre-adjusts the bit-field offsets + // for big-endian targets, but it assumes a container of width + // info.storageSize. Since AAPCS uses a different container size (width + // of the type), we first undo that calculation here and redo it once + // the bit-field offset within the new container is calculated. + const unsigned oldOffset = + isBigEndian() ? info.storageSize - (info.offset + info.size) + : info.offset; + // Offset to the bit-field from the beginning of the struct. + const unsigned absoluteOffset = + astContext.toBits(info.storageOffset) + oldOffset; + + // Container size is the width of the bit-field type. + const unsigned storageSize = getSizeInBits(resLTy).getQuantity(); + // Nothing to do if the access uses the desired + // container width and is naturally aligned. + if (info.storageSize == storageSize && (oldOffset % storageSize == 0)) + continue; + + // Offset within the container. + unsigned offset = absoluteOffset & (storageSize - 1); + // Bail out if an aligned load of the container cannot cover the entire + // bit-field. This can happen for example, if the bit-field is part of a + // packed struct. AAPCS does not define access rules for such cases, we let + // clang to follow its own rules. + if (offset + info.size > storageSize) + continue; + + // Re-adjust offsets for big-endian targets. + if (isBigEndian()) + offset = storageSize - (offset + info.size); + + const CharUnits storageOffset = + astContext.toCharUnitsFromBits(absoluteOffset & ~(storageSize - 1)); + const CharUnits end = storageOffset + + astContext.toCharUnitsFromBits(storageSize) - + CharUnits::One(); + + const ASTRecordLayout &layout = + astContext.getASTRecordLayout(field->getParent()); + // If we access outside memory outside the record, than bail out. + const CharUnits recordSize = layout.getSize(); + if (end >= recordSize) + continue; + + // Bail out if performing this load would access non-bit-fields members. + bool conflict = false; + for (const auto *f : recordDecl->fields()) { + // Allow sized bit-fields overlaps. + if (f->isBitField() && !f->isZeroLengthBitField()) + continue; + + const CharUnits fOffset = astContext.toCharUnitsFromBits( + layout.getFieldOffset(f->getFieldIndex())); + + // As C11 defines, a zero sized bit-field defines a barrier, so + // fields after and before it should be race condition free. + // The AAPCS acknowledges it and imposes no restritions when the + // natural container overlaps a zero-length bit-field. + if (f->isZeroLengthBitField()) { + if (end > fOffset && storageOffset < fOffset) { + conflict = true; + break; + } + } + + const CharUnits fEnd = + fOffset + + astContext.toCharUnitsFromBits(astContext.toBits( + getSizeInBits(cirGenTypes.convertTypeForMem(f->getType())))) - + CharUnits::One(); + // If no overlap, continue. + if (end < fOffset || fEnd < storageOffset) + continue; + + // The desired load overlaps a non-bit-field member, bail out. + conflict = true; + break; + } + + if (conflict) + continue; + // Write the new bit-field access parameters. + // As the storage offset now is defined as the number of elements from the + // start of the structure, we should divide the Offset by the element size. + info.volatileStorageOffset = + storageOffset / + astContext.toCharUnitsFromBits(storageSize).getQuantity(); + info.volatileStorageSize = storageSize; + info.volatileOffset = offset; + } } void CIRRecordLowering::accumulateBases(const CXXRecordDecl *cxxRecordDecl) { diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 21bee33..75da229e 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -363,8 +363,8 @@ mlir::LogicalResult CIRGenFunction::emitIfStmt(const IfStmt &s) { mlir::LogicalResult CIRGenFunction::emitDeclStmt(const DeclStmt &s) { assert(builder.getInsertionBlock() && "expected valid insertion point"); - for (const Decl *I : s.decls()) - emitDecl(*I); + for (const Decl *i : s.decls()) + emitDecl(*i, /*evaluateConditionDecl=*/true); return mlir::success(); } @@ -875,7 +875,7 @@ mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) { return mlir::failure(); if (s.getConditionVariable()) - emitDecl(*s.getConditionVariable()); + emitDecl(*s.getConditionVariable(), /*evaluateConditionDecl=*/true); mlir::Value condV = emitScalarExpr(s.getCond()); diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index cd77166..1c3a310 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc" #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc" @@ -338,7 +339,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, } if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, - cir::ConstComplexAttr>(attrType)) + cir::ConstComplexAttr, cir::PoisonAttr>(attrType)) return success(); assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?"); @@ -628,6 +629,11 @@ static Value tryFoldCastChain(cir::CastOp op) { } OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) { + if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) { + // Propagate poison value + return cir::PoisonAttr::get(getContext(), getType()); + } + if (getSrc().getType() == getType()) { switch (getKind()) { case cir::CastKind::integral: { @@ -1464,10 +1470,14 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) { llvm::SMLoc loc = parser.getCurrentLocation(); mlir::Builder &builder = parser.getBuilder(); + mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name); mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name); mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name); mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name); + if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded()) + state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr()); + // Default to external linkage if no keyword is provided. state.addAttribute(getLinkageAttrNameString(), GlobalLinkageKindAttr::get( @@ -1572,6 +1582,9 @@ mlir::Region *cir::FuncOp::getCallableRegion() { } void cir::FuncOp::print(OpAsmPrinter &p) { + if (getNoProto()) + p << " no_proto"; + if (getComdat()) p << " comdat"; @@ -1782,6 +1795,12 @@ static bool isBoolNot(cir::UnaryOp op) { // // and the argument of the first one (%0) will be used instead. OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) { + if (auto poison = + mlir::dyn_cast_if_present<cir::PoisonAttr>(adaptor.getInput())) { + // Propagate poison values + return poison; + } + if (isBoolNot(*this)) if (auto previous = dyn_cast_or_null<UnaryOp>(getInput().getDefiningOp())) if (isBoolNot(previous)) @@ -2231,6 +2250,143 @@ LogicalResult cir::ComplexImagPtrOp::verify() { } //===----------------------------------------------------------------------===// +// Bit manipulation operations +//===----------------------------------------------------------------------===// + +static OpFoldResult +foldUnaryBitOp(mlir::Attribute inputAttr, + llvm::function_ref<llvm::APInt(const llvm::APInt &)> func, + bool poisonZero = false) { + if (mlir::isa_and_present<cir::PoisonAttr>(inputAttr)) { + // Propagate poison value + return inputAttr; + } + + auto input = mlir::dyn_cast_if_present<IntAttr>(inputAttr); + if (!input) + return nullptr; + + llvm::APInt inputValue = input.getValue(); + if (poisonZero && inputValue.isZero()) + return cir::PoisonAttr::get(input.getType()); + + llvm::APInt resultValue = func(inputValue); + return IntAttr::get(input.getType(), resultValue); +} + +OpFoldResult BitClrsbOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) { + unsigned resultValue = + inputValue.getBitWidth() - inputValue.getSignificantBits(); + return llvm::APInt(inputValue.getBitWidth(), resultValue); + }); +} + +OpFoldResult BitClzOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp( + adaptor.getInput(), + [](const llvm::APInt &inputValue) { + unsigned resultValue = inputValue.countLeadingZeros(); + return llvm::APInt(inputValue.getBitWidth(), resultValue); + }, + getPoisonZero()); +} + +OpFoldResult BitCtzOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp( + adaptor.getInput(), + [](const llvm::APInt &inputValue) { + return llvm::APInt(inputValue.getBitWidth(), + inputValue.countTrailingZeros()); + }, + getPoisonZero()); +} + +OpFoldResult BitFfsOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) { + unsigned trailingZeros = inputValue.countTrailingZeros(); + unsigned result = + trailingZeros == inputValue.getBitWidth() ? 0 : trailingZeros + 1; + return llvm::APInt(inputValue.getBitWidth(), result); + }); +} + +OpFoldResult BitParityOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) { + return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount() % 2); + }); +} + +OpFoldResult BitPopcountOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) { + return llvm::APInt(inputValue.getBitWidth(), inputValue.popcount()); + }); +} + +OpFoldResult BitReverseOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) { + return inputValue.reverseBits(); + }); +} + +OpFoldResult ByteSwapOp::fold(FoldAdaptor adaptor) { + return foldUnaryBitOp(adaptor.getInput(), [](const llvm::APInt &inputValue) { + return inputValue.byteSwap(); + }); +} + +OpFoldResult RotateOp::fold(FoldAdaptor adaptor) { + if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getInput()) || + mlir::isa_and_present<cir::PoisonAttr>(adaptor.getAmount())) { + // Propagate poison values + return cir::PoisonAttr::get(getType()); + } + + auto input = mlir::dyn_cast_if_present<IntAttr>(adaptor.getInput()); + auto amount = mlir::dyn_cast_if_present<IntAttr>(adaptor.getAmount()); + if (!input && !amount) + return nullptr; + + // We could fold cir.rotate even if one of its two operands is not a constant: + // - `cir.rotate left/right %0, 0` could be folded into just %0 even if %0 + // is not a constant. + // - `cir.rotate left/right 0/0b111...111, %0` could be folded into 0 or + // 0b111...111 even if %0 is not a constant. + + llvm::APInt inputValue; + if (input) { + inputValue = input.getValue(); + if (inputValue.isZero() || inputValue.isAllOnes()) { + // An input value of all 0s or all 1s will not change after rotation + return input; + } + } + + uint64_t amountValue; + if (amount) { + amountValue = amount.getValue().urem(getInput().getType().getWidth()); + if (amountValue == 0) { + // A shift amount of 0 will not change the input value + return getInput(); + } + } + + if (!input || !amount) + return nullptr; + + assert(inputValue.getBitWidth() == getInput().getType().getWidth() && + "input value must have the same bit width as the input type"); + + llvm::APInt resultValue; + if (isRotateLeft()) + resultValue = inputValue.rotl(amountValue); + else + resultValue = inputValue.rotr(amountValue); + + return IntAttr::get(input.getContext(), input.getType(), resultValue); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp index e505db5..2eaa60c 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp @@ -143,7 +143,8 @@ void CIRCanonicalizePass::runOnOperation() { if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, ComplexCreateOp, ComplexImagOp, ComplexRealOp, VecCmpOp, VecCreateOp, VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, - VecTernaryOp>(op)) + VecTernaryOp, BitClrsbOp, BitClzOp, BitCtzOp, BitFfsOp, BitParityOp, + BitPopcountOp, BitReverseOp, ByteSwapOp, RotateOp>(op)) ops.push_back(op); }); diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp index cef83ea..66260eb 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp @@ -15,7 +15,6 @@ #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/MissingFeatures.h" -#include <iostream> #include <memory> using namespace mlir; @@ -28,8 +27,15 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> { void runOnOp(mlir::Operation *op); void lowerCastOp(cir::CastOp op); + void lowerComplexMulOp(cir::ComplexMulOp op); void lowerUnaryOp(cir::UnaryOp op); - void lowerArrayCtor(ArrayCtor op); + void lowerArrayDtor(cir::ArrayDtor op); + void lowerArrayCtor(cir::ArrayCtor op); + + cir::FuncOp buildRuntimeFunction( + mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, + cir::FuncType type, + cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage); /// /// AST related @@ -37,11 +43,31 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> { clang::ASTContext *astCtx; + /// Tracks current module. + mlir::ModuleOp mlirModule; + void setASTContext(clang::ASTContext *c) { astCtx = c; } }; } // namespace +cir::FuncOp LoweringPreparePass::buildRuntimeFunction( + mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, + cir::FuncType type, cir::GlobalLinkageKind linkage) { + cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom( + mlirModule, StringAttr::get(mlirModule->getContext(), name))); + if (!f) { + f = builder.create<cir::FuncOp>(loc, name, type); + f.setLinkageAttr( + cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); + mlir::SymbolTable::setSymbolVisibility( + f, mlir::SymbolTable::Visibility::Private); + + assert(!cir::MissingFeatures::opFuncExtraAttrs()); + } + return f; +} + static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op) { cir::CIRBaseBuilderTy builder(ctx); @@ -127,6 +153,124 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) { } } +static mlir::Value buildComplexBinOpLibCall( + LoweringPreparePass &pass, CIRBaseBuilderTy &builder, + llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics), + mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, + mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { + cir::FPTypeInterface elementTy = + mlir::cast<cir::FPTypeInterface>(ty.getElementType()); + + llvm::StringRef libFuncName = libFuncNameGetter( + llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics())); + llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy); + + cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty); + + // Insert a declaration for the runtime function to be used in Complex + // multiplication and division when needed + cir::FuncOp libFunc; + { + mlir::OpBuilder::InsertionGuard ipGuard{builder}; + builder.setInsertionPointToStart(pass.mlirModule.getBody()); + libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy); + } + + cir::CallOp call = + builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag}); + return call.getResult(); +} + +static llvm::StringRef +getComplexMulLibCallName(llvm::APFloat::Semantics semantics) { + switch (semantics) { + case llvm::APFloat::S_IEEEhalf: + return "__mulhc3"; + case llvm::APFloat::S_IEEEsingle: + return "__mulsc3"; + case llvm::APFloat::S_IEEEdouble: + return "__muldc3"; + case llvm::APFloat::S_PPCDoubleDouble: + return "__multc3"; + case llvm::APFloat::S_x87DoubleExtended: + return "__mulxc3"; + case llvm::APFloat::S_IEEEquad: + return "__multc3"; + default: + llvm_unreachable("unsupported floating point type"); + } +} + +static mlir::Value lowerComplexMul(LoweringPreparePass &pass, + CIRBaseBuilderTy &builder, + mlir::Location loc, cir::ComplexMulOp op, + mlir::Value lhsReal, mlir::Value lhsImag, + mlir::Value rhsReal, mlir::Value rhsImag) { + // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i + mlir::Value resultRealLhs = + builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal); + mlir::Value resultRealRhs = + builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag); + mlir::Value resultImagLhs = + builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag); + mlir::Value resultImagRhs = + builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal); + mlir::Value resultReal = builder.createBinop( + loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs); + mlir::Value resultImag = builder.createBinop( + loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs); + mlir::Value algebraicResult = + builder.createComplexCreate(loc, resultReal, resultImag); + + cir::ComplexType complexTy = op.getType(); + cir::ComplexRangeKind rangeKind = op.getRange(); + if (mlir::isa<cir::IntType>(complexTy.getElementType()) || + rangeKind == cir::ComplexRangeKind::Basic || + rangeKind == cir::ComplexRangeKind::Improved || + rangeKind == cir::ComplexRangeKind::Promoted) + return algebraicResult; + + assert(!cir::MissingFeatures::fastMathFlags()); + + // Check whether the real part and the imaginary part of the result are both + // NaN. If so, emit a library call to compute the multiplication instead. + // We check a value against NaN by comparing the value against itself. + mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal); + mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag); + mlir::Value resultRealAndImagAreNaN = + builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN); + + return builder + .create<cir::TernaryOp>( + loc, resultRealAndImagAreNaN, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value libCallResult = buildComplexBinOpLibCall( + pass, builder, &getComplexMulLibCallName, loc, complexTy, + lhsReal, lhsImag, rhsReal, rhsImag); + builder.createYield(loc, libCallResult); + }, + [&](mlir::OpBuilder &, mlir::Location) { + builder.createYield(loc, algebraicResult); + }) + .getResult(); +} + +void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) { + cir::CIRBaseBuilderTy builder(getContext()); + builder.setInsertionPointAfter(op); + mlir::Location loc = op.getLoc(); + mlir::TypedValue<cir::ComplexType> lhs = op.getLhs(); + mlir::TypedValue<cir::ComplexType> rhs = op.getRhs(); + mlir::Value lhsReal = builder.createComplexReal(loc, lhs); + mlir::Value lhsImag = builder.createComplexImag(loc, lhs); + mlir::Value rhsReal = builder.createComplexReal(loc, rhs); + mlir::Value rhsImag = builder.createComplexImag(loc, rhs); + mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal, + lhsImag, rhsReal, rhsImag); + op.replaceAllUsesWith(loweredResult); + op.erase(); +} + void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) { mlir::Type ty = op.getType(); if (!mlir::isa<cir::ComplexType>(ty)) @@ -154,7 +298,8 @@ void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) { case cir::UnaryOpKind::Plus: case cir::UnaryOpKind::Minus: - llvm_unreachable("Complex unary Plus/Minus NYI"); + resultReal = builder.createUnaryOp(loc, opKind, operandReal); + resultImag = builder.createUnaryOp(loc, opKind, operandImag); break; case cir::UnaryOpKind::Not: @@ -172,28 +317,30 @@ void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) { static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, - mlir::Value arrayAddr, - uint64_t arrayLen) { + mlir::Value arrayAddr, uint64_t arrayLen, + bool isCtor) { // Generate loop to call into ctor/dtor for every element. mlir::Location loc = op->getLoc(); - // TODO: instead of fixed integer size, create alias for PtrDiffTy and unify - // with CIRGen stuff. + // TODO: instead of getting the size from the AST context, create alias for + // PtrDiffTy and unify with CIRGen stuff. const unsigned sizeTypeSize = astCtx->getTypeSize(astCtx->getSignedSizeType()); - auto ptrDiffTy = - cir::IntType::get(builder.getContext(), sizeTypeSize, /*isSigned=*/false); - mlir::Value numArrayElementsConst = builder.getUnsignedInt(loc, arrayLen, 64); + uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1; + mlir::Value endOffsetVal = + builder.getUnsignedInt(loc, endOffset, sizeTypeSize); - auto begin = builder.create<cir::CastOp>( - loc, eltTy, cir::CastKind::array_to_ptrdecay, arrayAddr); - mlir::Value end = builder.create<cir::PtrStrideOp>(loc, eltTy, begin, - numArrayElementsConst); + auto begin = cir::CastOp::create(builder, loc, eltTy, + cir::CastKind::array_to_ptrdecay, arrayAddr); + mlir::Value end = + cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal); + mlir::Value start = isCtor ? begin : end; + mlir::Value stop = isCtor ? end : begin; mlir::Value tmpAddr = builder.createAlloca( loc, /*addr type*/ builder.getPointerTo(eltTy), /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1)); - builder.createStore(loc, begin, tmpAddr); + builder.createStore(loc, start, tmpAddr); cir::DoWhileOp loop = builder.createDoWhile( loc, @@ -202,7 +349,7 @@ static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); mlir::Type boolTy = cir::BoolType::get(b.getContext()); auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne, - currentElement, end); + currentElement, stop); builder.createCondition(cmp); }, /*bodyBuilder=*/ @@ -213,15 +360,19 @@ static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, op->walk([&](cir::CallOp c) { ctorCall = c; }); assert(ctorCall && "expected ctor call"); - auto one = builder.create<cir::ConstantOp>( - loc, ptrDiffTy, cir::IntAttr::get(ptrDiffTy, 1)); + // Array elements get constructed in order but destructed in reverse. + mlir::Value stride; + if (isCtor) + stride = builder.getUnsignedInt(loc, 1, sizeTypeSize); + else + stride = builder.getSignedInt(loc, -1, sizeTypeSize); - ctorCall->moveAfter(one); + ctorCall->moveBefore(stride.getDefiningOp()); ctorCall->setOperand(0, currentElement); + auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy, + currentElement, stride); - // Advance pointer and store them to temporary variable - auto nextElement = - builder.create<cir::PtrStrideOp>(loc, eltTy, currentElement, one); + // Store the element pointer to the temporary variable builder.createStore(loc, nextElement, tmpAddr); builder.createYield(loc); }); @@ -230,6 +381,18 @@ static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, op->erase(); } +void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) { + CIRBaseBuilderTy builder(getContext()); + builder.setInsertionPointAfter(op.getOperation()); + + mlir::Type eltTy = op->getRegion(0).getArgument(0).getType(); + assert(!cir::MissingFeatures::vlas()); + auto arrayLen = + mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize(); + lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, + false); +} + void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) { cir::CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op.getOperation()); @@ -238,26 +401,33 @@ void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) { assert(!cir::MissingFeatures::vlas()); auto arrayLen = mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize(); - lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), - arrayLen); + lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, + true); } void LoweringPreparePass::runOnOp(mlir::Operation *op) { if (auto arrayCtor = dyn_cast<ArrayCtor>(op)) lowerArrayCtor(arrayCtor); + else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) + lowerArrayDtor(arrayDtor); else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) lowerCastOp(cast); + else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) + lowerComplexMulOp(complexMul); else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op)) lowerUnaryOp(unary); } void LoweringPreparePass::runOnOperation() { mlir::Operation *op = getOperation(); + if (isa<::mlir::ModuleOp>(op)) + mlirModule = cast<::mlir::ModuleOp>(op); llvm::SmallVector<mlir::Operation *> opsToTransform; op->walk([&](mlir::Operation *op) { - if (mlir::isa<cir::ArrayCtor, cir::CastOp, cir::UnaryOp>(op)) + if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, + cir::ComplexMulOp, cir::UnaryOp>(op)) opsToTransform.push_back(op); }); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 3cd7de0..957a51a 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -521,6 +521,32 @@ mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite( return mlir::LogicalResult::success(); } +mlir::LogicalResult CIRToLLVMBitFfsOpLowering::matchAndRewrite( + cir::BitFfsOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto resTy = getTypeConverter()->convertType(op.getType()); + auto ctz = rewriter.create<mlir::LLVM::CountTrailingZerosOp>( + op.getLoc(), resTy, adaptor.getInput(), /*is_zero_poison=*/true); + + auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); + auto ctzAddOne = rewriter.create<mlir::LLVM::AddOp>(op.getLoc(), ctz, one); + + auto zeroInputTy = rewriter.create<mlir::LLVM::ConstantOp>( + op.getLoc(), adaptor.getInput().getType(), 0); + auto isZero = rewriter.create<mlir::LLVM::ICmpOp>( + op.getLoc(), + mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), + mlir::LLVM::ICmpPredicate::eq), + adaptor.getInput(), zeroInputTy); + + auto zero = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 0); + auto res = rewriter.create<mlir::LLVM::SelectOp>(op.getLoc(), isZero, zero, + ctzAddOne); + rewriter.replaceOp(op, res); + + return mlir::LogicalResult::success(); +} + mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite( cir::BitParityOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { @@ -919,13 +945,45 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, memoryEffects, noUnwind, willReturn); mlir::LLVM::LLVMFunctionType llvmFnTy; + + // Temporary to handle the case where we need to prepend an operand if the + // callee is an alias. + SmallVector<mlir::Value> adjustedCallOperands; + if (calleeAttr) { // direct call - mlir::FunctionOpInterface fn = - mlir::SymbolTable::lookupNearestSymbolFrom<mlir::FunctionOpInterface>( - op, calleeAttr); - assert(fn && "Did not find function for call"); - llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>( - converter->convertType(fn.getFunctionType())); + mlir::Operation *callee = + mlir::SymbolTable::lookupNearestSymbolFrom(op, calleeAttr); + if (auto fn = mlir::dyn_cast<mlir::FunctionOpInterface>(callee)) { + llvmFnTy = converter->convertType<mlir::LLVM::LLVMFunctionType>( + fn.getFunctionType()); + assert(llvmFnTy && "Failed to convert function type"); + } else if (auto alias = mlir::cast<mlir::LLVM::AliasOp>(callee)) { + // If the callee was an alias. In that case, + // we need to prepend the address of the alias to the operands. The + // way aliases work in the LLVM dialect is a little counter-intuitive. + // The AliasOp itself is a pseudo-function that returns the address of + // the global value being aliased, but when we generate the call we + // need to insert an operation that gets the address of the AliasOp. + // This all gets sorted out when the LLVM dialect is lowered to LLVM IR. + auto symAttr = mlir::cast<mlir::FlatSymbolRefAttr>(calleeAttr); + auto addrOfAlias = + mlir::LLVM::AddressOfOp::create( + rewriter, op->getLoc(), + mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), symAttr) + .getResult(); + adjustedCallOperands.push_back(addrOfAlias); + + // Now add the regular operands and assign this to the range value. + llvm::append_range(adjustedCallOperands, callOperands); + callOperands = adjustedCallOperands; + + // Clear the callee attribute because we're calling an alias. + calleeAttr = {}; + llvmFnTy = mlir::cast<mlir::LLVM::LLVMFunctionType>(alias.getType()); + } else { + // Was this an ifunc? + return op->emitError("Unexpected callee type!"); + } } else { // indirect call assert(!op->getOperands().empty() && "operands list must no be empty for the indirect call"); @@ -1027,6 +1085,12 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { mlir::Attribute attr = op.getValue(); + if (mlir::isa<cir::PoisonAttr>(attr)) { + rewriter.replaceOpWithNewOp<mlir::LLVM::PoisonOp>( + op, getTypeConverter()->convertType(op.getType())); + return mlir::success(); + } + if (mlir::isa<mlir::IntegerType>(op.getType())) { // Verified cir.const operations cannot actually be of these types, but the // lowering pass may generate temporary cir.const operations with these @@ -1166,6 +1230,30 @@ void CIRToLLVMFuncOpLowering::lowerFuncAttributes( } } +mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewriteAlias( + cir::FuncOp op, llvm::StringRef aliasee, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + SmallVector<mlir::NamedAttribute, 4> attributes; + lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes); + + mlir::Location loc = op.getLoc(); + auto aliasOp = rewriter.replaceOpWithNewOp<mlir::LLVM::AliasOp>( + op, ty, convertLinkage(op.getLinkage()), op.getName(), op.getDsoLocal(), + /*threadLocal=*/false, attributes); + + // Create the alias body + mlir::OpBuilder builder(op.getContext()); + mlir::Block *block = builder.createBlock(&aliasOp.getInitializerRegion()); + builder.setInsertionPointToStart(block); + // The type of AddressOfOp is always a pointer. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type ptrTy = mlir::LLVM::LLVMPointerType::get(ty.getContext()); + auto addrOp = mlir::LLVM::AddressOfOp::create(builder, loc, ptrTy, aliasee); + mlir::LLVM::ReturnOp::create(builder, loc, addrOp); + + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite( cir::FuncOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { @@ -1190,6 +1278,11 @@ mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite( resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()), signatureConversion.getConvertedTypes(), /*isVarArg=*/fnType.isVarArg()); + + // If this is an alias, it needs to be lowered to llvm::AliasOp. + if (std::optional<llvm::StringRef> aliasee = op.getAliasee()) + return matchAndRewriteAlias(op, *aliasee, llvmFnTy, adaptor, rewriter); + // LLVMFuncOp expects a single FileLine Location instead of a fused // location. mlir::Location loc = op.getLoc(); @@ -2083,6 +2176,7 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMBitClrsbOpLowering, CIRToLLVMBitClzOpLowering, CIRToLLVMBitCtzOpLowering, + CIRToLLVMBitFfsOpLowering, CIRToLLVMBitParityOpLowering, CIRToLLVMBitPopcountOpLowering, CIRToLLVMBitReverseOpLowering, diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 2911ced..f339d43 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -84,6 +84,16 @@ public: mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMBitFfsOpLowering + : public mlir::OpConversionPattern<cir::BitFfsOp> { +public: + using mlir::OpConversionPattern<cir::BitFfsOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::BitFfsOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + class CIRToLLVMBitParityOpLowering : public mlir::OpConversionPattern<cir::BitParityOp> { public: @@ -257,6 +267,11 @@ class CIRToLLVMFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> { cir::FuncOp func, bool filterArgAndResAttrs, mlir::SmallVectorImpl<mlir::NamedAttribute> &result) const; + mlir::LogicalResult + matchAndRewriteAlias(cir::FuncOp op, llvm::StringRef aliasee, mlir::Type ty, + OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const; + public: using mlir::OpConversionPattern<cir::FuncOp>::OpConversionPattern; |