//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "clang/AST/ASTContext.h" #include "clang/Basic/TargetInfo.h" #include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/IR/CIROpsEnums.h" #include "clang/CIR/Dialect/Passes.h" #include "clang/CIR/MissingFeatures.h" #include using namespace mlir; using namespace cir; namespace { struct LoweringPreparePass : public LoweringPrepareBase { LoweringPreparePass() = default; void runOnOperation() override; void runOnOp(mlir::Operation *op); void lowerCastOp(cir::CastOp op); void lowerComplexDivOp(cir::ComplexDivOp op); void lowerComplexMulOp(cir::ComplexMulOp op); void lowerUnaryOp(cir::UnaryOp op); void lowerArrayDtor(cir::ArrayDtor op); void lowerArrayCtor(cir::ArrayCtor op); cir::FuncOp buildRuntimeFunction( mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, cir::FuncType type, cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage); /// /// AST related /// ----------- clang::ASTContext *astCtx; /// Tracks current module. mlir::ModuleOp mlirModule; void setASTContext(clang::ASTContext *c) { astCtx = c; } }; } // namespace cir::FuncOp LoweringPreparePass::buildRuntimeFunction( mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, cir::FuncType type, cir::GlobalLinkageKind linkage) { cir::FuncOp f = dyn_cast_or_null(SymbolTable::lookupNearestSymbolFrom( mlirModule, StringAttr::get(mlirModule->getContext(), name))); if (!f) { f = builder.create(loc, name, type); f.setLinkageAttr( cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); mlir::SymbolTable::setSymbolVisibility( f, mlir::SymbolTable::Visibility::Private); assert(!cir::MissingFeatures::opFuncExtraAttrs()); } return f; } static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op) { cir::CIRBaseBuilderTy builder(ctx); builder.setInsertionPoint(op); mlir::Value src = op.getSrc(); mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc()); return builder.createComplexCreate(op.getLoc(), src, imag); } static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind) { cir::CIRBaseBuilderTy builder(ctx); builder.setInsertionPoint(op); mlir::Value src = op.getSrc(); if (!mlir::isa(op.getType())) return builder.createComplexReal(op.getLoc(), src); // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src); mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src); cir::BoolType boolTy = builder.getBoolTy(); mlir::Value srcRealToBool = builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy); mlir::Value srcImagToBool = builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy); return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool); } static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind) { CIRBaseBuilderTy builder(ctx); builder.setInsertionPoint(op); mlir::Value src = op.getSrc(); auto dstComplexElemTy = mlir::cast(op.getType()).getElementType(); mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src); mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src); mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal, dstComplexElemTy); mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag, dstComplexElemTy); return builder.createComplexCreate(op.getLoc(), dstReal, dstImag); } void LoweringPreparePass::lowerCastOp(cir::CastOp op) { mlir::MLIRContext &ctx = getContext(); mlir::Value loweredValue = [&]() -> mlir::Value { switch (op.getKind()) { case cir::CastKind::float_to_complex: case cir::CastKind::int_to_complex: return lowerScalarToComplexCast(ctx, op); case cir::CastKind::float_complex_to_real: case cir::CastKind::int_complex_to_real: return lowerComplexToScalarCast(ctx, op, op.getKind()); case cir::CastKind::float_complex_to_bool: return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool); case cir::CastKind::int_complex_to_bool: return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool); case cir::CastKind::float_complex: return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating); case cir::CastKind::float_complex_to_int_complex: return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int); case cir::CastKind::int_complex: return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral); case cir::CastKind::int_complex_to_float_complex: return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float); default: return nullptr; } }(); if (loweredValue) { op.replaceAllUsesWith(loweredValue); op.erase(); } } static mlir::Value buildComplexBinOpLibCall( LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { cir::FPTypeInterface elementTy = mlir::cast(ty.getElementType()); llvm::StringRef libFuncName = libFuncNameGetter( llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics())); llvm::SmallVector libFuncInputTypes(4, elementTy); cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty); // Insert a declaration for the runtime function to be used in Complex // multiplication and division when needed cir::FuncOp libFunc; { mlir::OpBuilder::InsertionGuard ipGuard{builder}; builder.setInsertionPointToStart(pass.mlirModule.getBody()); libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy); } cir::CallOp call = builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag}); return call.getResult(); } static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics) { switch (semantics) { case llvm::APFloat::S_IEEEhalf: return "__divhc3"; case llvm::APFloat::S_IEEEsingle: return "__divsc3"; case llvm::APFloat::S_IEEEdouble: return "__divdc3"; case llvm::APFloat::S_PPCDoubleDouble: return "__divtc3"; case llvm::APFloat::S_x87DoubleExtended: return "__divxc3"; case llvm::APFloat::S_IEEEquad: return "__divtc3"; default: llvm_unreachable("unsupported floating point type"); } } static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i mlir::Value &a = lhsReal; mlir::Value &b = lhsImag; mlir::Value &c = rhsReal; mlir::Value &d = rhsImag; mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d mlir::Value acbd = builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd mlir::Value ccdd = builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd mlir::Value resultReal = builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd); mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d mlir::Value bcad = builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad mlir::Value resultImag = builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd); return builder.createComplexCreate(loc, resultReal, resultImag); } static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { // Implements Smith's algorithm for complex division. // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962). // Let: // - lhs := a+bi // - rhs := c+di // - result := lhs / rhs = e+fi // // The algorithm pseudocode looks like follows: // if fabs(c) >= fabs(d): // r := d / c // tmp := c + r*d // e = (a + b*r) / tmp // f = (b - a*r) / tmp // else: // r := c / d // tmp := d + r*c // e = (a*r + b) / tmp // f = (b*r - a) / tmp mlir::Value &a = lhsReal; mlir::Value &b = lhsImag; mlir::Value &c = rhsReal; mlir::Value &d = rhsImag; auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) { mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div, c); // r := d / c mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add, rd); // tmp := c + r*d mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r mlir::Value abr = builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp); mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r mlir::Value bar = builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp); mlir::Value result = builder.createComplexCreate(loc, e, f); builder.createYield(loc, result); }; auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) { mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div, d); // r := c / d mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add, rc); // tmp := d + r*c mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r mlir::Value arb = builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp); mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r mlir::Value bra = builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp); mlir::Value result = builder.createComplexCreate(loc, e, f); builder.createYield(loc, result); }; auto cFabs = builder.create(loc, c); auto dFabs = builder.create(loc, d); cir::CmpOp cmpResult = builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs); auto ternary = builder.create( loc, cmpResult, trueBranchBuilder, falseBranchBuilder); return ternary.getResult(); } static mlir::Type higherPrecisionElementTypeForComplexArithmetic( mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType) { auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type { if (mlir::isa(type)) return cir::SingleType::get(&context); if (mlir::isa(type) || mlir::isa(type)) return cir::DoubleType::get(&context); if (mlir::isa(type)) return cir::LongDoubleType::get(&context, type); return type; }; auto getFloatTypeSemantics = [&cc](mlir::Type type) -> const llvm::fltSemantics & { const clang::TargetInfo &info = cc.getTargetInfo(); if (mlir::isa(type)) return info.getHalfFormat(); if (mlir::isa(type)) return info.getBFloat16Format(); if (mlir::isa(type)) return info.getFloatFormat(); if (mlir::isa(type)) return info.getDoubleFormat(); if (mlir::isa(type)) { if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice) llvm_unreachable("NYI Float type semantics with OpenMP"); return info.getLongDoubleFormat(); } if (mlir::isa(type)) { if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice) llvm_unreachable("NYI Float type semantics with OpenMP"); return info.getFloat128Format(); } assert(false && "Unsupported float type semantics"); }; const mlir::Type higherElementType = getHigherPrecisionFPType(elementType); const llvm::fltSemantics &elementTypeSemantics = getFloatTypeSemantics(elementType); const llvm::fltSemantics &higherElementTypeSemantics = getFloatTypeSemantics(higherElementType); // Check that the promoted type can handle the intermediate values without // overflowing. This can be interpreted as: // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <= // LargerType.LargestFiniteVal. // In terms of exponent it gives this formula: // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal // doubles the exponent of SmallerType.LargestFiniteVal) if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <= llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) { return higherElementType; } // The intermediate values can't be represented in the promoted type // without overflowing. return {}; } static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc) { cir::ComplexType complexTy = op.getType(); if (mlir::isa(complexTy.getElementType())) { cir::ComplexRangeKind range = op.getRange(); if (range == cir::ComplexRangeKind::Improved) return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal, rhsImag); if (range == cir::ComplexRangeKind::Full) return buildComplexBinOpLibCall(pass, builder, &getComplexDivLibCallName, loc, complexTy, lhsReal, lhsImag, rhsReal, rhsImag); if (range == cir::ComplexRangeKind::Promoted) { mlir::Type originalElementType = complexTy.getElementType(); mlir::Type higherPrecisionElementType = higherPrecisionElementTypeForComplexArithmetic(mlirCx, cc, builder, originalElementType); if (!higherPrecisionElementType) return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal, rhsImag); cir::CastKind floatingCastKind = cir::CastKind::floating; lhsReal = builder.createCast(floatingCastKind, lhsReal, higherPrecisionElementType); lhsImag = builder.createCast(floatingCastKind, lhsImag, higherPrecisionElementType); rhsReal = builder.createCast(floatingCastKind, rhsReal, higherPrecisionElementType); rhsImag = builder.createCast(floatingCastKind, rhsImag, higherPrecisionElementType); mlir::Value algebraicResult = buildAlgebraicComplexDiv( builder, loc, lhsReal, lhsImag, rhsReal, rhsImag); mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult); mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult); mlir::Value finalReal = builder.createCast(floatingCastKind, resultReal, originalElementType); mlir::Value finalImag = builder.createCast(floatingCastKind, resultImag, originalElementType); return builder.createComplexCreate(loc, finalReal, finalImag); } } return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal, rhsImag); } void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) { cir::CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op); mlir::Location loc = op.getLoc(); mlir::TypedValue lhs = op.getLhs(); mlir::TypedValue rhs = op.getRhs(); mlir::Value lhsReal = builder.createComplexReal(loc, lhs); mlir::Value lhsImag = builder.createComplexImag(loc, lhs); mlir::Value rhsReal = builder.createComplexReal(loc, rhs); mlir::Value rhsImag = builder.createComplexImag(loc, rhs); mlir::Value loweredResult = lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal, rhsImag, getContext(), *astCtx); op.replaceAllUsesWith(loweredResult); op.erase(); } static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics) { switch (semantics) { case llvm::APFloat::S_IEEEhalf: return "__mulhc3"; case llvm::APFloat::S_IEEEsingle: return "__mulsc3"; case llvm::APFloat::S_IEEEdouble: return "__muldc3"; case llvm::APFloat::S_PPCDoubleDouble: return "__multc3"; case llvm::APFloat::S_x87DoubleExtended: return "__mulxc3"; case llvm::APFloat::S_IEEEquad: return "__multc3"; default: llvm_unreachable("unsupported floating point type"); } } static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i mlir::Value resultRealLhs = builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal); mlir::Value resultRealRhs = builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag); mlir::Value resultImagLhs = builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag); mlir::Value resultImagRhs = builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal); mlir::Value resultReal = builder.createBinop( loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs); mlir::Value resultImag = builder.createBinop( loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs); mlir::Value algebraicResult = builder.createComplexCreate(loc, resultReal, resultImag); cir::ComplexType complexTy = op.getType(); cir::ComplexRangeKind rangeKind = op.getRange(); if (mlir::isa(complexTy.getElementType()) || rangeKind == cir::ComplexRangeKind::Basic || rangeKind == cir::ComplexRangeKind::Improved || rangeKind == cir::ComplexRangeKind::Promoted) return algebraicResult; assert(!cir::MissingFeatures::fastMathFlags()); // Check whether the real part and the imaginary part of the result are both // NaN. If so, emit a library call to compute the multiplication instead. // We check a value against NaN by comparing the value against itself. mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal); mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag); mlir::Value resultRealAndImagAreNaN = builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN); return builder .create( loc, resultRealAndImagAreNaN, [&](mlir::OpBuilder &, mlir::Location) { mlir::Value libCallResult = buildComplexBinOpLibCall( pass, builder, &getComplexMulLibCallName, loc, complexTy, lhsReal, lhsImag, rhsReal, rhsImag); builder.createYield(loc, libCallResult); }, [&](mlir::OpBuilder &, mlir::Location) { builder.createYield(loc, algebraicResult); }) .getResult(); } void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) { cir::CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op); mlir::Location loc = op.getLoc(); mlir::TypedValue lhs = op.getLhs(); mlir::TypedValue rhs = op.getRhs(); mlir::Value lhsReal = builder.createComplexReal(loc, lhs); mlir::Value lhsImag = builder.createComplexImag(loc, lhs); mlir::Value rhsReal = builder.createComplexReal(loc, rhs); mlir::Value rhsImag = builder.createComplexImag(loc, rhs); mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal, lhsImag, rhsReal, rhsImag); op.replaceAllUsesWith(loweredResult); op.erase(); } void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) { mlir::Type ty = op.getType(); if (!mlir::isa(ty)) return; mlir::Location loc = op.getLoc(); cir::UnaryOpKind opKind = op.getKind(); CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op); mlir::Value operand = op.getInput(); mlir::Value operandReal = builder.createComplexReal(loc, operand); mlir::Value operandImag = builder.createComplexImag(loc, operand); mlir::Value resultReal; mlir::Value resultImag; switch (opKind) { case cir::UnaryOpKind::Inc: case cir::UnaryOpKind::Dec: resultReal = builder.createUnaryOp(loc, opKind, operandReal); resultImag = operandImag; break; case cir::UnaryOpKind::Plus: case cir::UnaryOpKind::Minus: resultReal = builder.createUnaryOp(loc, opKind, operandReal); resultImag = builder.createUnaryOp(loc, opKind, operandImag); break; case cir::UnaryOpKind::Not: resultReal = operandReal; resultImag = builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag); break; } mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag); op.replaceAllUsesWith(result); op.erase(); } static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value arrayAddr, uint64_t arrayLen, bool isCtor) { // Generate loop to call into ctor/dtor for every element. mlir::Location loc = op->getLoc(); // TODO: instead of getting the size from the AST context, create alias for // PtrDiffTy and unify with CIRGen stuff. const unsigned sizeTypeSize = astCtx->getTypeSize(astCtx->getSignedSizeType()); uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1; mlir::Value endOffsetVal = builder.getUnsignedInt(loc, endOffset, sizeTypeSize); auto begin = cir::CastOp::create(builder, loc, eltTy, cir::CastKind::array_to_ptrdecay, arrayAddr); mlir::Value end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal); mlir::Value start = isCtor ? begin : end; mlir::Value stop = isCtor ? end : begin; mlir::Value tmpAddr = builder.createAlloca( loc, /*addr type*/ builder.getPointerTo(eltTy), /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1)); builder.createStore(loc, start, tmpAddr); cir::DoWhileOp loop = builder.createDoWhile( loc, /*condBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { auto currentElement = b.create(loc, eltTy, tmpAddr); mlir::Type boolTy = cir::BoolType::get(b.getContext()); auto cmp = builder.create(loc, boolTy, cir::CmpOpKind::ne, currentElement, stop); builder.createCondition(cmp); }, /*bodyBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { auto currentElement = b.create(loc, eltTy, tmpAddr); cir::CallOp ctorCall; op->walk([&](cir::CallOp c) { ctorCall = c; }); assert(ctorCall && "expected ctor call"); // Array elements get constructed in order but destructed in reverse. mlir::Value stride; if (isCtor) stride = builder.getUnsignedInt(loc, 1, sizeTypeSize); else stride = builder.getSignedInt(loc, -1, sizeTypeSize); ctorCall->moveBefore(stride.getDefiningOp()); ctorCall->setOperand(0, currentElement); auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy, currentElement, stride); // Store the element pointer to the temporary variable builder.createStore(loc, nextElement, tmpAddr); builder.createYield(loc); }); op->replaceAllUsesWith(loop); op->erase(); } void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) { CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op.getOperation()); mlir::Type eltTy = op->getRegion(0).getArgument(0).getType(); assert(!cir::MissingFeatures::vlas()); auto arrayLen = mlir::cast(op.getAddr().getType().getPointee()).getSize(); lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, false); } void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) { cir::CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op.getOperation()); mlir::Type eltTy = op->getRegion(0).getArgument(0).getType(); assert(!cir::MissingFeatures::vlas()); auto arrayLen = mlir::cast(op.getAddr().getType().getPointee()).getSize(); lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, true); } void LoweringPreparePass::runOnOp(mlir::Operation *op) { if (auto arrayCtor = dyn_cast(op)) lowerArrayCtor(arrayCtor); else if (auto arrayDtor = dyn_cast(op)) lowerArrayDtor(arrayDtor); else if (auto cast = mlir::dyn_cast(op)) lowerCastOp(cast); else if (auto complexDiv = mlir::dyn_cast(op)) lowerComplexDivOp(complexDiv); else if (auto complexMul = mlir::dyn_cast(op)) lowerComplexMulOp(complexMul); else if (auto unary = mlir::dyn_cast(op)) lowerUnaryOp(unary); } void LoweringPreparePass::runOnOperation() { mlir::Operation *op = getOperation(); if (isa<::mlir::ModuleOp>(op)) mlirModule = cast<::mlir::ModuleOp>(op); llvm::SmallVector opsToTransform; op->walk([&](mlir::Operation *op) { if (mlir::isa(op)) opsToTransform.push_back(op); }); for (mlir::Operation *o : opsToTransform) runOnOp(o); } std::unique_ptr mlir::createLoweringPreparePass() { return std::make_unique(); } std::unique_ptr mlir::createLoweringPreparePass(clang::ASTContext *astCtx) { auto pass = std::make_unique(); pass->setASTContext(astCtx); return std::move(pass); }