diff options
Diffstat (limited to 'mlir/lib/Conversion')
90 files changed, 5386 insertions, 4913 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ef35ee2..64720bf 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, if (i32 == valTy) return val; return valTy.getWidth() > 32 - ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val)) - : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val)); + ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) + : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); } static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type i32 = rewriter.getI32Type(); - return rewriter.create<LLVM::ConstantOp>(loc, i32, value); + return LLVM::ConstantOp::create(rewriter, loc, i32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value) { Type llvmI1 = rewriter.getI1Type(); - return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); + return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); } /// Returns the linear index used to access an element in the memref. @@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, ShapedType::isDynamic(stride) ? convertUnsignedToI32(rewriter, loc, memRefDescriptor.stride(rewriter, loc, i)) - : rewriter.create<LLVM::ConstantOp>(loc, i32, stride); - increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue); + : LLVM::ConstantOp::create(rewriter, loc, i32, stride); + increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); } - index = - index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; + index = index ? LLVM::AddOp::create(rewriter, loc, index, increment) + : increment; } return index ? index : createI32Constant(rewriter, loc, 0); } @@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); - Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); + Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); maxIndex = maxIndex - ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim) + ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) : maxThisDim; } Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); - return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst); + return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst); } static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, @@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value stride; if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { Value cacheStrideZext = - rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride); - Value swizzleBit = rewriter.create<LLVM::ConstantOp>( - loc, i16, rewriter.getI16IntegerAttr(1 << 14)); - stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit, - /*isDisjoint=*/true); + LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); + Value swizzleBit = LLVM::ConstantOp::create( + rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); + stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, + /*isDisjoint=*/true); } else { - stride = rewriter.create<LLVM::ConstantOp>(loc, i16, - rewriter.getI16IntegerAttr(0)); + stride = LLVM::ConstantOp::create(rewriter, loc, i16, + rewriter.getI16IntegerAttr(0)); } // Get the number of elements. // Flag word: @@ -209,20 +209,21 @@ struct FatRawBufferCastLowering : descriptor.alignedPtr(rewriter, loc); Value offset = adaptor.getResetOffset() - ? rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(0)) + ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0)) : descriptor.offset(rewriter, loc); bool hasSizes = memrefType.getRank() > 0; // No need to unpack() and pack() all the individual sizes and strides, // so we'll just extract the arrays. - Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>( - loc, descriptor, kSizePosInMemRefDescriptor) - : Value{}; - Value strides = hasSizes - ? rewriter.create<LLVM::ExtractValueOp>( - loc, descriptor, kStridePosInMemRefDescriptor) - : Value{}; + Value sizes = hasSizes + ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kSizePosInMemRefDescriptor) + : Value{}; + Value strides = + hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kStridePosInMemRefDescriptor) + : Value{}; Value fatPtr = makeBufferRsrc( rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), @@ -231,17 +232,17 @@ struct FatRawBufferCastLowering Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset, - kOffsetPosInMemRefDescriptor); + SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor}; + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, + kAlignedPtrPosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, + kOffsetPosInMemRefDescriptor); if (hasSizes) { - result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes, - kSizePosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, strides, kStridePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, + kSizePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, + kStridePosInMemRefDescriptor); } rewriter.replaceOp(op, result); return success(); @@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { SmallVector<Value, 6> args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForStore = - rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); + Value castForStore = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); @@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForCmp = rewriter.create<LLVM::BitcastOp>( - loc, llvmBufferValType, atomicCmpData); + Value castForCmp = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { args.push_back(atomicCmpData); @@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset(); indexOffset && *indexOffset > 0) { Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); - voffset = - voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) - : extraOffsetConst; + voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset, + extraOffsetConst) + : extraOffsetConst; } - voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst); + voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); args.push_back(voffset); // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst); + sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) @@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), llvmBufferValType); - Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { - replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, - replacement); + replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, + replacement); } rewriter.replaceOp(gpuOp, replacement); } else { @@ -419,6 +420,112 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { } }; +// TODO: AMDGPU backend already have all this bitpacking logic, we should move +// it to some common place. +/// Vmcnt, Expcnt and Lgkmcnt are decoded as follows: +/// Vmcnt = Waitcnt[3:0] (pre-gfx9) +/// Vmcnt = Waitcnt[15:14,3:0] (gfx9,10) +/// Vmcnt = Waitcnt[15:10] (gfx11) +/// Expcnt = Waitcnt[6:4] (pre-gfx11) +/// Expcnt = Waitcnt[2:0] (gfx11) +/// Lgkmcnt = Waitcnt[11:8] (pre-gfx10) +/// Lgkmcnt = Waitcnt[13:8] (gfx10) +/// Lgkmcnt = Waitcnt[9:4] (gfx11) +static FailureOr<unsigned> encodeWaitcnt(Chipset chipset, unsigned vmcnt, + unsigned expcnt, unsigned lgkmcnt) { + if (chipset.majorVersion < 9) { + vmcnt = std::min(15u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + return vmcnt | (expcnt << 4) | (lgkmcnt << 8); + } + if (chipset.majorVersion == 9) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 10) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 11) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + return (vmcnt << 10) | expcnt | (lgkmcnt << 4); + } + return failure(); +} + +struct MemoryCounterWaitOpLowering + : public ConvertOpToLLVMPattern<MemoryCounterWaitOp> { + MemoryCounterWaitOpLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter), + chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset.majorVersion >= 12) { + Location loc = op.getLoc(); + if (std::optional<int> ds = adaptor.getDs()) + ROCDL::WaitDscntOp::create(rewriter, loc, *ds); + + if (std::optional<int> load = adaptor.getLoad()) + ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); + + if (std::optional<int> store = adaptor.getStore()) + ROCDL::WaitStorecntOp::create(rewriter, loc, *store); + + if (std::optional<int> exp = adaptor.getExp()) + ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); + + rewriter.eraseOp(op); + return success(); + } + + auto getVal = [](Attribute attr) -> unsigned { + if (attr) + return cast<IntegerAttr>(attr).getInt(); + + // This value will be clamped to the maximum value for the chipset. + return 1024; + }; + unsigned ds = getVal(adaptor.getDsAttr()); + unsigned exp = getVal(adaptor.getExpAttr()); + + unsigned vmcnt = 1024; + Attribute load = adaptor.getLoadAttr(); + Attribute store = adaptor.getStoreAttr(); + if (load && store) { + vmcnt = getVal(load) + getVal(store); + } else if (load) { + vmcnt = getVal(load); + } else if (store) { + vmcnt = getVal(store); + } + + FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds); + if (failed(waitcnt)) + return op.emitOpError("unsupported chipset"); + + rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt); + return success(); + } +}; + struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {} @@ -465,12 +572,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { << chipset.majorVersion; Location loc = op->getLoc(); - rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits); + ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits); rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); } else { Location loc = op->getLoc(); - rewriter.create<ROCDL::WaitDscntOp>(loc, 0); - rewriter.create<ROCDL::BarrierSignalOp>(loc, -1); + ROCDL::WaitDscntOp::create(rewriter, loc, 0); + ROCDL::BarrierSignalOp::create(rewriter, loc, -1); rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1); } @@ -516,19 +623,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); if (auto vectorType = dyn_cast<VectorType>(inputType)) { if (vectorType.getElementType().isBF16() && !allowBf16) - return rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && vectorType.getNumElements() <= 8) - return rewriter.create<LLVM::BitcastOp>( - loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); + return LLVM::BitcastOp::create( + rewriter, loc, + rewriter.getIntegerType(vectorType.getNumElements() * 8), input); if (isa<IntegerType>(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); - return rewriter.create<LLVM::BitcastOp>( - loc, VectorType::get(numWords, rewriter.getI32Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), + input); } } return input; @@ -549,8 +658,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); Type outputType = rewriter.getI32Type(); if (auto intType = dyn_cast<IntegerType>(inputType)) - return rewriter.create<LLVM::ZExtOp>(loc, outputType, input); - return rewriter.create<LLVM::BitcastOp>(loc, outputType, input); + return LLVM::ZExtOp::create(rewriter, loc, outputType, input); + return LLVM::BitcastOp::create(rewriter, loc, outputType, input); } /// Push an input operand. If it is a float type, nothing to do. If it is @@ -576,8 +685,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - llvmInput = rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), llvmInput); + llvmInput = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -613,7 +722,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. // Add in the zeros here. if (numBits < 32) - castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput); + castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); operands.push_back(castInput); } @@ -633,8 +742,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, auto vectorType = dyn_cast<VectorType>(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - output = rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), output); + output = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); @@ -992,7 +1101,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { }; Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) - lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); + lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } @@ -1092,8 +1201,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { Operation *maybeCastBack = lowered; if (rawOutType != outType) - maybeCastBack = - rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0)); + maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType, + lowered->getResult(0)); rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); @@ -1143,22 +1252,22 @@ struct TransposeLoadOpLowering switch (elementTypeSize) { case 4: { assert(numElements == 16); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } case 6: { assert(numElements == 16); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } case 8: { assert(numElements == 8); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } @@ -1316,21 +1425,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { - Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8); + Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); if (!sourceVecType) { - longVec = rewriter.create<LLVM::InsertElementOp>( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (resultVecType) { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source, @@ -1382,21 +1491,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( // Extend to a packedVectorType if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { - Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType); + Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); if (!sourceVecType) { - longVec = rewriter.create<LLVM::InsertElementOp>( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>( @@ -1454,54 +1563,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( Value scale = adaptor.getScale(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); else - existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType); + existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); if (sourceVecType.getNumElements() < 2) { Value c0 = createI32Constant(rewriter, loc, 0); - Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); + Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); VectorType v2 = VectorType::get(2, sourceElemType); - source = rewriter.create<LLVM::ZeroOp>(loc, v2); - source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0); + source = LLVM::ZeroOp::create(rewriter, loc, v2); + source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); } Value sourceA, sourceB; if (sourceElemType.isF32()) { Value c0 = createI32Constant(rewriter, loc, 0); Value c1 = createI32Constant(rewriter, loc, 1); - sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); - sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1); + sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); + sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); } Value result; if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else return failure(); @@ -1526,20 +1638,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) - sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType()); + sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create<LLVM::UndefOp>(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( op, getTypeConverter()->convertType(resultType), result); @@ -1563,17 +1675,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create<LLVM::UndefOp>(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtSrBf8F32Op>( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtSrFp8F32Op>( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( op, getTypeConverter()->convertType(resultType), result); @@ -1617,14 +1729,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { if (operandType.getIntOrFloatBitWidth() <= 16) { if (llvm::isa<FloatType>(operandType)) { operand = - rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand); + LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); } auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); - Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType); - operand = rewriter.create<LLVM::InsertElementOp>( - loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); - operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand); + Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); + operand = + LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, + createI32Constant(rewriter, loc, 0)); + operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); } return operand; }; @@ -1711,14 +1824,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue(); // create a ROCDL_DPPMovOp instruction with the appropriate attributes - auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>( - loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); + auto dppMovOp = + ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, + rowMask, bankMask, boundCtrl); Value result = dppMovOp.getRes(); if (srcType.getIntOrFloatBitWidth() < 32) { - result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result); + result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); if (!llvm::isa<IntegerType>(srcType)) { - result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result); + result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); } } @@ -1752,7 +1866,7 @@ struct AMDGPUSwizzleBitModeLowering SmallVector<Value> swizzled; for (Value v : decomposed) { Value res = - rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue); + ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); swizzled.emplace_back(res); } @@ -1825,9 +1939,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicUminOp>, RawBufferOpLowering<RawBufferAtomicCmpswapOp, ROCDL::RawPtrBufferAtomicCmpSwap>, - AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, - MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, - ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, + AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, TransposeLoadOpLowering>(converter, chipset); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 3b143ca..3b148f9 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc, Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { if (predicate == arith::CmpIPredicate::sgt) - value = builder.create<arith::MaxSIOp>(loc, value, *valueIt); + value = arith::MaxSIOp::create(builder, loc, value, *valueIt); else - value = builder.create<arith::MinSIOp>(loc, value, *valueIt); + value = arith::MinSIOp::create(builder, loc, value, *valueIt); } return value; @@ -154,9 +154,9 @@ public: Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); Value step = - rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt()); - auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, - step, op.getInits()); + arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt()); + auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, + step, op.getInits()); rewriter.eraseBlock(scfForOp.getBody()); rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(), scfForOp.getRegion().end()); @@ -197,7 +197,7 @@ public: } steps.reserve(op.getSteps().size()); for (int64_t step : op.getSteps()) - steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step)); + steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step)); // Get the terminator op. auto affineParOpTerminator = @@ -205,9 +205,9 @@ public: scf::ParallelOp parOp; if (op.getResults().empty()) { // Case with no reduction operations/return values. - parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple, - upperBoundTuple, steps, - /*bodyBuilderFn=*/nullptr); + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, + upperBoundTuple, steps, + /*bodyBuilderFn=*/nullptr); rewriter.eraseBlock(parOp.getBody()); rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(), parOp.getRegion().end()); @@ -233,9 +233,9 @@ public: identityVals.push_back( arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); } - parOp = rewriter.create<scf::ParallelOp>( - loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, - /*bodyBuilderFn=*/nullptr); + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, + upperBoundTuple, steps, identityVals, + /*bodyBuilderFn=*/nullptr); // Copy the body of the affine.parallel op. rewriter.eraseBlock(parOp.getBody()); @@ -261,7 +261,7 @@ public: Value reductionResult = arith::getReductionOp( reductionOpValue, rewriter, loc, reductionBody.getArgument(0), reductionBody.getArgument(1)); - rewriter.create<scf::ReduceReturnOp>(loc, reductionResult); + scf::ReduceReturnOp::create(rewriter, loc, reductionResult); } rewriter.replaceOp(op, parOp.getResults()); return success(); @@ -278,7 +278,7 @@ public: // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector<Value, 8> operands(op.getOperands()); auto operandsRef = llvm::ArrayRef(operands); @@ -298,18 +298,18 @@ public: auto pred = isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; Value cmpVal = - rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant); - cond = cond - ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult() - : cmpVal; + arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant); + cond = + cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult() + : cmpVal; } cond = cond ? cond - : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1, - /*width=*/1); + : arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, + /*width=*/1); bool hasElseRegion = !op.getElseRegion().empty(); - auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond, - hasElseRegion); + auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond, + hasElseRegion); rewriter.inlineRegionBefore(op.getThenRegion(), &ifOp.getThenRegion().back()); rewriter.eraseBlock(&ifOp.getThenRegion().back()); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index cf9bb3a..8230591 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { using OpRewritePattern::OpRewritePattern; Chipset chipset; - ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} + ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset, + PatternBenefit benefit) + : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {} LogicalResult matchAndRewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; @@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { bool saturateFP8 = false; TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, - Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), - chipset(chipset) {} + Chipset chipset, PatternBenefit benefit) + : OpRewritePattern::OpRewritePattern(ctx, benefit), + saturateFP8(saturateFP8), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(arith::TruncFOp op, @@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> { using OpRewritePattern::OpRewritePattern; - ScalingExtFRewritePattern(MLIRContext *ctx) - : OpRewritePattern::OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const override; }; @@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> { using OpRewritePattern::OpRewritePattern; - ScalingTruncFRewritePattern(MLIRContext *ctx) - : OpRewritePattern::OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const override; }; @@ -115,9 +110,9 @@ static Value castF32To(Type desType, Value f32, Location loc, if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) - return rewriter.create<arith::TruncFOp>(loc, desType, f32); + return arith::TruncFOp::create(rewriter, loc, desType, f32); if (elementType.getIntOrFloatBitWidth() > 32) - return rewriter.create<arith::ExtFOp>(loc, desType, f32); + return arith::ExtFOp::create(rewriter, loc, desType, f32); llvm_unreachable("The only 32-bit float type is f32"); } @@ -139,64 +134,64 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Type outElemType = getElementTypeOrSelf(op.getOut().getType()); VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { - Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, rewriter.getF32Type(), in, 0); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); rewriter.replaceOp(op, result); return success(); } int64_t numElements = inVecType.getNumElements(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); VectorType outType = cast<VectorType>(op.getOut().getType()); if (inVecType.getShape().empty()) { Value zerodSplat = - rewriter.createOrFold<vector::SplatOp>(loc, outType, zero); + rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero); Value scalarIn = - rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{}); Value scalarExt = - rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); - Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat, - ArrayRef<int64_t>{}); + arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarExt, + zerodSplat, ArrayRef<int64_t>{}); rewriter.replaceOp(op, result); return success(); } VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, outType.getElementType()); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); + Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero); if (inVecType.getRank() > 1) { inVecType = VectorType::get(SmallVector<int64_t>{numElements}, inVecType.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; - Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, i, elemsThisOp, 1); + Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i, + elemsThisOp, 1); for (int64_t j = 0; j < elemsThisOp; j += 2) { if (i + j + 1 < numElements) { // Convert two 8-bit elements - Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, extResType, inSlice, j / 2); + Value asFloats = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, extResType, inSlice, j / 2); Type desType = VectorType::get(2, outElemType); Value asType = castF32To(desType, asFloats, loc, rewriter); - result = rewriter.create<vector::InsertStridedSliceOp>( - loc, asType, result, i + j, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, asType, + result, i + j, 1); } else { // Convert a 8-bit element - Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, rewriter.getF32Type(), inSlice, j / 2 * 2); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2); Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); + result = vector::InsertOp::create(rewriter, loc, asType, result, i + j); } } } if (inVecType.getRank() != outType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outType, result); } rewriter.replaceOp(op, result); @@ -208,9 +203,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { if (type.isF32()) return value; if (type.getIntOrFloatBitWidth() < 32) - return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); + return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value); if (type.getIntOrFloatBitWidth() > 32) - return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); + return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value); llvm_unreachable("The only 32-bit float type is f32"); } @@ -250,13 +245,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc, loc, arith::CmpFPredicate::OEQ, source, negInf); Value isNan = rewriter.createOrFold<arith::CmpFOp>( loc, arith::CmpFPredicate::UNO, source, source); - Value isNonFinite = rewriter.create<arith::OrIOp>( - loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); + Value isNonFinite = arith::OrIOp::create( + rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), + isNan); - Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); - Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); + Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst); + Value clamped = + arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst); Value res = - rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); + arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped); return res; } @@ -290,62 +287,62 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, VectorType truncResType = VectorType::get(4, outElemType); if (!inVectorTy) { Value asFloat = castToF32(in, loc, rewriter); - Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( - loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, + Value asF8s = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); - Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); + Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0); rewriter.replaceOp(op, result); return success(); } int64_t numElements = outVecType.getNumElements(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); if (outVecType.getShape().empty()) { Value scalarIn = - rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarTrunc = - rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); - Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, - ArrayRef<int64_t>{}); + arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero, + ArrayRef<int64_t>{}); rewriter.replaceOp(op, result); return success(); } VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements}, outVecType.getElementType()); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); + Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero); if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, inVectorTy.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value thisResult = nullptr; for (int64_t j = 0; j < elemsThisOp; j += 2) { - Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j); Value asFloatA = castToF32(elemA, loc, rewriter); Value asFloatB = nullptr; if (j + 1 < elemsThisOp) { - Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); + Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1); asFloatB = castToF32(elemB, loc, rewriter); } - thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( - loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); + thisResult = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); } if (elemsThisOp < 4) - thisResult = rewriter.create<vector::ExtractStridedSliceOp>( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -373,22 +370,23 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( // Handle the case where input type is not a vector type if (!inVectorTy) { - auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); + auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); Value asF16s = - rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB); - Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB); + Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0); rewriter.replaceOp(op, result); return success(); } int64_t numElements = outVecType.getNumElements(); Value zero = rewriter.createOrFold<arith::ConstantOp>( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); + Value result = + rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero); if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, inVectorTy.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } // Handle the vector case. We also handle the (uncommon) case where the vector @@ -396,25 +394,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( for (int64_t i = 0; i < numElements; i += 2) { int64_t elemsThisOp = std::min(numElements, i + 2) - i; Value thisResult = nullptr; - Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i); - Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i); + Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); if (elemsThisOp == 2) { - elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1); + elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1); } thisResult = - rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB); // Place back the truncated result into the possibly larger vector. If we // are operating on a size 2 vector, these operations should be folded away - thisResult = rewriter.create<vector::ExtractStridedSliceOp>( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -451,7 +449,7 @@ LogicalResult ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opOutWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -462,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleType = getElementTypeOrSelf(scale); Type outType = getElementTypeOrSelf(out); + int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth(); + VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); @@ -471,28 +471,29 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - VectorType extScaleResultType = VectorType::get(opWidth, outType); + VectorType extScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { - Value inCast = - rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, + VectorType::get(1, inType), in); // TODO: replace this with non-packed ScaledExtOp - Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( - loc, extScaleResultType, inCast, scale, 0); + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inCast, scale, 0); scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0); return success(); } VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) + if (origScaleVecType) llvm::append_range(originalScaleShape, origScaleVecType.getShape()); originalScaleShape.insert(originalScaleShape.end(), @@ -507,44 +508,52 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, int64_t blockSize = computeProduct(ratio); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); + Value result = + rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero); for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector<int64_t> strides(offsets.size(), 1); - Value block = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create<vector::ShapeCastOp>(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create<vector::ExtractOp>(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = - rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero); + rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i); i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 - Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( - loc, extScaleResultType, slice, uniformScale, 0); - if (sliceWidth != opWidth) - scaleExt = rewriter.create<vector::ExtractStridedSliceOp>( - loc, scaleExt, 0, sliceWidth, 1); - blockResult = rewriter.create<vector::InsertStridedSliceOp>( - loc, scaleExt, blockResult, i, 1); + i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) { + Value inSlice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, inSliceWidth, 1); + for (int64_t j = 0, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j); + j < inSliceWidth; j += outSliceWidth, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) { + // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inSlice, uniformScale, + j / opOutWidth); + if (outSliceWidth < opOutWidth) { + scaleExt = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleExt, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleExt, blockResult, i + j, 1); + } } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); @@ -556,7 +565,7 @@ LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opInWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -569,28 +578,28 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); - if (outVecType && outVecType.isScalable()) return failure(); Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); - unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); - VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); + int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth(); + VectorType truncScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Type inVecType = VectorType::get(1, inType); - Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in); // TODO: replace this with non-packed ScaledTruncOp - Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( - loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr); + Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, inCast, scale, 0, + /*existing=*/nullptr); scaleTrunc = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0); return success(); @@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); - SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) - llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + SmallVector<int64_t> scaleShape; + if (origScaleVecType) + llvm::append_range(scaleShape, origScaleVecType.getShape()); - originalScaleShape.insert(originalScaleShape.end(), - inShape.size() - originalScaleShape.size(), 1); + scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1); - auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + auto maybeRatio = computeShapeRatio(inShape, scaleShape); assert(maybeRatio && "failed to derive block size from broadcast or splat operation"); @@ -616,45 +625,62 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, int64_t blockSize = computeProduct(ratio); - Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero); + Value result = + rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero); for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector<int64_t> strides(offsets.size(), 1); - Value block = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create<vector::ShapeCastOp>(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create<vector::ExtractOp>(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = - rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero); - - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); - i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 - Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( - loc, truncScaleResultType, slice, uniformScale, 0, - /*existing=*/nullptr); - int64_t packedWidth = - cast<VectorType>(scaleTrunc.getType()).getNumElements(); - if (packedWidth != opWidth) - scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>( - loc, scaleTrunc, 0, sliceWidth, 1); - blockResult = rewriter.create<vector::InsertStridedSliceOp>( - loc, scaleTrunc, blockResult, i, 1); + rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); + + for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i); + i < blockSize; i += outSliceWidth, + outSliceWidth = std::min(opOutWidth, blockSize - i)) { + Value scaleTrunc; + // Case where <= 2 elements are being truncated. + if (outSliceWidth <= opInWidth) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, outSliceWidth, 1); + // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, 0, + /*existing=*/nullptr); + } else { + scaleTrunc = vector::BroadcastOp::create(rewriter, loc, + truncScaleResultType, zero); + for (int64_t j = 0, + inSliceWidth = std::min(opInWidth, outSliceWidth - j); + j < outSliceWidth; j += opInWidth, + inSliceWidth = std::min(opInWidth, outSliceWidth - j)) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i + j, inSliceWidth, 1); + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, + j / opInWidth, scaleTrunc); + } + } + if (outSliceWidth != opOutWidth) { + scaleTrunc = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleTrunc, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleTrunc, blockResult, i, 1); } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); @@ -664,19 +690,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, void mlir::arith::populateArithToAMDGPUConversionPatterns( RewritePatternSet &patterns, bool convertFP8Arithmetic, - bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { + bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset, + PatternBenefit benefit) { if (convertFP8Arithmetic) { - patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset); - patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(), - saturateFP8Truncf, chipset); + patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset, + benefit); + patterns.add<TruncFToFloat8RewritePattern>( + patterns.getContext(), saturateFP8Truncf, chipset, benefit); } if (allowPackedF16Rtz) - patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext()); + patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit); if (chipset >= kGfx950) { - patterns.add<ScalingExtFRewritePattern>(patterns.getContext()); - patterns.add<ScalingTruncFRewritePattern>(patterns.getContext()); + patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit); + patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit); } } diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp index cbe0b3f..ba48943 100644 --- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp +++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp @@ -74,15 +74,15 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); auto denseAttr1D = DenseElementsAttr::get( tileSliceType, denseAttr.getSplatValue<Attribute>()); - auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); + auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D); - auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.insert_tile_slice' to write vector to tile // slice. - auto nextTile = b.create<arm_sme::InsertTileSliceOp>( - loc, tileType, constantOp1D, currentTile, tileSliceIndex); + auto nextTile = arm_sme::InsertTileSliceOp::create( + b, loc, tileType, constantOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; auto forOp = mlir::arm_sme::createLoopOverTileSlices( diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index a5c08a6..515fe5c 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -110,9 +110,9 @@ public: emitc::CmpPredicate predicate; switch (op.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: { - auto constant = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rewriter.getI1Type(), - rewriter.getBoolAttr(/*value=*/false)); + auto constant = + emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/false)); rewriter.replaceOp(op, constant); return success(); } @@ -179,9 +179,9 @@ public: return success(); } case arith::CmpFPredicate::AlwaysTrue: { - auto constant = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rewriter.getI1Type(), - rewriter.getBoolAttr(/*value=*/true)); + auto constant = + emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/true)); rewriter.replaceOp(op, constant); return success(); } @@ -189,8 +189,8 @@ public: // Compare the values naively auto cmpResult = - rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate, - adaptor.getLhs(), adaptor.getRhs()); + emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate, + adaptor.getLhs(), adaptor.getRhs()); // Adjust the results for unordered/ordered semantics if (unordered) { @@ -213,16 +213,16 @@ private: Value isNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is NaN exactly when it compares unequal to itself. - return rewriter.create<emitc::CmpOp>( - loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, operand, operand); } /// Return a value that is true if \p operand is not NaN. Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is not NaN exactly when it compares equal to itself. - return rewriter.create<emitc::CmpOp>( - loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand); + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, operand, operand); } /// Return a value that is true if the operands \p first and \p second are @@ -231,8 +231,8 @@ private: Location loc, Value first, Value second) const { auto firstIsNaN = isNaN(rewriter, loc, first); auto secondIsNaN = isNaN(rewriter, loc, second); - return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(), - firstIsNaN, secondIsNaN); + return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(), + firstIsNaN, secondIsNaN); } /// Return a value that is true if the operands \p first and \p second are @@ -241,8 +241,8 @@ private: Value first, Value second) const { auto firstIsNotNaN = isNotNaN(rewriter, loc, first); auto secondIsNotNaN = isNotNaN(rewriter, loc, second); - return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(), - firstIsNotNaN, secondIsNotNaN); + return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(), + firstIsNotNaN, secondIsNotNaN); } }; @@ -378,10 +378,10 @@ public: Type attrType = (emitc::isPointerWideType(operandType)) ? rewriter.getIndexType() : operandType; - auto constOne = rewriter.create<emitc::ConstantOp>( - op.getLoc(), operandType, rewriter.getOneAttr(attrType)); - auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>( - op.getLoc(), operandType, adaptor.getIn(), constOne); + auto constOne = emitc::ConstantOp::create( + rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType)); + auto oneAndOperand = emitc::BitwiseAndOp::create( + rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne); rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType, oneAndOperand); return success(); @@ -402,8 +402,8 @@ public: Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); // Actual cast (may change bitwidth) - auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), - castDestType, actualOp); + auto cast = + emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp); // Cast to the expected output type auto result = adaptValueType(cast, rewriter, opReturnType); @@ -466,9 +466,8 @@ public: Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType); Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType); - auto newDivOp = - rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType, - ArrayRef<Value>{lhsAdapted, rhsAdapted}); + auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType, + ArrayRef<Value>{lhsAdapted, rhsAdapted}); Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy); rewriter.replaceOp(uiBinOp, resultAdapted); return success(); @@ -508,8 +507,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -548,8 +547,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -588,38 +587,40 @@ public: // Add a runtime check for overflow Value width; if (emitc::isPointerWideType(type)) { - Value eight = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rhsType, rewriter.getIndexAttr(8)); - emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>( - op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight}); - width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight, - sizeOfCall.getResult(0)); + Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType, + rewriter.getIndexAttr(8)); + emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create( + rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight}); + width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight, + sizeOfCall.getResult(0)); } else { - width = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rhsType, + width = emitc::ConstantOp::create( + rewriter, op.getLoc(), rhsType, rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); } - Value excessCheck = rewriter.create<emitc::CmpOp>( - op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); + Value excessCheck = + emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + emitc::CmpPredicate::lt, rhs, width); // Any concrete value is a valid refinement of poison. - Value poison = rewriter.create<emitc::ConstantOp>( - op.getLoc(), arithmeticType, + Value poison = emitc::ConstantOp::create( + rewriter, op.getLoc(), arithmeticType, (isa<IntegerType>(arithmeticType) ? rewriter.getIntegerAttr(arithmeticType, 0) : rewriter.getIndexAttr(0))); - emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>( - op.getLoc(), arithmeticType, /*do_not_inline=*/false); + emitc::ExpressionOp ternary = emitc::ExpressionOp::create( + rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false); Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); auto currentPoint = rewriter.getInsertionPoint(); rewriter.setInsertionPointToStart(&bodyBlock); Value arithmeticResult = - rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs); - Value resultOrPoison = rewriter.create<emitc::ConditionalOp>( - op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); - rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison); + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); + Value resultOrPoison = + emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType, + excessCheck, arithmeticResult, poison); + emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison); rewriter.setInsertionPoint(op->getBlock(), currentPoint); Value result = adaptValueType(ternary, rewriter, type); @@ -700,11 +701,12 @@ public: /*isSigned=*/false); } - Value result = rewriter.create<emitc::CastOp>( - castOp.getLoc(), actualResultType, adaptor.getOperands()); + Value result = emitc::CastOp::create( + rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands()); if (isa<arith::FPToUIOp>(castOp)) { - result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result); + result = + emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result); } rewriter.replaceOp(castOp, result); @@ -746,8 +748,8 @@ public: } Value fpCastOperand = adaptor.getIn(); if (actualOperandType != operandType) { - fpCastOperand = rewriter.template create<emitc::CastOp>( - castOp.getLoc(), actualOperandType, fpCastOperand); + fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(), + actualOperandType, fpCastOperand); } rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index f7bf581..18e857c 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -293,11 +293,11 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite( [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { - return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); } - return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); }, rewriter); } @@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite( Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); - Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>( - loc, structType, adaptor.getLhs(), adaptor.getRhs()); + Value addOverflow = LLVM::UAddWithOverflowOp::create( + rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = - rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0); Value overflowExtracted = - rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1); rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); return success(); } @@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite( "LLVM dialect should support all signless integer types"); using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>; - Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs()); - Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs()); - Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt); + Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs()); + Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs()); + Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt); // Split the 2*N-bit wide result into two N-bit values. - Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt); - Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr); - Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal); - Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt); + Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt); + Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr); + Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal); + Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt); rewriter.replaceOp(op, {low, high}); return success(); @@ -435,8 +435,8 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create<LLVM::ICmpOp>( - op.getLoc(), llvm1DVectorTy, + return LLVM::ICmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, @@ -471,8 +471,8 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create<LLVM::FCmpOp>( - op.getLoc(), llvm1DVectorTy, + return LLVM::FCmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs(), fmf); }, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 434d7df..265293b 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -117,12 +128,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value, if (auto vectorType = dyn_cast<VectorType>(type)) { Attribute element = IntegerAttr::get(vectorType.getElementType(), value); auto attr = SplatElementsAttr::get(vectorType, element); - return builder.create<spirv::ConstantOp>(loc, vectorType, attr); + return spirv::ConstantOp::create(builder, loc, vectorType, attr); } if (auto intType = dyn_cast<IntegerType>(type)) - return builder.create<spirv::ConstantOp>( - loc, type, builder.getIntegerAttr(type, value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getIntegerAttr(type, value)); return nullptr; } @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector<Attribute, 8> elements; if (isa<FloatType>(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa<IntegerType>(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa<FloatType>(srcType)) { auto srcAttr = cast<FloatAttr>(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); if (!dstAttr) return failure(); @@ -418,18 +447,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Type type = lhs.getType(); // Calculate the remainder with spirv.UMod. - Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); - Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); - Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); + Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs); + Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs); + Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) - isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); + isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs); else - isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); - Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); - return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); + isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs); + Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs); + return spirv::SelectOp::create(builder, loc, type, isPositive, abs, + absNegate); } /// Converts arith.remsi to GLSL SPIR-V ops. @@ -601,13 +631,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { Value allOnes; if (auto intTy = dyn_cast<IntegerType>(dstType)) { unsigned componentBitwidth = intTy.getWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, intTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, vectorTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, vectorTy, SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); } else { @@ -653,8 +683,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's // bitwidth. - auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>( - op.getLoc(), dstType, adaptor.getIn(), shiftSize); + auto shiftLOp = spirv::ShiftLeftLogicalOp::create( + rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize); // Then we perform arithmetic right shift to make sure we have the right // sign bits for negative values. @@ -757,9 +787,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); + Value maskedSrc = spirv::BitwiseAndOp::create( + rewriter, loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -914,9 +944,9 @@ public: if (auto vectorType = dyn_cast<VectorType>(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs()); Value extRhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, extRhs); @@ -1067,12 +1097,12 @@ public: replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); } } else { - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); + replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); + replace = spirv::LogicalNotOp::create(rewriter, loc, replace); } rewriter.replaceOp(op, replace); @@ -1094,17 +1124,17 @@ public: ConversionPatternRewriter &rewriter) const override { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); - Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(), - adaptor.getRhs()); + Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(), + adaptor.getRhs()); - Value sumResult = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(0)); - Value carryValue = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(1)); + Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one); + Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); @@ -1125,12 +1155,12 @@ public: ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value result = - rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs()); + SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); - Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(0)); - Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(1)); + Value low = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value high = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); rewriter.replaceOp(op, {low, high}); return success(); @@ -1183,20 +1213,20 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { rewriter.replaceOp(op, spirvOp); return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getLhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getRhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getLhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getRhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1237,7 +1267,7 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (!shouldInsertNanGuards<SPIRVOp>() || bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { @@ -1245,13 +1275,13 @@ public: return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getRhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getLhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getRhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getLhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1351,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 9c6de93..e34b368 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -41,11 +40,11 @@ public: Value c2d = op.getC(); Location loc = op.getLoc(); Value b1d = - rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d); + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d); Value c1d = - rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d); - Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(), - b1d, c1d); + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d); + Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(), + op.getA(), b1d, c1d); rewriter.replaceOp(op, {newOp}); return success(); } diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 21ea444..8a2e3b63 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); break; } } @@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_st1b_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_st1h_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_st1w_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_st1d_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_st1q_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } llvm_unreachable("unknown type in createStoreTileSliceIntrinsic"); @@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc, // Move to the first operation in the function. rewriter.setInsertionPointToStart(&func.getBlocks().front()); // Create an alloca matching the tile size of the `tileOp`. - auto vscale = rewriter.create<vector::VectorScaleOp>(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto tileElementType = tileOp.getTileType().getElementType(); auto memrefType = MemRefType::get( {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); auto minElementsOp = - rewriter.create<arith::ConstantIndexOp>(loc, minElements); - auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp); - auto alloca = rewriter.create<memref::AllocaOp>( - loc, memrefType, ValueRange{vectorLen, vectorLen}); + arith::ConstantIndexOp::create(rewriter, loc, minElements); + auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp); + auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType, + ValueRange{vectorLen, vectorLen}); return alloca; } @@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { Value tileMemory, Value sliceIndex) const { auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); auto descriptor = - rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory); - auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64); - auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), sliceIndex); + UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory); + auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64); + auto sliceIndexI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), sliceIndex); return getStridedElementPtr( static_cast<ConversionPatternRewriter &>(rewriter), loc, llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0), @@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { arm_sme::ArmSMETileType tileType, VectorType sliceType, IntegerAttr tileId, Value sliceIndex) const { // Cast the slice index to an i32. - auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create an all-true predicate for the slice. auto predicateType = sliceType.clone(rewriter.getI1Type()); - auto allTruePredicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Create padding vector (never used due to all-true predicate). - auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType); + auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType); // Get a pointer to the current slice. auto slicePtr = getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); // Read the value of the current slice from ZA. - auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>( - loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); + auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create( + rewriter, loc, sliceType, padVector, allTruePredicate, tileId, + sliceIndexI32); // Load the new tile slice back from memory into ZA. createLoadTileSliceIntrinsic( rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, allTruePredicate, slicePtr, tileId, sliceIndexI32); // Store the current tile slice to memory. - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca, - ValueRange{sliceIndex, zero}); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca, + ValueRange{sliceIndex, zero}); } /// Emits a full in-place swap of the contents of a tile in ZA and a @@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { RewriterBase::InsertionGuard guard(rewriter); // Create an scf.for over all tile slices. auto minNumElts = - rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0)); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto upperBound = rewriter.create<arith::MulIOp>( - loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc)); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = + arith::MulIOp::create(rewriter, loc, minNumElts, + vector::VectorScaleOp::create(rewriter, loc)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Emit a swap for each tile slice. rewriter.setInsertionPointToStart(forOp.getBody()); auto sliceIndex = forOp.getInductionVar(); @@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> { // // This holds for all tile sizes. int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); - rewriter.create<arm_sme::aarch64_sme_zero>( - loc, rewriter.getI32IntegerAttr(zeroMask)); + arm_sme::aarch64_sme_zero::create(rewriter, loc, + rewriter.getI32IntegerAttr(zeroMask)); // Create a placeholder op to preserve dataflow. // Note: Place the `get_tile` op at the start of the block. This ensures @@ -513,8 +516,8 @@ struct LoadTileSliceConversion auto tileSlice = loadTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. auto maskOp = loadTileSliceOp.getMask(); @@ -559,8 +562,8 @@ struct StoreTileSliceConversion auto tileSlice = storeTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); auto maskOp = storeTileSliceOp.getMask(); @@ -595,28 +598,29 @@ struct InsertTileSliceConversion auto tileSlice = insertTileSliceOp.getTileSliceIndex(); // Cast tile slice from index to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. - auto one = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI1Type(), + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); + auto allActiveMask = + vector::BroadcastOp::create(rewriter, loc, predTy, one); // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. switch (insertTileSliceOp.getLayout()) { case arm_sme::TileSliceLayout::Horizontal: - rewriter.create<arm_sme::aarch64_sme_write_horiz>( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; case arm_sme::TileSliceLayout::Vertical: - rewriter.create<arm_sme::aarch64_sme_write_vert>( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; } @@ -646,16 +650,16 @@ struct ExtractTileSliceConversion // Create an 'all true' predicate for the tile slice. auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); - auto allTruePredicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Zero destination/fallback for tile slice extraction. - auto zeroVector = rewriter.create<arith::ConstantOp>( - loc, sliceType, rewriter.getZeroAttr(sliceType)); + auto zeroVector = arith::ConstantOp::create( + rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType)); // Cast tile slice from index to i32 for intrinsic. - auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. switch (extractTileSlice.getLayout()) { @@ -743,7 +747,7 @@ struct OuterProductOpConversion Value acc = outerProductOp.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType); zero.setTileId(tileId); acc = zero; } @@ -754,16 +758,16 @@ struct OuterProductOpConversion if (!lhsMask || !rhsMask) { auto predTy = outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } // Create 'arm_sme.intr.mopa' outer product intrinsic. - rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask, - outerProductOp.getLhs(), - outerProductOp.getRhs()); + arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask, + outerProductOp.getLhs(), + outerProductOp.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -792,7 +796,7 @@ struct OuterProductWideningOpConversion Value acc = op.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType()); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType()); zero.setTileId(tileId); acc = zero; } @@ -801,14 +805,14 @@ struct OuterProductWideningOpConversion Value rhsMask = op.getRhsMask(); if (!lhsMask || !rhsMask) { auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } - rewriter.create<OuterProductWideningIntrOp>( - loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); + OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask, + adaptor.getLhs(), adaptor.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -843,13 +847,13 @@ struct StreamingVLOpConversion auto *intrOp = [&]() -> Operation * { switch (streamingVlOp.getTypeSize()) { case arm_sme::TypeSize::Byte: - return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type); + return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Half: - return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type); + return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Word: - return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type); + return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Double: - return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type); + return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); } llvm_unreachable("unknown type size in StreamingVLOpConversion"); }(); @@ -872,8 +876,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) { if (zeroOpsToMerge.size() <= 1) return; IRRewriter rewriter(zeroOpsToMerge.front()); - rewriter.create<arm_sme::aarch64_sme_zero>( - zeroOpsToMerge.front().getLoc(), + arm_sme::aarch64_sme_zero::create( + rewriter, zeroOpsToMerge.front().getLoc(), rewriter.getI32IntegerAttr(mergedZeroMask)); for (auto zeroOp : zeroOpsToMerge) rewriter.eraseOp(zeroOp); diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 458628c..e28d5122 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -39,7 +39,7 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank, auto tileSliceOffset = tileSliceIndex; auto baseIndexPlusTileSliceOffset = - rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset); + arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset); outIndices.push_back(baseIndexPlusTileSliceOffset); outIndices.push_back(indices[1]); @@ -59,10 +59,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( if (memrefIndices.size() != 2) return rewriter.notifyMatchFailure(loc, "invalid number of indices"); - auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, + arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); auto vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); @@ -70,7 +71,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( // elements in a vector of SVL bits for a given element type (SVL_B, // SVL_H, ..., SVL_Q). auto numTileSlices = - rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); Value predicate; Value upperBound; @@ -82,30 +83,30 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( // The upper bound of the loop must be clamped at `numTileSlices` as // `vector.create_mask` allows operands to be greater than the size of a // dimension. - auto numRowI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), maskDim0); - auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), numTileSlices); + auto numRowI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), maskDim0); + auto numTileSlicesI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), numTileSlices); auto upperBoundI64 = - rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64); - upperBound = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), upperBoundI64); + arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64); + upperBound = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), upperBoundI64); predicate = - rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1); + vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1); } else { upperBound = numTileSlices; // No mask. Create an 'all true' predicate for the tile slice. - predicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + predicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); } bool hasCarriedArgs = bool(initTile); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step, - hasCarriedArgs ? ValueRange{initTile} - : ValueRange{}); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, + hasCarriedArgs ? ValueRange{initTile} : ValueRange{}); rewriter.setInsertionPointToStart(forOp.getBody()); Value tileSliceIndex = forOp.getInductionVar(); @@ -118,7 +119,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( assert(bool(nextTile) == hasCarriedArgs); if (nextTile) - rewriter.create<scf::YieldOp>(loc, nextTile); + scf::YieldOp::create(rewriter, loc, nextTile); return forOp; } @@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { // Initialize tile with zero to satisfy padding. Inactive cols will be // zeroed anyway since the loads use zeroing predication. For inactive // rows however, no load will occur so these need to be zeroed. - initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType); + initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType); } else { - initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); } // Create a loop to load the active tile slices from memory. @@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { Value currentTile) -> Value { // Create 'arm_sme.load_tile_slice' to load tile slice from memory // into tile. - return rewriter.create<arm_sme::LoadTileSliceOp>( - loc, tileType, tileLoadOp.getBase(), predicate, currentTile, - memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + return arm_sme::LoadTileSliceOp::create( + rewriter, loc, tileType, tileLoadOp.getBase(), predicate, + currentTile, memrefIndices, tileSliceIndex, + tileLoadOp.getLayout()); }); if (failed(forOp)) @@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto numRows = createMaskOp.getOperands()[0]; auto numCols = createMaskOp.getOperands()[1]; - auto numColsI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), numCols); + auto numColsI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), numCols); - auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); // Create a loop that loads each ZA tile slice from memory. - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto numTileSlices = - rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, - step, ValueRange{initTile}); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices, + step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); @@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto currentTile = forOp.getRegionIterArg(0); // Combine masks. - auto rowIsActive = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); - auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>( - loc, rewriter.getI32Type(), rowIsActive); - auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32); - auto maskIndex = - rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask); + auto rowIsActive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); + auto rowIsActiveI32 = arith::ExtSIOp::create( + rewriter, loc, rewriter.getI32Type(), rowIsActive); + auto mask = + arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32); + auto maskIndex = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), mask); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); - auto maskOp1D = rewriter.create<vector::CreateMaskOp>( - loc, predicateType, maskIndex.getResult()); + auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType, + maskIndex.getResult()); auto memrefIndices = getMemrefIndices( tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(), @@ -324,17 +327,19 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion // Splat pad into 1-D vector matching type of tile slice. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp); + auto pad1DOp = + vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp); - auto loadSlice = rewriter.create<vector::MaskedLoadOp>( - loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D, - /*passthru=*/pad1DOp); + auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType, + tileLoadOp.getBase(), + memrefIndices, maskOp1D, + /*passthru=*/pad1DOp); // Create 'arm_sme.insert_tile_slice' to insert slice into tile. - auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>( - loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex, - tileLoadOp.getLayout()); - rewriter.create<scf::YieldOp>(loc, insertSlice.getResult()); + auto insertSlice = arm_sme::InsertTileSliceOp::create( + rewriter, loc, tileType, loadSlice->getResult(0), currentTile, + tileSliceIndex, tileLoadOp.getLayout()); + scf::YieldOp::create(rewriter, loc, insertSlice.getResult()); rewriter.setInsertionPointAfter(forOp); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 94f7caa..29e6552 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -203,7 +202,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; - builder.create<func::FuncOp>(name, type).setPrivate(); + func::FuncOp::create(builder, name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); @@ -254,15 +253,15 @@ static void addResumeFunction(ModuleOp module) { auto voidTy = LLVM::LLVMVoidType::get(ctx); Type ptrType = AsyncAPI::opaquePointerType(ctx); - auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( - kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); + auto resumeOp = LLVM::LLVMFuncOp::create( + moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(moduleBuilder); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); - blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); - blockBuilder.create<LLVM::ReturnOp>(ValueRange()); + LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0)); + LLVM::ReturnOp::create(blockBuilder, ValueRange()); } //===----------------------------------------------------------------------===// @@ -282,7 +281,8 @@ public: // in patterns for other dialects. auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); + auto cast = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return cast.getResult(0); }; @@ -343,8 +343,8 @@ public: // Constants for initializing coroutine frame. auto constZero = - rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); - auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType); // Get coroutine id: @llvm.coro.id. rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( @@ -372,33 +372,33 @@ public: // Get coroutine frame size: @llvm.coro.size.i64. Value coroSize = - rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); + LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type()); // Get coroutine frame alignment: @llvm.coro.align.i64. Value coroAlign = - rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); + LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type()); // Round up the size to be multiple of the alignment. Since aligned_alloc // requires the size parameter be an integral multiple of the alignment // parameter. auto makeConstant = [&](uint64_t c) { - return rewriter.create<LLVM::ConstantOp>(op->getLoc(), - rewriter.getI64Type(), c); + return LLVM::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getI64Type(), c); }; - coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); + coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign); coroSize = - rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); + LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1)); Value negCoroAlign = - rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); + LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign); coroSize = - rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); + LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign); // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); if (failed(allocFuncOp)) return failure(); - auto coroAlloc = rewriter.create<LLVM::CallOp>( - loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); + auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), + ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); @@ -427,7 +427,7 @@ public: // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = - rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands()); + LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands()); // Free the memory. auto freeFuncOp = @@ -455,15 +455,15 @@ public: matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. - auto constFalse = rewriter.create<LLVM::ConstantOp>( - op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); - auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc()); + auto constFalse = + LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(false)); + auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc()); // Mark the end of a coroutine: @llvm.coro.end. auto coroHdl = adaptor.getHandle(); - rewriter.create<LLVM::CoroEndOp>( - op->getLoc(), rewriter.getI1Type(), - ValueRange({coroHdl, constFalse, noneToken})); + LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), + ValueRange({coroHdl, constFalse, noneToken})); rewriter.eraseOp(op); return success(); @@ -534,13 +534,13 @@ public: auto loc = op->getLoc(); // This is not a final suspension point. - auto constFalse = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + auto constFalse = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend auto coroState = adaptor.getState(); - auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( - loc, i8, ValueRange({coroState, constFalse})); + auto coroSuspend = LLVM::CoroSuspendOp::create( + rewriter, loc, i8, ValueRange({coroState, constFalse})); // Cast return code to i32. @@ -551,7 +551,7 @@ public: llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(), op.getCleanupDest()}; rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( - op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), + op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()), /*defaultDestination=*/op.getSuspendDest(), /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, @@ -602,11 +602,11 @@ public: // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i64 - auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType); auto gep = - rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType, - nullPtr, ArrayRef<LLVM::GEPArg>{1}); - return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); + LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType, + nullPtr, ArrayRef<LLVM::GEPArg>{1}); + return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep); }; rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType, @@ -739,8 +739,8 @@ public: .Case<ValueType>([](Type) { return kAwaitValue; }) .Case<GroupType>([](Type) { return kAwaitGroup; }); - rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(), - adaptor.getOperands()); + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), + adaptor.getOperands()); rewriter.eraseOp(op); return success(); @@ -772,13 +772,12 @@ public: // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType<ModuleOp>()); - auto resumePtr = rewriter.create<LLVM::AddressOfOp>( - op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), - kResume); + auto resumePtr = LLVM::AddressOfOp::create( + rewriter, op->getLoc(), + AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); - rewriter.create<func::CallOp>( - op->getLoc(), apiFuncName, TypeRange(), - ValueRange({operand, handle, resumePtr.getRes()})); + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), + ValueRange({operand, handle, resumePtr.getRes()})); rewriter.eraseOp(op); return success(); @@ -801,9 +800,9 @@ public: ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType<ModuleOp>()); - auto resumePtr = rewriter.create<LLVM::AddressOfOp>( - op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), - kResume); + auto resumePtr = LLVM::AddressOfOp::create( + rewriter, op->getLoc(), + AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); // Call async runtime API to execute a coroutine in the managed thread. auto coroHdl = adaptor.getHandle(); @@ -832,8 +831,8 @@ public: // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create<func::CallOp>( - loc, kGetValueStorage, TypeRange(ptrType), storage); + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, + TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getValue().getType(); @@ -845,7 +844,7 @@ public: Value castedStoragePtr = storagePtr.getResult(0); // Store the yielded value into the async value storage. auto value = adaptor.getValue(); - rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr); + LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr); // Erase the original runtime store operation. rewriter.eraseOp(op); @@ -872,8 +871,8 @@ public: // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create<func::CallOp>( - loc, kGetValueStorage, TypeRange(ptrType), storage); + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, + TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getResult().getType(); @@ -960,9 +959,9 @@ public: LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto count = rewriter.create<arith::ConstantOp>( - op->getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(op.getCount())); + auto count = + arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(op.getCount())); auto operand = adaptor.getOperand(); rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName, diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index b9991f3..3edcbb8 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -47,30 +47,29 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) { // Constants - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); // Dynamically evaluate the size and shape of the unranked memref - Value rank = rewriter.create<memref::RankOp>(loc, op.getInput()); + Value rank = memref::RankOp::create(rewriter, loc, op.getInput()); MemRefType allocType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); - Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank); + Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank); // Create a loop to query dimension sizes, store them as a shape, and // compute the total size of the memref auto loopBody = [&](OpBuilder &builder, Location loc, Value i, ValueRange args) { auto acc = args.front(); - auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i); + auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i); - rewriter.create<memref::StoreOp>(loc, dim, shape, i); - acc = rewriter.create<arith::MulIOp>(loc, acc, dim); + memref::StoreOp::create(rewriter, loc, dim, shape, i); + acc = arith::MulIOp::create(rewriter, loc, acc, dim); - rewriter.create<scf::YieldOp>(loc, acc); + scf::YieldOp::create(rewriter, loc, acc); }; - auto size = rewriter - .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), - loopBody) + auto size = scf::ForOp::create(rewriter, loc, zero, rank, one, + ValueRange(one), loopBody) .getResult(0); MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, @@ -78,9 +77,9 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { // Allocate new memref with 1D dynamic shape, then reshape into the // shape of the original unranked memref - alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size); + alloc = memref::AllocOp::create(rewriter, loc, memrefType, size); alloc = - rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape); + memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape); } else { MemRefType memrefType = cast<MemRefType>(type); MemRefLayoutAttrInterface layout; @@ -103,14 +102,15 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { } // Allocate a memref with identity layout. - alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands); + alloc = + memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands); // Cast the allocation to the specified type if needed. if (memrefType != allocType) alloc = - rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc); + memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc); } - rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc); + memref::CopyOp::create(rewriter, loc, op.getInput(), alloc); rewriter.replaceOp(op, alloc); return success(); } diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index f84375b..785cb82 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) -add_subdirectory(MeshToMPI) +add_subdirectory(ShardToMPI) add_subdirectory(MPIToLLVM) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) diff --git a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp index 70b2238..14fbb9b 100644 --- a/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp +++ b/mlir/lib/Conversion/ComplexCommon/DivisionConverter.cpp @@ -23,41 +23,43 @@ void mlir::complex::convertDivToLLVMUsingAlgebraic( ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm) { - Value rhsSqNorm = rewriter.create<LLVM::FAddOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf), - rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf); + Value rhsSqNorm = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf); - Value realNumerator = rewriter.create<LLVM::FAddOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf), - rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf); + Value realNumerator = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf); - Value imagNumerator = rewriter.create<LLVM::FSubOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), - rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); + Value imagNumerator = LLVM::FSubOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); - *resultRe = rewriter.create<LLVM::FDivOp>(loc, realNumerator, rhsSqNorm, fmf); - *resultIm = rewriter.create<LLVM::FDivOp>(loc, imagNumerator, rhsSqNorm, fmf); + *resultRe = + LLVM::FDivOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf); + *resultIm = + LLVM::FDivOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf); } void mlir::complex::convertDivToStandardUsingAlgebraic( ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm) { - Value rhsSqNorm = rewriter.create<arith::AddFOp>( - loc, rewriter.create<arith::MulFOp>(loc, rhsRe, rhsRe, fmf), - rewriter.create<arith::MulFOp>(loc, rhsIm, rhsIm, fmf), fmf); + Value rhsSqNorm = arith::AddFOp::create( + rewriter, loc, arith::MulFOp::create(rewriter, loc, rhsRe, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, rhsIm, rhsIm, fmf), fmf); - Value realNumerator = rewriter.create<arith::AddFOp>( - loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRe, fmf), - rewriter.create<arith::MulFOp>(loc, lhsIm, rhsIm, fmf), fmf); - Value imagNumerator = rewriter.create<arith::SubFOp>( - loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRe, fmf), - rewriter.create<arith::MulFOp>(loc, lhsRe, rhsIm, fmf), fmf); + Value realNumerator = arith::AddFOp::create( + rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsRe, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, lhsIm, rhsIm, fmf), fmf); + Value imagNumerator = arith::SubFOp::create( + rewriter, loc, arith::MulFOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + arith::MulFOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); *resultRe = - rewriter.create<arith::DivFOp>(loc, realNumerator, rhsSqNorm, fmf); + arith::DivFOp::create(rewriter, loc, realNumerator, rhsSqNorm, fmf); *resultIm = - rewriter.create<arith::DivFOp>(loc, imagNumerator, rhsSqNorm, fmf); + arith::DivFOp::create(rewriter, loc, imagNumerator, rhsSqNorm, fmf); } // Smith's algorithm to divide complex numbers. It is just a bit smarter @@ -94,181 +96,185 @@ void mlir::complex::convertDivToLLVMUsingRangeReduction( auto elementType = cast<FloatType>(rhsRe.getType()); Value rhsRealImagRatio = - rewriter.create<LLVM::FDivOp>(loc, rhsRe, rhsIm, fmf); - Value rhsRealImagDenom = rewriter.create<LLVM::FAddOp>( - loc, rhsIm, - rewriter.create<LLVM::FMulOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf); - Value realNumerator1 = rewriter.create<LLVM::FAddOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealImagRatio, fmf), - lhsIm, fmf); - Value resultReal1 = - rewriter.create<LLVM::FDivOp>(loc, realNumerator1, rhsRealImagDenom, fmf); - Value imagNumerator1 = rewriter.create<LLVM::FSubOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealImagRatio, fmf), - lhsRe, fmf); - Value resultImag1 = - rewriter.create<LLVM::FDivOp>(loc, imagNumerator1, rhsRealImagDenom, fmf); + LLVM::FDivOp::create(rewriter, loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = LLVM::FAddOp::create( + rewriter, loc, rhsIm, + LLVM::FMulOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = LLVM::FAddOp::create( + rewriter, loc, + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm, + fmf); + Value resultReal1 = LLVM::FDivOp::create(rewriter, loc, realNumerator1, + rhsRealImagDenom, fmf); + Value imagNumerator1 = LLVM::FSubOp::create( + rewriter, loc, + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe, + fmf); + Value resultImag1 = LLVM::FDivOp::create(rewriter, loc, imagNumerator1, + rhsRealImagDenom, fmf); Value rhsImagRealRatio = - rewriter.create<LLVM::FDivOp>(loc, rhsIm, rhsRe, fmf); - Value rhsImagRealDenom = rewriter.create<LLVM::FAddOp>( - loc, rhsRe, - rewriter.create<LLVM::FMulOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf); - Value realNumerator2 = rewriter.create<LLVM::FAddOp>( - loc, lhsRe, - rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf); - Value resultReal2 = - rewriter.create<LLVM::FDivOp>(loc, realNumerator2, rhsImagRealDenom, fmf); - Value imagNumerator2 = rewriter.create<LLVM::FSubOp>( - loc, lhsIm, - rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf); - Value resultImag2 = - rewriter.create<LLVM::FDivOp>(loc, imagNumerator2, rhsImagRealDenom, fmf); + LLVM::FDivOp::create(rewriter, loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = LLVM::FAddOp::create( + rewriter, loc, rhsRe, + LLVM::FMulOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = LLVM::FAddOp::create( + rewriter, loc, lhsRe, + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf); + Value resultReal2 = LLVM::FDivOp::create(rewriter, loc, realNumerator2, + rhsImagRealDenom, fmf); + Value imagNumerator2 = LLVM::FSubOp::create( + rewriter, loc, lhsIm, + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf); + Value resultImag2 = LLVM::FDivOp::create(rewriter, loc, imagNumerator2, + rhsImagRealDenom, fmf); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create<LLVM::ConstantOp>( - loc, elementType, rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsRe, fmf); - Value rhsRealIsZero = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, rhsIm, fmf); - Value rhsImagIsZero = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero); - Value lhsRealIsNotNaN = - rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsRe, zero); - Value lhsImagIsNotNaN = - rewriter.create<LLVM::FCmpOp>(loc, LLVM::FCmpPredicate::ord, lhsIm, zero); + Value zero = LLVM::ConstantOp::create(rewriter, loc, elementType, + rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, rhsRe, fmf); + Value rhsRealIsZero = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, zero); + Value rhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, rhsIm, fmf); + Value rhsImagIsZero = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, zero); + Value lhsRealIsNotNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::ord, lhsRe, zero); + Value lhsImagIsNotNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::ord, lhsIm, zero); Value lhsContainsNotNaNValue = - rewriter.create<LLVM::OrOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create<LLVM::AndOp>( - loc, lhsContainsNotNaNValue, - rewriter.create<LLVM::AndOp>(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create<LLVM::ConstantOp>( - loc, elementType, + LLVM::OrOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = LLVM::AndOp::create( + rewriter, loc, lhsContainsNotNaNValue, + LLVM::AndOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = LLVM::ConstantOp::create( + rewriter, loc, elementType, rewriter.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfrhsReal = - rewriter.create<LLVM::CopySignOp>(loc, inf, rhsRe); + LLVM::CopySignOp::create(rewriter, loc, inf, rhsRe); Value infinityResultReal = - rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsRe, fmf); + LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsRe, fmf); Value infinityResultImag = - rewriter.create<LLVM::FMulOp>(loc, infWithSignOfrhsReal, lhsIm, fmf); + LLVM::FMulOp::create(rewriter, loc, infWithSignOfrhsReal, lhsIm, fmf); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf); - Value rhsImagFinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf); + Value rhsRealFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, rhsRealAbs, inf); + Value rhsImagFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, rhsImagAbs, inf); Value rhsFinite = - rewriter.create<LLVM::AndOp>(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsRe, fmf); - Value lhsRealInfinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create<LLVM::FAbsOp>(loc, lhsIm, fmf); - Value lhsImagInfinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf); + LLVM::AndOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = LLVM::FAbsOp::create(rewriter, loc, lhsRe, fmf); + Value lhsRealInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, lhsRealAbs, inf); + Value lhsImagAbs = LLVM::FAbsOp::create(rewriter, loc, lhsIm, fmf); + Value lhsImagInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create<LLVM::OrOp>(loc, lhsRealInfinite, lhsImagInfinite); + LLVM::OrOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create<LLVM::AndOp>(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create<LLVM::ConstantOp>( - loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>( - loc, rewriter.create<LLVM::SelectOp>(loc, lhsRealInfinite, one, zero), - lhsRe); - Value lhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>( - loc, rewriter.create<LLVM::SelectOp>(loc, lhsImagInfinite, one, zero), - lhsIm); + LLVM::AndOp::create(rewriter, loc, lhsInfinite, rhsFinite); + Value one = LLVM::ConstantOp::create(rewriter, loc, elementType, + rewriter.getFloatAttr(elementType, 1)); + Value lhsRealIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe); + Value lhsImagIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm); Value lhsRealIsInfWithSignTimesrhsReal = - rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf); Value lhsImagIsInfWithSignTimesrhsImag = - rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf); - Value resultReal3 = rewriter.create<LLVM::FMulOp>( - loc, inf, - rewriter.create<LLVM::FAddOp>(loc, lhsRealIsInfWithSignTimesrhsReal, - lhsImagIsInfWithSignTimesrhsImag, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = LLVM::FMulOp::create( + rewriter, loc, inf, + LLVM::FAddOp::create(rewriter, loc, lhsRealIsInfWithSignTimesrhsReal, + lhsImagIsInfWithSignTimesrhsImag, fmf), fmf); Value lhsRealIsInfWithSignTimesrhsImag = - rewriter.create<LLVM::FMulOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf); Value lhsImagIsInfWithSignTimesrhsReal = - rewriter.create<LLVM::FMulOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf); - Value resultImag3 = rewriter.create<LLVM::FMulOp>( - loc, inf, - rewriter.create<LLVM::FSubOp>(loc, lhsImagIsInfWithSignTimesrhsReal, - lhsRealIsInfWithSignTimesrhsImag, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = LLVM::FMulOp::create( + rewriter, loc, inf, + LLVM::FSubOp::create(rewriter, loc, lhsImagIsInfWithSignTimesrhsReal, + lhsRealIsInfWithSignTimesrhsImag, fmf), fmf); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf); - Value lhsImagFinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf); + Value lhsRealFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, lhsRealAbs, inf); + Value lhsImagFinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::one, lhsImagAbs, inf); Value lhsFinite = - rewriter.create<LLVM::AndOp>(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf); - Value rhsImagInfinite = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf); + LLVM::AndOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsRealAbs, inf); + Value rhsImagInfinite = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::oeq, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create<LLVM::OrOp>(loc, rhsRealInfinite, rhsImagInfinite); + LLVM::OrOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create<LLVM::AndOp>(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create<LLVM::CopySignOp>( - loc, rewriter.create<LLVM::SelectOp>(loc, rhsRealInfinite, one, zero), - rhsRe); - Value rhsImagIsInfWithSign = rewriter.create<LLVM::CopySignOp>( - loc, rewriter.create<LLVM::SelectOp>(loc, rhsImagInfinite, one, zero), - rhsIm); + LLVM::AndOp::create(rewriter, loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe); + Value rhsImagIsInfWithSign = LLVM::CopySignOp::create( + rewriter, loc, + LLVM::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm); Value rhsRealIsInfWithSignTimeslhsReal = - rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimeslhsImag = - rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf); - Value resultReal4 = rewriter.create<LLVM::FMulOp>( - loc, zero, - rewriter.create<LLVM::FAddOp>(loc, rhsRealIsInfWithSignTimeslhsReal, - rhsImagIsInfWithSignTimeslhsImag, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = LLVM::FMulOp::create( + rewriter, loc, zero, + LLVM::FAddOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsReal, + rhsImagIsInfWithSignTimeslhsImag, fmf), fmf); Value rhsRealIsInfWithSignTimeslhsImag = - rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf); + LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimeslhsReal = - rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf); - Value resultImag4 = rewriter.create<LLVM::FMulOp>( - loc, zero, - rewriter.create<LLVM::FSubOp>(loc, rhsRealIsInfWithSignTimeslhsImag, - rhsImagIsInfWithSignTimeslhsReal, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = LLVM::FMulOp::create( + rewriter, loc, zero, + LLVM::FSubOp::create(rewriter, loc, rhsRealIsInfWithSignTimeslhsImag, + rhsImagIsInfWithSignTimeslhsReal, fmf), fmf); - Value realAbsSmallerThanImagAbs = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs); - Value resultReal5 = rewriter.create<LLVM::SelectOp>( - loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); - Value resultImag5 = rewriter.create<LLVM::SelectOp>( - loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create<LLVM::SelectOp>( - loc, finiteNumInfiniteDenom, resultReal4, resultReal5); - Value resultImagSpecialCase3 = rewriter.create<LLVM::SelectOp>( - loc, finiteNumInfiniteDenom, resultImag4, resultImag5); - Value resultRealSpecialCase2 = rewriter.create<LLVM::SelectOp>( - loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create<LLVM::SelectOp>( - loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create<LLVM::SelectOp>( - loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create<LLVM::SelectOp>( - loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); + Value realAbsSmallerThanImagAbs = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::olt, rhsRealAbs, rhsImagAbs); + Value resultReal5 = LLVM::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); + Value resultImag5 = LLVM::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); + Value resultRealSpecialCase3 = LLVM::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5); + Value resultImagSpecialCase3 = LLVM::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5); + Value resultRealSpecialCase2 = LLVM::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); + Value resultImagSpecialCase2 = LLVM::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); + Value resultRealSpecialCase1 = + LLVM::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultReal, resultRealSpecialCase2); + Value resultImagSpecialCase1 = + LLVM::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::uno, resultReal5, zero); - Value resultImagIsNaN = rewriter.create<LLVM::FCmpOp>( - loc, LLVM::FCmpPredicate::uno, resultImag5, zero); + Value resultRealIsNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::uno, resultReal5, zero); + Value resultImagIsNaN = LLVM::FCmpOp::create( + rewriter, loc, LLVM::FCmpPredicate::uno, resultImag5, zero); Value resultIsNaN = - rewriter.create<LLVM::AndOp>(loc, resultRealIsNaN, resultImagIsNaN); + LLVM::AndOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN); - *resultRe = rewriter.create<LLVM::SelectOp>( - loc, resultIsNaN, resultRealSpecialCase1, resultReal5); - *resultIm = rewriter.create<LLVM::SelectOp>( - loc, resultIsNaN, resultImagSpecialCase1, resultImag5); + *resultRe = LLVM::SelectOp::create(rewriter, loc, resultIsNaN, + resultRealSpecialCase1, resultReal5); + *resultIm = LLVM::SelectOp::create(rewriter, loc, resultIsNaN, + resultImagSpecialCase1, resultImag5); } void mlir::complex::convertDivToStandardUsingRangeReduction( @@ -278,179 +284,187 @@ void mlir::complex::convertDivToStandardUsingRangeReduction( auto elementType = cast<FloatType>(rhsRe.getType()); Value rhsRealImagRatio = - rewriter.create<arith::DivFOp>(loc, rhsRe, rhsIm, fmf); - Value rhsRealImagDenom = rewriter.create<arith::AddFOp>( - loc, rhsIm, - rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsRe, fmf), fmf); - Value realNumerator1 = rewriter.create<arith::AddFOp>( - loc, rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealImagRatio, fmf), - lhsIm, fmf); - Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1, - rhsRealImagDenom, fmf); - Value imagNumerator1 = rewriter.create<arith::SubFOp>( - loc, rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealImagRatio, fmf), - lhsRe, fmf); - Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1, - rhsRealImagDenom, fmf); + arith::DivFOp::create(rewriter, loc, rhsRe, rhsIm, fmf); + Value rhsRealImagDenom = arith::AddFOp::create( + rewriter, loc, rhsIm, + arith::MulFOp::create(rewriter, loc, rhsRealImagRatio, rhsRe, fmf), fmf); + Value realNumerator1 = arith::AddFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealImagRatio, fmf), lhsIm, + fmf); + Value resultReal1 = arith::DivFOp::create(rewriter, loc, realNumerator1, + rhsRealImagDenom, fmf); + Value imagNumerator1 = arith::SubFOp::create( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealImagRatio, fmf), lhsRe, + fmf); + Value resultImag1 = arith::DivFOp::create(rewriter, loc, imagNumerator1, + rhsRealImagDenom, fmf); Value rhsImagRealRatio = - rewriter.create<arith::DivFOp>(loc, rhsIm, rhsRe, fmf); - Value rhsImagRealDenom = rewriter.create<arith::AddFOp>( - loc, rhsRe, - rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsIm, fmf), fmf); - Value realNumerator2 = rewriter.create<arith::AddFOp>( - loc, lhsRe, - rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagRealRatio, fmf), fmf); - Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2, - rhsImagRealDenom, fmf); - Value imagNumerator2 = rewriter.create<arith::SubFOp>( - loc, lhsIm, - rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagRealRatio, fmf), fmf); - Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2, - rhsImagRealDenom, fmf); + arith::DivFOp::create(rewriter, loc, rhsIm, rhsRe, fmf); + Value rhsImagRealDenom = arith::AddFOp::create( + rewriter, loc, rhsRe, + arith::MulFOp::create(rewriter, loc, rhsImagRealRatio, rhsIm, fmf), fmf); + Value realNumerator2 = arith::AddFOp::create( + rewriter, loc, lhsRe, + arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagRealRatio, fmf), fmf); + Value resultReal2 = arith::DivFOp::create(rewriter, loc, realNumerator2, + rhsImagRealDenom, fmf); + Value imagNumerator2 = arith::SubFOp::create( + rewriter, loc, lhsIm, + arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagRealRatio, fmf), fmf); + Value resultImag2 = arith::DivFOp::create(rewriter, loc, imagNumerator2, + rhsImagRealDenom, fmf); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. - Value zero = rewriter.create<arith::ConstantOp>( - loc, elementType, rewriter.getZeroAttr(elementType)); - Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsRe, fmf); - Value rhsRealIsZero = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); - Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsIm, fmf); - Value rhsImagIsZero = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); - Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::ORD, lhsRe, zero); - Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::ORD, lhsIm, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, elementType, + rewriter.getZeroAttr(elementType)); + Value rhsRealAbs = math::AbsFOp::create(rewriter, loc, rhsRe, fmf); + Value rhsRealIsZero = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); + Value rhsImagAbs = math::AbsFOp::create(rewriter, loc, rhsIm, fmf); + Value rhsImagIsZero = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); + Value lhsRealIsNotNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ORD, lhsRe, zero); + Value lhsImagIsNotNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ORD, lhsIm, zero); Value lhsContainsNotNaNValue = - rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); - Value resultIsInfinity = rewriter.create<arith::AndIOp>( - loc, lhsContainsNotNaNValue, - rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero)); - Value inf = rewriter.create<arith::ConstantOp>( - loc, elementType, + arith::OrIOp::create(rewriter, loc, lhsRealIsNotNaN, lhsImagIsNotNaN); + Value resultIsInfinity = arith::AndIOp::create( + rewriter, loc, lhsContainsNotNaNValue, + arith::AndIOp::create(rewriter, loc, rhsRealIsZero, rhsImagIsZero)); + Value inf = arith::ConstantOp::create( + rewriter, loc, elementType, rewriter.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfRhsReal = - rewriter.create<math::CopySignOp>(loc, inf, rhsRe); + math::CopySignOp::create(rewriter, loc, inf, rhsRe); Value infinityResultReal = - rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsRe, fmf); + arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsRe, fmf); Value infinityResultImag = - rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsIm, fmf); + arith::MulFOp::create(rewriter, loc, infWithSignOfRhsReal, lhsIm, fmf); // Case 2. Infinite numerator, finite denominator. - Value rhsRealFinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); - Value rhsImagFinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); + Value rhsRealFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); + Value rhsImagFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); Value rhsFinite = - rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite); - Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsRe, fmf); - Value lhsRealInfinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); - Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsIm, fmf); - Value lhsImagInfinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); + arith::AndIOp::create(rewriter, loc, rhsRealFinite, rhsImagFinite); + Value lhsRealAbs = math::AbsFOp::create(rewriter, loc, lhsRe, fmf); + Value lhsRealInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); + Value lhsImagAbs = math::AbsFOp::create(rewriter, loc, lhsIm, fmf); + Value lhsImagInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = - rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite); + arith::OrIOp::create(rewriter, loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = - rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite); - Value one = rewriter.create<arith::ConstantOp>( - loc, elementType, rewriter.getFloatAttr(elementType, 1)); - Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( - loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero), + arith::AndIOp::create(rewriter, loc, lhsInfinite, rhsFinite); + Value one = arith::ConstantOp::create(rewriter, loc, elementType, + rewriter.getFloatAttr(elementType, 1)); + Value lhsRealIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, lhsRealInfinite, one, zero), lhsRe); - Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( - loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero), + Value lhsImagIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, lhsImagInfinite, one, zero), lhsIm); Value lhsRealIsInfWithSignTimesRhsReal = - rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsRe, fmf); + arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsRe, fmf); Value lhsImagIsInfWithSignTimesRhsImag = - rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsIm, fmf); - Value resultReal3 = rewriter.create<arith::MulFOp>( - loc, inf, - rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal, - lhsImagIsInfWithSignTimesRhsImag, fmf), + arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsIm, fmf); + Value resultReal3 = arith::MulFOp::create( + rewriter, loc, inf, + arith::AddFOp::create(rewriter, loc, lhsRealIsInfWithSignTimesRhsReal, + lhsImagIsInfWithSignTimesRhsImag, fmf), fmf); Value lhsRealIsInfWithSignTimesRhsImag = - rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsIm, fmf); + arith::MulFOp::create(rewriter, loc, lhsRealIsInfWithSign, rhsIm, fmf); Value lhsImagIsInfWithSignTimesRhsReal = - rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsRe, fmf); - Value resultImag3 = rewriter.create<arith::MulFOp>( - loc, inf, - rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal, - lhsRealIsInfWithSignTimesRhsImag, fmf), + arith::MulFOp::create(rewriter, loc, lhsImagIsInfWithSign, rhsRe, fmf); + Value resultImag3 = arith::MulFOp::create( + rewriter, loc, inf, + arith::SubFOp::create(rewriter, loc, lhsImagIsInfWithSignTimesRhsReal, + lhsRealIsInfWithSignTimesRhsImag, fmf), fmf); // Case 3: Finite numerator, infinite denominator. - Value lhsRealFinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); - Value lhsImagFinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); + Value lhsRealFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); + Value lhsImagFinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); Value lhsFinite = - rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite); - Value rhsRealInfinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); - Value rhsImagInfinite = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); + arith::AndIOp::create(rewriter, loc, lhsRealFinite, lhsImagFinite); + Value rhsRealInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); + Value rhsImagInfinite = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = - rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite); + arith::OrIOp::create(rewriter, loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = - rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite); - Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>( - loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero), + arith::AndIOp::create(rewriter, loc, lhsFinite, rhsInfinite); + Value rhsRealIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, rhsRealInfinite, one, zero), rhsRe); - Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>( - loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero), + Value rhsImagIsInfWithSign = math::CopySignOp::create( + rewriter, loc, + arith::SelectOp::create(rewriter, loc, rhsImagInfinite, one, zero), rhsIm); Value rhsRealIsInfWithSignTimesLhsReal = - rewriter.create<arith::MulFOp>(loc, lhsRe, rhsRealIsInfWithSign, fmf); + arith::MulFOp::create(rewriter, loc, lhsRe, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsImag = - rewriter.create<arith::MulFOp>(loc, lhsIm, rhsImagIsInfWithSign, fmf); - Value resultReal4 = rewriter.create<arith::MulFOp>( - loc, zero, - rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal, - rhsImagIsInfWithSignTimesLhsImag, fmf), + arith::MulFOp::create(rewriter, loc, lhsIm, rhsImagIsInfWithSign, fmf); + Value resultReal4 = arith::MulFOp::create( + rewriter, loc, zero, + arith::AddFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsReal, + rhsImagIsInfWithSignTimesLhsImag, fmf), fmf); Value rhsRealIsInfWithSignTimesLhsImag = - rewriter.create<arith::MulFOp>(loc, lhsIm, rhsRealIsInfWithSign, fmf); + arith::MulFOp::create(rewriter, loc, lhsIm, rhsRealIsInfWithSign, fmf); Value rhsImagIsInfWithSignTimesLhsReal = - rewriter.create<arith::MulFOp>(loc, lhsRe, rhsImagIsInfWithSign, fmf); - Value resultImag4 = rewriter.create<arith::MulFOp>( - loc, zero, - rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag, - rhsImagIsInfWithSignTimesLhsReal, fmf), + arith::MulFOp::create(rewriter, loc, lhsRe, rhsImagIsInfWithSign, fmf); + Value resultImag4 = arith::MulFOp::create( + rewriter, loc, zero, + arith::SubFOp::create(rewriter, loc, rhsRealIsInfWithSignTimesLhsImag, + rhsImagIsInfWithSignTimesLhsReal, fmf), fmf); - Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); - Value resultReal5 = rewriter.create<arith::SelectOp>( - loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); - Value resultImag5 = rewriter.create<arith::SelectOp>( - loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); - Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>( - loc, finiteNumInfiniteDenom, resultReal4, resultReal5); - Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>( - loc, finiteNumInfiniteDenom, resultImag4, resultImag5); - Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>( - loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); - Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>( - loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); - Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>( - loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); - Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>( - loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); + Value realAbsSmallerThanImagAbs = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); + Value resultReal5 = arith::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); + Value resultImag5 = arith::SelectOp::create( + rewriter, loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); + Value resultRealSpecialCase3 = arith::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultReal4, resultReal5); + Value resultImagSpecialCase3 = arith::SelectOp::create( + rewriter, loc, finiteNumInfiniteDenom, resultImag4, resultImag5); + Value resultRealSpecialCase2 = arith::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); + Value resultImagSpecialCase2 = arith::SelectOp::create( + rewriter, loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); + Value resultRealSpecialCase1 = + arith::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultReal, resultRealSpecialCase2); + Value resultImagSpecialCase1 = + arith::SelectOp::create(rewriter, loc, resultIsInfinity, + infinityResultImag, resultImagSpecialCase2); - Value resultRealIsNaN = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::UNO, resultReal5, zero); - Value resultImagIsNaN = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::UNO, resultImag5, zero); + Value resultRealIsNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UNO, resultReal5, zero); + Value resultImagIsNaN = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UNO, resultImag5, zero); Value resultIsNaN = - rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN); + arith::AndIOp::create(rewriter, loc, resultRealIsNaN, resultImagIsNaN); - *resultRe = rewriter.create<arith::SelectOp>( - loc, resultIsNaN, resultRealSpecialCase1, resultReal5); - *resultIm = rewriter.create<arith::SelectOp>( - loc, resultIsNaN, resultImagSpecialCase1, resultImag5); + *resultRe = arith::SelectOp::create(rewriter, loc, resultIsNaN, + resultRealSpecialCase1, resultReal5); + *resultIm = arith::SelectOp::create(rewriter, loc, resultIsNaN, + resultImagSpecialCase1, resultImag5); } diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index e5e8623..86d02e6 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -35,7 +35,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; ComplexStructBuilder ComplexStructBuilder::poison(OpBuilder &builder, Location loc, Type type) { - Value val = builder.create<LLVM::PoisonOp>(loc, type); + Value val = LLVM::PoisonOp::create(builder, loc, type); return ComplexStructBuilder(val); } @@ -79,9 +79,9 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> { LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); - Value sqNorm = rewriter.create<LLVM::FAddOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf), - rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf); + Value sqNorm = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf), + LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf); rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm); return success(); @@ -191,10 +191,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> { LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); - Value real = - rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); - Value imag = - rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(), + arg.rhs.real(), fmf); + Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(), + arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -278,13 +278,13 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> { Value lhsRe = arg.lhs.real(); Value lhsIm = arg.lhs.imag(); - Value real = rewriter.create<LLVM::FSubOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf), - rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf); + Value real = LLVM::FSubOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf); - Value imag = rewriter.create<LLVM::FAddOp>( - loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf), - rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf); + Value imag = LLVM::FAddOp::create( + rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf), + LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); @@ -313,10 +313,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> { LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); - Value real = - rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf); - Value imag = - rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); + Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(), + arg.rhs.real(), fmf); + Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(), + arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp index 56269d1..f83cac7 100644 --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -84,8 +84,8 @@ LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite( rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, - opFunctionTy); + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name, + opFunctionTy); opFunc.setPrivate(); } assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name))); diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 99d5424..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -44,8 +44,8 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> { rewriter.setInsertionPointToStart(&symTable->getRegion(0).front()); auto funcTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName, - funcTy); + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), + funcName, funcTy); opFunc.setPrivate(); } rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(), @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0c832c4..5ad514d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include <type_traits> @@ -31,44 +30,45 @@ enum class AbsFn { abs, sqrt, rsqrt }; // Returns the absolute value, its square root or its reciprocal square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { - Value one = b.create<arith::ConstantOp>(real.getType(), - b.getFloatAttr(real.getType(), 1.0)); + Value one = arith::ConstantOp::create(b, real.getType(), + b.getFloatAttr(real.getType(), 1.0)); - Value absReal = b.create<math::AbsFOp>(real, fmf); - Value absImag = b.create<math::AbsFOp>(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf); - Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf); + Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf); + Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf); // The lowering below requires NaNs and infinities to work correctly. arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf); - Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf); - Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf); + Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf); + Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf); + Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf); Value result; if (fn == AbsFn::rsqrt) { - ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf); - min = b.create<math::RsqrtOp>(min, fmfWithNaNInf); - max = b.create<math::RsqrtOp>(max, fmfWithNaNInf); + ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + min = math::RsqrtOp::create(b, min, fmfWithNaNInf); + max = math::RsqrtOp::create(b, max, fmfWithNaNInf); } if (fn == AbsFn::sqrt) { - Value quarter = b.create<arith::ConstantOp>( - real.getType(), b.getFloatAttr(real.getType(), 0.25)); + Value quarter = arith::ConstantOp::create( + b, real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. - Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf); - Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf); - result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf); + Value p025 = + math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf); + result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf); } else { - Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf); - result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf); } - Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, - result, fmfWithNaNInf); - return b.create<arith::SelectOp>(isNaN, min, result); + Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result, + result, fmfWithNaNInf); + return arith::SelectOp::create(b, isNaN, min, result); } struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { @@ -81,8 +81,8 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); - Value real = b.create<complex::ReOp>(adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); return success(); @@ -105,28 +105,28 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); - Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf); - Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf); + Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf); + Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf); Value rhsSquaredPlusLhsSquared = - b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf); + complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf); Value sqrtOfRhsSquaredPlusLhsSquared = - b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf); + complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf); Value zero = - b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); - Value one = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 1)); - Value i = b.create<complex::CreateOp>(type, zero, one); - Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf); - Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf); + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); + Value one = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 1)); + Value i = complex::CreateOp::create(b, type, zero, one); + Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf); + Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf); - Value divResult = b.create<complex::DivOp>( - rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); - Value logResult = b.create<complex::LogOp>(divResult, fmf); + Value divResult = complex::DivOp::create( + b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); + Value logResult = complex::LogOp::create(b, divResult, fmf); - Value negativeOne = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -1)); - Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); + Value negativeOne = arith::ConstantOp::create( + b, elementType, b.getFloatAttr(elementType, -1)); + Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne); rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf); return success(); @@ -146,14 +146,18 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { auto loc = op.getLoc(); auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType(); - Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); - Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); - Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); - Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); + Value realLhs = + complex::ReOp::create(rewriter, loc, type, adaptor.getLhs()); + Value imagLhs = + complex::ImOp::create(rewriter, loc, type, adaptor.getLhs()); + Value realRhs = + complex::ReOp::create(rewriter, loc, type, adaptor.getRhs()); + Value imagRhs = + complex::ImOp::create(rewriter, loc, type, adaptor.getRhs()); Value realComparison = - rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); + arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs); Value imagComparison = - rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); + arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, imagComparison); @@ -176,14 +180,14 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); - Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); - Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs, - fmf.getValue()); - Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); - Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); - Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs, - fmf.getValue()); + Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value resultReal = BinaryStandardOp::create(b, elementType, realLhs, + realRhs, fmf.getValue()); + Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs()); + Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs, + imagRhs, fmf.getValue()); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); return success(); @@ -205,20 +209,20 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); // Trigonometric ops use a set of common building blocks to convert to real // ops. Here we create these building blocks and call into an op-specific // implementation in the subclass to combine them. - Value half = rewriter.create<arith::ConstantOp>( - loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); - Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf); - Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf); - Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf); - Value sin = rewriter.create<math::SinOp>(loc, real, fmf); - Value cos = rewriter.create<math::CosOp>(loc, real, fmf); + Value half = arith::ConstantOp::create( + rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); + Value exp = math::ExpOp::create(rewriter, loc, imag, fmf); + Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf); + Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf); + Value sin = math::SinOp::create(rewriter, loc, real, fmf); + Value cos = math::CosOp::create(rewriter, loc, real, fmf); auto resultPair = combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); @@ -251,11 +255,11 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x Value sum = - rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf); - Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf); + arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf); Value diff = - rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf); - Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf); + arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf); return {resultReal, resultImag}; } }; @@ -275,13 +279,13 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhsReal = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value lhsImag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value rhsReal = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value rhsImag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value resultReal, resultImag; @@ -318,16 +322,16 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue()); - Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); + Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); Value resultReal = - rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue()); - Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); + Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); Value resultImag = - rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); @@ -340,11 +344,11 @@ Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, arith::FastMathFlagsAttr fmf) { auto argType = mlir::cast<FloatType>(arg.getType()); Value poly = - b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0])); + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0])); for (unsigned i = 1; i < coefficients.size(); ++i) { - poly = b.create<math::FmaOp>( - poly, arg, - b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])), + poly = math::FmaOp::create( + b, poly, arg, + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])), fmf); } return poly; @@ -365,26 +369,26 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create<complex::ReOp>(adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0)); - Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0)); + Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0)); + Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0)); - Value expm1Real = b.create<math::ExpM1Op>(real, fmf); - Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf); + Value expm1Real = math::ExpM1Op::create(b, real, fmf); + Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf); - Value sinImag = b.create<math::SinOp>(imag, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); Value cosm1Imag = emitCosm1(imag, fmf, b); - Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf); + Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf); - Value realResult = b.create<arith::AddFOp>( - b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf); + Value realResult = arith::AddFOp::create( + b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf); - Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, - zero, fmf.getValue()); - Value imagResult = b.create<arith::SelectOp>( - imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf)); + Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, + zero, fmf.getValue()); + Value imagResult = arith::SelectOp::create( + b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf)); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult, imagResult); @@ -395,8 +399,8 @@ private: Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, ImplicitLocOpBuilder &b) const { auto argType = mlir::cast<FloatType>(arg.getType()); - auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5)); - auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0)); + auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5)); + auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0)); // Algorithm copied from cephes cosm1. SmallVector<double, 7> kCoeffs{ @@ -405,23 +409,23 @@ private: 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 4.1666666666666666609054E-2, }; - Value cos = b.create<math::CosOp>(arg, fmf); - Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf); + Value cos = math::CosOp::create(b, arg, fmf); + Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf); - Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf); - Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf); + Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf); + Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf); Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); auto forSmallArg = - b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf), - b.create<arith::MulFOp>(negHalf, argPow2, fmf)); + arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf), + arith::MulFOp::create(b, negHalf, argPow2, fmf)); // (pi/4)^2 is approximately 0.61685 Value piOver4Pow2 = - b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685)); - Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2, - piOver4Pow2, fmf.getValue()); - return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg); + arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685)); + Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2, + piOver4Pow2, fmf.getValue()); + return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg); } }; @@ -436,13 +440,13 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), - fmf.getValue()); - Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue()); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(), + fmf.getValue()); + Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value resultImag = - b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue()); + math::Atan2Op::create(b, elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); return success(); @@ -460,40 +464,42 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create<complex::ReOp>(adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value half = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 0.5)); - Value one = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf); - Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf); - Value absImag = b.create<math::AbsFOp>(imag, fmf); + Value half = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.5)); + Value one = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 1)); + Value realPlusOne = arith::AddFOp::create(b, real, one, fmf); + Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf); - Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf); + Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf); + Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf); - Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, - realPlusOne, absImag, fmf); - Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf); + Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, + realPlusOne, absImag, fmf); + Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf); Value maxAbsOfRealPlusOneAndImagMinusOne = - b.create<arith::SelectOp>(useReal, real, maxMinusOne); + arith::SelectOp::create(b, useReal, real, maxMinusOne); arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf); + Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf); Value logOfMaxAbsOfRealPlusOneAndImag = - b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf); - Value logOfSqrtPart = b.create<math::Log1pOp>( - b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf), + math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf); + Value logOfSqrtPart = math::Log1pOp::create( + b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf), fmfWithNaNInf); - Value r = b.create<arith::AddFOp>( - b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf), + Value r = arith::AddFOp::create( + b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf), logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); - Value resultReal = b.create<arith::SelectOp>( - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), + Value resultReal = arith::SelectOp::create( + b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r, + fmfWithNaNInf), minAbs, r); - Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf); + Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); return success(); @@ -511,22 +517,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> { auto elementType = cast<FloatType>(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); auto fmfValue = fmf.getValue(); - Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); - Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); - Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); - Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); + Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs()); Value lhsRealTimesRhsReal = - b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue); Value lhsImagTimesRhsImag = - b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue); - Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal, - lhsImagTimesRhsImag, fmfValue); + arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue); + Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal, + lhsImagTimesRhsImag, fmfValue); Value lhsImagTimesRhsReal = - b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue); Value lhsRealTimesRhsImag = - b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue); - Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal, - lhsRealTimesRhsImag, fmfValue); + arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue); + Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal, + lhsRealTimesRhsImag, fmfValue); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); return success(); } @@ -543,11 +549,11 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> { auto elementType = cast<FloatType>(type.getElementType()); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value negReal = rewriter.create<arith::NegFOp>(loc, real); - Value negImag = rewriter.create<arith::NegFOp>(loc, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negReal = arith::NegFOp::create(rewriter, loc, real); + Value negImag = arith::NegFOp::create(rewriter, loc, imag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); return success(); } @@ -570,11 +576,11 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x Value sum = - rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf); - Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf); + arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf); Value diff = - rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf); - Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf); + arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf); return {resultReal, resultImag}; } }; @@ -593,64 +599,65 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); - Value half = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 0.5)); + Value half = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.5)); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); - Value argArg = b.create<math::Atan2Op>(imag, real, fmf); - Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf); - Value cos = b.create<math::CosOp>(sqrtArg, fmf); - Value sin = b.create<math::SinOp>(sqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf); + Value cos = math::CosOp::create(b, sqrtArg, fmf); + Value sin = math::SinOp::create(b, sqrtArg, fmf); // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply // 0 * inf. Value sinIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf); - Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf); - Value resultImag = b.create<arith::SelectOp>( - sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf)); + Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf); + Value resultImag = arith::SelectOp::create( + b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf)); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value inf = cst(APFloat::getInf(floatSemantics)); Value negInf = cst(APFloat::getInf(floatSemantics, true)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf); + Value absImag = math::AbsFOp::create(b, elementType, imag, fmf); - Value absImagIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); - Value absImagIsNotInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf); + Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absImag, inf, fmf); + Value absImagIsNotInf = arith::CmpFOp::create( + b, arith::CmpFPredicate::ONE, absImag, inf, fmf); Value realIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf); - Value realIsNegInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf); + Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + real, negInf, fmf); - resultReal = b.create<arith::SelectOp>( - b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero, + resultReal = arith::SelectOp::create( + b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero, resultReal); - resultReal = b.create<arith::SelectOp>( - b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal); + resultReal = arith::SelectOp::create( + b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal); - Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf); - resultImag = b.create<arith::SelectOp>( - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt), + Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf); + resultImag = arith::SelectOp::create( + b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt), nan, resultImag); - resultImag = b.create<arith::SelectOp>( - b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf, + resultImag = arith::SelectOp::create( + b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf, resultImag); } Value resultIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); - resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal); - resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); + resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); @@ -669,19 +676,20 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value zero = - b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); Value realIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); - Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); - auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf); - Value realSign = b.create<arith::DivFOp>(real, abs, fmf); - Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf); - Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero); + Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero); + auto abs = + complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf); + Value realSign = arith::DivFOp::create(b, real, abs, fmf); + Value imagSign = arith::DivFOp::create(b, imag, abs, fmf); + Value sign = complex::CreateOp::create(b, type, realSign, imagSign); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, adaptor.getComplex(), sign); return success(); @@ -703,84 +711,84 @@ struct TanTanhOpConversion : public OpConversionPattern<Op> { const auto &floatSemantics = elementType.getFloatSemantics(); Value real = - b.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(b, loc, elementType, adaptor.getComplex()); Value imag = - b.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value negOne = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -1.0)); + complex::ImOp::create(b, loc, elementType, adaptor.getComplex()); + Value negOne = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, -1.0)); if constexpr (std::is_same_v<Op, complex::TanOp>) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(real, imag); - real = b.create<arith::MulFOp>(real, negOne, fmf); + real = arith::MulFOp::create(b, real, negOne, fmf); } auto cst = [&](APFloat v) { - return b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; Value inf = cst(APFloat::getInf(floatSemantics)); - Value four = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 4.0)); - Value twoReal = b.create<arith::AddFOp>(real, real, fmf); - Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf); - - Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf); - Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf); - Value realNum = - b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); - - Value cosImag = b.create<math::CosOp>(imag, fmf); - Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf); - Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf); - Value sinImag = b.create<math::SinOp>(imag, fmf); - - Value imagNum = b.create<arith::MulFOp>( - four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf); - - Value expSumMinusTwo = - b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); + Value four = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 4.0)); + Value twoReal = arith::AddFOp::create(b, real, real, fmf); + Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf); + + Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf); + Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf); + Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne, + expNegTwoRealMinusOne, fmf); + + Value cosImag = math::CosOp::create(b, imag, fmf); + Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf); + Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); + + Value imagNum = arith::MulFOp::create( + b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf); + + Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne, + expNegTwoRealMinusOne, fmf); Value denom = - b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); + arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf); - Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, - expSumMinusTwo, inf, fmf); - Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf); + Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + expSumMinusTwo, inf, fmf); + Value realLimit = math::CopySignOp::create(b, negOne, real, fmf); - Value resultReal = b.create<arith::SelectOp>( - isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf)); - Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf); + Value resultReal = arith::SelectOp::create( + b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf)); + Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value absReal = b.create<math::AbsFOp>(real, fmf); - Value zero = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, 0.0)); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value zero = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.0)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value absRealIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); + Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absReal, inf, fmf); Value imagIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value absRealIsNotInf = b.create<arith::XOrIOp>( - absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1)); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value absRealIsNotInf = arith::XOrIOp::create( + b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1)); - Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, - imagNum, imagNum, fmf); + Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, + imagNum, imagNum, fmf); Value resultRealIsNaN = - b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf); - Value resultImagIsZero = b.create<arith::OrIOp>( - imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN)); + arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf); + Value resultImagIsZero = arith::OrIOp::create( + b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN)); - resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal); + resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal); resultImag = - b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag); + arith::SelectOp::create(b, resultImagIsZero, zero, resultImag); } if constexpr (std::is_same_v<Op, complex::TanOp>) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(resultReal, resultImag); - resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf); + resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf); } rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, @@ -799,10 +807,10 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> { auto type = cast<ComplexType>(adaptor.getComplex().getType()); auto elementType = cast<FloatType>(type.getElementType()); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag); @@ -818,97 +826,102 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, arith::FastMathFlags fmf) { auto elementType = cast<FloatType>(type.getElementType()); - Value a = builder.create<complex::ReOp>(lhs); - Value b = builder.create<complex::ImOp>(lhs); + Value a = complex::ReOp::create(builder, lhs); + Value b = complex::ImOp::create(builder, lhs); - Value abs = builder.create<complex::AbsOp>(lhs, fmf); - Value absToC = builder.create<math::PowFOp>(abs, c, fmf); + Value abs = complex::AbsOp::create(builder, lhs, fmf); + Value absToC = math::PowFOp::create(builder, abs, c, fmf); - Value negD = builder.create<arith::NegFOp>(d, fmf); - Value argLhs = builder.create<math::Atan2Op>(b, a, fmf); - Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf); - Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf); + Value negD = arith::NegFOp::create(builder, d, fmf); + Value argLhs = math::Atan2Op::create(builder, b, a, fmf); + Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf); + Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf); - Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf); - Value lnAbs = builder.create<math::LogOp>(abs, fmf); - Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf); - Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf); - Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf); - Value cosQ = builder.create<math::CosOp>(q, fmf); - Value sinQ = builder.create<math::SinOp>(q, fmf); + Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf); + Value lnAbs = math::LogOp::create(builder, abs, fmf); + Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf); + Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf); + Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf); + Value cosQ = math::CosOp::create(builder, q, fmf); + Value sinQ = math::SinOp::create(builder, q, fmf); - Value inf = builder.create<arith::ConstantOp>( - elementType, + Value inf = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value zero = builder.create<arith::ConstantOp>( - elementType, builder.getFloatAttr(elementType, 0.0)); - Value one = builder.create<arith::ConstantOp>( - elementType, builder.getFloatAttr(elementType, 1.0)); - Value complexOne = builder.create<complex::CreateOp>(type, one, zero); - Value complexZero = builder.create<complex::CreateOp>(type, zero, zero); - Value complexInf = builder.create<complex::CreateOp>(type, inf, zero); + Value zero = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, 0.0)); + Value one = arith::ConstantOp::create(builder, elementType, + builder.getFloatAttr(elementType, 1.0)); + Value complexOne = complex::CreateOp::create(builder, type, one, zero); + Value complexZero = complex::CreateOp::create(builder, type, zero, zero); + Value complexInf = complex::CreateOp::create(builder, type, inf, zero); // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value absEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf); Value dEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf); Value cEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf); Value bEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf); Value zeroLeC = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf); - Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf); - Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf); + Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf); + Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf); Value complexOneOrZero = - builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero); + arith::SelectOp::create(builder, cEqZero, complexOne, complexZero); Value coeffCosSin = - builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ); - Value cutoff0 = builder.create<arith::SelectOp>( - builder.create<arith::AndIOp>( - builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC), + complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ); + Value cutoff0 = arith::SelectOp::create( + builder, + arith::AndIOp::create( + builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC), complexOneOrZero, coeffCosSin); // Case 1: // x^0 is defined to be 1 for any x, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. - Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero); + Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero); Value cutoff1 = - builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0); + arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0); // Case 2: // 1^(c + d*i) = 1 + 0*i - Value lhsEqOne = builder.create<arith::AndIOp>( - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf), + Value lhsEqOne = arith::AndIOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf), bEqZero); Value cutoff2 = - builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1); + arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1); // Case 3: // inf^(c + 0*i) = inf + 0*i, c > 0 - Value lhsEqInf = builder.create<arith::AndIOp>( - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf), + Value lhsEqInf = arith::AndIOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf), bEqZero); - Value rhsGt0 = builder.create<arith::AndIOp>( - dEqZero, - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf)); - Value cutoff3 = builder.create<arith::SelectOp>( - builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2); + Value rhsGt0 = arith::AndIOp::create( + builder, dEqZero, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf)); + Value cutoff3 = arith::SelectOp::create( + builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf, + cutoff2); // Case 4: // inf^(c + 0*i) = 0 + 0*i, c < 0 - Value rhsLt0 = builder.create<arith::AndIOp>( - dEqZero, - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf)); - Value cutoff4 = builder.create<arith::SelectOp>( - builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3); + Value rhsLt0 = arith::AndIOp::create( + builder, dEqZero, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf)); + Value cutoff4 = arith::SelectOp::create( + builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero, + cutoff3); return cutoff4; } @@ -923,8 +936,8 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> { auto type = cast<ComplexType>(adaptor.getLhs().getType()); auto elementType = cast<FloatType>(type.getElementType()); - Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs()); - Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs()); + Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs()); + Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs()); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), c, d, op.getFastmath())}); @@ -945,64 +958,64 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value inf = cst(APFloat::getInf(floatSemantics)); - Value negHalf = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -0.5)); + Value negHalf = arith::ConstantOp::create( + b, elementType, b.getFloatAttr(elementType, -0.5)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); - Value argArg = b.create<math::Atan2Op>(imag, real, fmf); - Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf); - Value cos = b.create<math::CosOp>(rsqrtArg, fmf); - Value sin = b.create<math::SinOp>(rsqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf); + Value cos = math::CosOp::create(b, rsqrtArg, fmf); + Value sin = math::SinOp::create(b, rsqrtArg, fmf); - Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf); - Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf); + Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf); + Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value negOne = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -1)); + Value negOne = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, -1)); - Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf); - Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf); + Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf); + Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf); Value negImagSignedZero = - b.create<arith::MulFOp>(negOne, imagSignedZero, fmf); + arith::MulFOp::create(b, negOne, imagSignedZero, fmf); - Value absReal = b.create<math::AbsFOp>(real, fmf); - Value absImag = b.create<math::AbsFOp>(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value absImagIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); + Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absImag, inf, fmf); Value realIsNan = - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf); - Value realIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); - Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan); + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf); + Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absReal, inf, fmf); + Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan); - Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf); + Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf); resultReal = - b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal); - resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero, - resultImag); + arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero, + resultImag); } Value isRealZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf); Value isImagZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero); - resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal); - resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag); + resultReal = arith::SelectOp::create(b, isZero, inf, resultReal); + resultImag = arith::SelectOp::create(b, isZero, nan, resultImag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); @@ -1021,9 +1034,9 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, type, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, type, adaptor.getComplex()); rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 13a0844..ff6d369 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -73,13 +73,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); - abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), - "abort", abortFuncTy); + abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(), + "abort", abortFuncTy); } - rewriter.create<LLVM::CallOp>(loc, abortFunc, ValueRange()); - rewriter.create<LLVM::UnreachableOp>(loc); + LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange()); + LLVM::UnreachableOp::create(rewriter, loc); } else { - rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); + LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock); } // Generate assertion test. diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index 9831dca..5ac838c 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -33,8 +33,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( MutableArrayRef<Region> regions) { if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) { assert(regions.size() == 2); - auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(), - resultTypes, condBrOp.getCondition()); + auto ifOp = scf::IfOp::create(builder, controlFlowCondOp->getLoc(), + resultTypes, condBrOp.getCondition()); ifOp.getThenRegion().takeBody(regions[0]); ifOp.getElseRegion().takeBody(regions[1]); return ifOp.getOperation(); @@ -43,8 +43,8 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) { // `getCFGSwitchValue` returns an i32 that we need to convert to index // fist. - auto cast = builder.create<arith::IndexCastUIOp>( - controlFlowCondOp->getLoc(), builder.getIndexType(), + auto cast = arith::IndexCastUIOp::create( + builder, controlFlowCondOp->getLoc(), builder.getIndexType(), switchOp.getFlag()); SmallVector<int64_t> cases; if (auto caseValues = switchOp.getCaseValues()) @@ -55,8 +55,9 @@ ControlFlowToSCFTransformation::createStructuredBranchRegionOp( assert(regions.size() == cases.size() + 1); - auto indexSwitchOp = builder.create<scf::IndexSwitchOp>( - controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); + auto indexSwitchOp = + scf::IndexSwitchOp::create(builder, controlFlowCondOp->getLoc(), + resultTypes, cast, cases, cases.size()); indexSwitchOp.getDefaultRegion().takeBody(regions[0]); for (auto &&[targetRegion, sourceRegion] : @@ -75,7 +76,7 @@ LogicalResult ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( Location loc, OpBuilder &builder, Operation *branchRegionOp, Operation *replacedControlFlowOp, ValueRange results) { - builder.create<scf::YieldOp>(loc, results); + scf::YieldOp::create(builder, loc, results); return success(); } @@ -84,23 +85,24 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { Location loc = replacedOp->getLoc(); - auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(), - loopVariablesInit); + auto whileOp = scf::WhileOp::create( + builder, loc, loopVariablesInit.getTypes(), loopVariablesInit); whileOp.getBefore().takeBody(loopBody); builder.setInsertionPointToEnd(&whileOp.getBefore().back()); // `getCFGSwitchValue` returns a i32. We therefore need to truncate the // condition to i1 first. It is guaranteed to be either 0 or 1 already. - builder.create<scf::ConditionOp>( - loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition), + scf::ConditionOp::create( + builder, loc, + arith::TruncIOp::create(builder, loc, builder.getI1Type(), condition), loopVariablesNextIter); Block *afterBlock = builder.createBlock(&whileOp.getAfter()); afterBlock->addArguments( loopVariablesInit.getTypes(), SmallVector<Location>(loopVariablesInit.size(), loc)); - builder.create<scf::YieldOp>(loc, afterBlock->getArguments()); + scf::YieldOp::create(builder, loc, afterBlock->getArguments()); return whileOp.getOperation(); } @@ -108,8 +110,8 @@ ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, OpBuilder &builder, unsigned int value) { - return builder.create<arith::ConstantOp>(loc, - builder.getI32IntegerAttr(value)); + return arith::ConstantOp::create(builder, loc, + builder.getI32IntegerAttr(value)); } void ControlFlowToSCFTransformation::createCFGSwitchOp( @@ -117,15 +119,15 @@ void ControlFlowToSCFTransformation::createCFGSwitchOp( ArrayRef<unsigned int> caseValues, BlockRange caseDestinations, ArrayRef<ValueRange> caseArguments, Block *defaultDest, ValueRange defaultArgs) { - builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs, - llvm::to_vector_of<int32_t>(caseValues), - caseDestinations, caseArguments); + cf::SwitchOp::create(builder, loc, flag, defaultDest, defaultArgs, + llvm::to_vector_of<int32_t>(caseValues), + caseDestinations, caseArguments); } Value ControlFlowToSCFTransformation::getUndefValue(Location loc, OpBuilder &builder, Type type) { - return builder.create<ub::PoisonOp>(loc, type, nullptr); + return ub::PoisonOp::create(builder, loc, type, nullptr); } FailureOr<Operation *> @@ -142,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, return emitError(loc, "Cannot create unreachable terminator for '") << parentOp->getName() << "'"; - return builder - .create<func::ReturnOp>( - loc, llvm::map_to_vector(funcOp.getResultTypes(), - [&](Type type) { - return getUndefValue(loc, builder, type); - })) + return func::ReturnOp::create( + builder, loc, + llvm::map_to_vector( + funcOp.getResultTypes(), + [&](Type type) { return getUndefValue(loc, builder, type); })) .getOperation(); } diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4..56b6181 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp index c9b1dc1..ee6d7d5 100644 --- a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp +++ b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp @@ -9,8 +9,6 @@ #include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h" #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp index 252245d..c70b5f0 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp @@ -9,7 +9,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" -#include "llvm/ADT/DenseSet.h" using namespace mlir; diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp index f8dc06f..197caeb 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -99,8 +99,8 @@ public: } // Create the converted `emitc.func` op. - emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>( - funcOp.getLoc(), funcOp.getName(), + emitc::FuncOp newFuncOp = emitc::FuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), FunctionType::get(rewriter.getContext(), signatureConverter.getConvertedTypes(), resultType ? TypeRange(resultType) : TypeRange())); diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 3623563..67bb1c1 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -115,8 +115,8 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, SmallVector<NamedAttribute> attributes; filterFuncAttributes(funcOp, attributes); - auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>( - loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), + auto wrapperFuncOp = LLVM::LLVMFuncOp::create( + rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp); @@ -129,14 +129,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, for (auto [index, argType] : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(index + argOffset); if (auto memrefType = dyn_cast<MemRefType>(argType)) { - Value loaded = rewriter.create<LLVM::LoadOp>( - loc, typeConverter.convertType(memrefType), arg); + Value loaded = LLVM::LoadOp::create( + rewriter, loc, typeConverter.convertType(memrefType), arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } if (isa<UnrankedMemRefType>(argType)) { - Value loaded = rewriter.create<LLVM::LoadOp>( - loc, typeConverter.convertType(argType), arg); + Value loaded = LLVM::LoadOp::create( + rewriter, loc, typeConverter.convertType(argType), arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); continue; } @@ -144,14 +144,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, args.push_back(arg); } - auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args); + auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args); if (resultStructType) { - rewriter.create<LLVM::StoreOp>(loc, call.getResult(), - wrapperFuncOp.getArgument(0)); - rewriter.create<LLVM::ReturnOp>(loc, ValueRange{}); + LLVM::StoreOp::create(rewriter, loc, call.getResult(), + wrapperFuncOp.getArgument(0)); + LLVM::ReturnOp::create(rewriter, loc, ValueRange{}); } else { - rewriter.create<LLVM::ReturnOp>(loc, call.getResults()); + LLVM::ReturnOp::create(rewriter, loc, call.getResults()); } } @@ -182,8 +182,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, filterFuncAttributes(funcOp, attributes); // Create the auxiliary function. - auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>( - loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), + auto wrapperFunc = LLVM::LLVMFuncOp::create( + builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc); @@ -201,11 +201,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, if (resultStructType) { // Allocate the struct on the stack and pass the pointer. Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0); - Value one = builder.create<LLVM::ConstantOp>( - loc, typeConverter.convertType(builder.getIndexType()), + Value one = LLVM::ConstantOp::create( + builder, loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); Value result = - builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one); + LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one); args.push_back(result); } @@ -229,12 +229,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, wrapperArgsRange.take_front(numToDrop)); auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); - Value one = builder.create<LLVM::ConstantOp>( - loc, typeConverter.convertType(builder.getIndexType()), + Value one = LLVM::ConstantOp::create( + builder, loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); - Value allocated = builder.create<LLVM::AllocaOp>( - loc, ptrTy, packed.getType(), one, /*alignment=*/0); - builder.create<LLVM::StoreOp>(loc, packed, allocated); + Value allocated = LLVM::AllocaOp::create( + builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0); + LLVM::StoreOp::create(builder, loc, packed, allocated); arg = allocated; } else { arg = wrapperArgsRange[0]; @@ -245,14 +245,14 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, } assert(wrapperArgsRange.empty() && "did not map some of the arguments"); - auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args); + auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args); if (resultStructType) { Value result = - builder.create<LLVM::LoadOp>(loc, resultStructType, args.front()); - builder.create<LLVM::ReturnOp>(loc, result); + LLVM::LoadOp::create(builder, loc, resultStructType, args.front()); + LLVM::ReturnOp::create(builder, loc, result); } else { - builder.create<LLVM::ReturnOp>(loc, call.getResults()); + LLVM::ReturnOp::create(builder, loc, call.getResults()); } } @@ -283,7 +283,7 @@ static void restoreByValRefArgumentType( Type resTy = typeConverter.convertType( cast<TypeAttr>(byValRefAttr->getValue()).getValue()); - Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg); + Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); rewriter.replaceUsesOfBlockArgument(arg, valueArg); } } @@ -357,8 +357,8 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp( symbolTable.remove(funcOp); } - auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>( - funcOp.getLoc(), funcOp.getName(), llvmType, linkage, + auto newFuncOp = LLVM::LLVMFuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes); @@ -509,7 +509,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> { return rewriter.notifyMatchFailure(op, "failed to convert result type"); auto newOp = - rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue()); + LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue()); for (const NamedAttribute &attr : op->getAttrs()) { if (attr.getName().strref() == "value") continue; @@ -556,9 +556,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), adaptor.getOperands(), rewriter, useBarePtrCallConv); - auto newOp = rewriter.create<LLVM::CallOp>( - callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), - promoted, callOp->getAttrs()); + auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(), + packedResult ? TypeRange(packedResult) + : TypeRange(), + promoted, callOp->getAttrs()); newOp.getProperties().operandSegmentSizes = { static_cast<int32_t>(promoted.size()), 0}; @@ -573,8 +574,8 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { // Extract individual results from the structure and return them as list. results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create<LLVM::ExtractValueOp>( - callOp.getLoc(), newOp->getResult(0), i)); + results.push_back(LLVM::ExtractValueOp::create( + rewriter, callOp.getLoc(), newOp->getResult(0), i)); } } @@ -726,9 +727,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, op->getAttrs()); diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f65..c0439a4 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 01ca5e9..1037e29 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); - ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); + ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } return ret; } @@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); - return b.create<LLVM::GlobalOp>(loc, globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, - name, attr, alignment, addrSpace); + return LLVM::GlobalOp::create(b, loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, + name, attr, alignment, addrSpace); } LogicalResult @@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, gpuFuncOp.getWorkgroupAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - auto globalOp = rewriter.create<LLVM::GlobalOp>( - gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, + auto globalOp = LLVM::GlobalOp::create( + rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, workgroupAddrSpace); workgroupBuffers.push_back(globalOp); @@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, LLVM::CConv callingConvention = gpuFuncOp.isKernel() ? kernelCallingConvention : nonKernelCallingConvention; - auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( - gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, + auto llvmFuncOp = LLVM::LLVMFuncOp::create( + rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, /*comdat=*/nullptr, attributes); @@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()); - Value address = rewriter.create<LLVM::AddressOfOp>( - loc, ptrType, global.getSymNameAttr()); + Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType, + global.getSymNameAttr()); Value memory = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), - address, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(), + address, ArrayRef<LLVM::GEPArg>{0, 0}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than @@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, Type elementType = typeConverter->convertType(type.getElementType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); - Value numElements = rewriter.create<LLVM::ConstantOp>( - gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); + Value numElements = LLVM::ConstantOp::create( + rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - Value allocated = rewriter.create<LLVM::AllocaOp>( - gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); + Value allocated = + LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, + elementType, numElements, alignment); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( @@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall - Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0); - auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); + Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0); + auto printfBeginCall = + LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. @@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); - Value stringLen = rewriter.create<LLVM::ConstantOp>( - loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + Value stringLen = LLVM::ConstantOp::create( + rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); - Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1); - Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0); + Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1); + Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0); - auto appendFormatCall = rewriter.create<LLVM::CallOp>( - loc, ocklAppendStringN, + auto appendFormatCall = LLVM::CallOp::create( + rewriter, loc, ocklAppendStringN, ValueRange{printfDesc, stringStart, stringLen, adaptor.getArgs().empty() ? oneI32 : zeroI32}); printfDesc = appendFormatCall.getResult(); @@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; arguments.push_back(printfDesc); arguments.push_back( - rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall)); + LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; if (auto floatType = dyn_cast<FloatType>(arg.getType())) { if (!floatType.isF64()) - arg = rewriter.create<LLVM::FPExtOp>( - loc, typeConverter->convertType(rewriter.getF64Type()), arg); - arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); + arg = LLVM::FPExtOp::create( + rewriter, loc, typeConverter->convertType(rewriter.getF64Type()), + arg); + arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg); } if (arg.getType().getIntOrFloatBitWidth() != 64) - arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); + arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg); arguments.push_back(arg); } @@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); - auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); + auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments); printfDesc = call.getResult(); } rewriter.eraseOp(gpuPrintfOp); @@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); // Construct arguments and function call auto argsRange = adaptor.getArgs(); @@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); SmallVector<Type> types; SmallVector<Value> args; // Promote and pack the arguments into a stack allocation. @@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( assert(type.isIntOrFloat()); if (isa<FloatType>(type)) { type = rewriter.getF64Type(); - promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg); + promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg); } types.push_back(type); args.push_back(promotedArg); } Type structType = LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); - Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), - rewriter.getIndexAttr(1)); + Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIndexAttr(1)); Value tempAlloc = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one, - /*alignment=*/0); + LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one, + /*alignment=*/0); for (auto [index, arg] : llvm::enumerate(args)) { - Value ptr = rewriter.create<LLVM::GEPOp>( - loc, ptrType, structType, tempAlloc, + Value ptr = LLVM::GEPOp::create( + rewriter, loc, ptrType, structType, tempAlloc, ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)}); - rewriter.create<LLVM::StoreOp>(loc, arg, ptr); + LLVM::StoreOp::create(rewriter, loc, arg, ptr); } std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; - rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, TypeRange operandTypes(operands); VectorType vectorType = cast<VectorType>(llvm1DVectorTy); Location loc = op->getLoc(); - Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType); + Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType); Type indexType = converter.convertType(rewriter.getIndexType()); StringAttr name = op->getName().getIdentifier(); Type elementType = vectorType.getElementType(); for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { - Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i); + Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i); auto extractElement = [&](Value operand) -> Value { if (!isa<VectorType>(operand.getType())) return operand; - return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index); + return LLVM::ExtractElementOp::create(rewriter, loc, operand, index); }; auto scalarOperands = llvm::map_to_vector(operands, extractElement); Operation *scalarOp = rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); - result = rewriter.create<LLVM::InsertElementOp>( - loc, result, scalarOp->getResult(0), index); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + scalarOp->getResult(0), index); } return result; } @@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol( auto zeroSizedArrayType = LLVM::LLVMArrayType::get( typeConverter->convertType(memrefType.getElementType()), 0); - return rewriter.create<LLVM::GlobalOp>( - op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, - LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, - addressSpace.value()); + return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType, + /*isConstant=*/false, LLVM::Linkage::Internal, + symName, /*value=*/Attribute(), alignmentByte, + addressSpace.value()); } LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( @@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( // Step 3. Get address of the global symbol OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp); + auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp); Type baseType = basePtr->getResultTypes().front(); // Step 4. Generate GEP using offsets SmallVector<LLVM::GEPArg> gepArgs = {0}; - Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType, - basePtr, gepArgs); + Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType, + basePtr, gepArgs); // Step 5. Create a memref descriptor SmallVector<Value> shape, strides; Value sizeBytes; @@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite( return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, op->getAttrs()); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 167cabb..3cfbd89 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -79,8 +79,8 @@ protected: uint64_t rank = type.getRank(); Value numElements = desc.size(rewriter, loc, /*pos=*/0); for (unsigned i = 1; i < rank; i++) - numElements = rewriter.create<LLVM::MulOp>( - loc, numElements, desc.size(rewriter, loc, /*pos=*/i)); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, + desc.size(rewriter, loc, /*pos=*/i)); return numElements; } @@ -579,10 +579,10 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, auto function = [&] { if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) return function; - return OpBuilder::atBlockEnd(module.getBody()) - .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); + auto builder = OpBuilder::atBlockEnd(module.getBody()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); }(); - return builder.create<LLVM::CallOp>(loc, function, arguments); + return LLVM::CallOp::create(builder, loc, function, arguments); } // Corresponding to cusparseIndexType_t defined in cusparse.h. @@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType); + auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullPtr : adaptor.getAsyncDependencies().front(); - auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>( - loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); + auto isHostShared = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); Value allocatedPtr = allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared}) @@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) * static_cast<uint64_t>(memrefTy.getNumElements()); - Value sizeArg = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(staticSize)); + Value sizeArg = LLVM::ConstantOp::create( + rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize)); llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. llvmArgumentsWithSizes.push_back(sizeArg); } @@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), adaptor.getClusterSizeZ()}; } - rewriter.create<gpu::LaunchFuncOp>( - launchOp.getLoc(), launchOp.getKernelAttr(), + gpu::LaunchFuncOp::create( + rewriter, launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), @@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc, const LLVMTypeConverter &typeConverter) { auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) - sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>( - loc, + sourcePtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), destinationType.getAddressSpace()), sourcePtr); @@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); Type elementPtrType = getElementPtrType(memRefType); - Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); - Value gepPtr = rewriter.create<LLVM::GEPOp>( - loc, elementPtrType, + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create( + rewriter, loc, elementPtrType, typeConverter->convertType(memRefType.getElementType()), nullPtr, numElements); auto sizeBytes = - rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, srcDesc.alignedPtr(rewriter, loc), @@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); auto value = - rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue()); + LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue()); auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstDesc.alignedPtr(rewriter, loc), *getTypeConverter()); @@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( template <typename T> static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { Type llvmInt32Type = builder.getIntegerType(32); - return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, - static_cast<int32_t>(tValue)); + return LLVM::ConstantOp::create(builder, loc, llvmInt32Type, + static_cast<int32_t>(tValue)); } template <typename T> static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { Type llvmFloat32Type = builder.getF32Type(); - return builder.create<LLVM::ConstantOp>( - loc, llvmFloat32Type, + return LLVM::ConstantOp::create( + builder, loc, llvmFloat32Type, builder.getF32FloatAttr(static_cast<float>(tValue))); } @@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( // the dnmat is used with spmat with 2:4 sparsity if (dims.size() == 2) { if (isSpMMCusparseLtOp(op.getDnTensor())) { - auto handleSz = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(11032)); - handle = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(11032)); + handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); createLtDnMatCallBuilder .create(loc, rewriter, @@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); // CUDA runner asserts the size is 44104 bytes. - auto handleSz = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(44104)); - Value handle = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(44104)); + Value handle = LLVM::AllocaOp::create( + rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); create2To4SpMatCallBuilder .create(loc, rewriter, @@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); auto computeType = genConstInt32From( rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); - auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto bufferSize = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto bufferSize = + LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType, + three, /*alignment=*/16); createCuSparseLtSpMMBufferSizeBuilder .create(loc, rewriter, {bufferSize, modeA, modeB, adaptor.getSpmatA(), @@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( pruneFlag, stream}) .getResult(); - auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(1))}); - auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(2))}); + auto bufferSizePtr1 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto bufferSizePtr2 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); auto bufferSize0 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize); auto bufferSize1 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1); auto bufferSize2 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2); rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); } else { @@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( Location loc = op.getLoc(); auto stream = adaptor.getAsyncDependencies().front(); - auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto buffer = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); - - auto rowsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(0))}); - auto colsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(1))}); - auto nnzsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(2))}); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt64Type, three, /*alignment=*/16); + + auto rowsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0))}); + auto colsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto nnzsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); createSpMatGetSizeBuilder.create( loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); - auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr); - auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr); - auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr); + auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr); + auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr); + auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr); rewriter.replaceOp(op, {rows, cols, nnzs, stream}); return success(); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index aab2409..91c43e8 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -59,13 +59,13 @@ public: Operation *newOp; switch (op.getDimension()) { case gpu::Dimension::x: - newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32)); + newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::y: - newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32)); + newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::z: - newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32)); + newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32)); break; } @@ -124,11 +124,13 @@ public: rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { - newOp = rewriter.create<LLVM::SExtOp>( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::SExtOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } else if (indexBitwidth < 32) { - newOp = rewriter.create<LLVM::TruncOp>( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::TruncOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } rewriter.replaceOp(op, newOp->getResults()); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 64cf09e..9f36e5c 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -103,7 +103,7 @@ public: LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); auto callOp = - rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands); + LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands); if (resultType == adaptor.getOperands().front().getType()) { rewriter.replaceOp(op, {callOp.getResult()}); @@ -115,19 +115,20 @@ public: // there is no guarantee of a specific value being used to indicate true, // compare for inequality with zero (rather than truncate or shift). if (isResultBool) { - Value zero = rewriter.create<LLVM::ConstantOp>( - op->getLoc(), rewriter.getIntegerType(32), - rewriter.getI32IntegerAttr(0)); - Value truncated = rewriter.create<LLVM::ICmpOp>( - op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero); + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getIntegerType(32), + rewriter.getI32IntegerAttr(0)); + Value truncated = + LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne, + callOp.getResult(), zero); rewriter.replaceOp(op, {truncated}); return success(); } assert(callOp.getResult().getType().isF32() && "only f32 types are supposed to be truncated back"); - Value truncated = rewriter.create<LLVM::FPTruncOp>( - op->getLoc(), adaptor.getOperands().front().getType(), + Value truncated = LLVM::FPTruncOp::create( + rewriter, op->getLoc(), adaptor.getOperands().front().getType(), callOp.getResult()); rewriter.replaceOp(op, {truncated}); return success(); @@ -142,8 +143,9 @@ public: if (!f16Func.empty() && isa<Float16Type>(type)) return operand; - return rewriter.create<LLVM::FPExtOp>( - operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + return LLVM::FPExtOp::create(rewriter, operand.getLoc(), + Float32Type::get(rewriter.getContext()), + operand); } Type getFunctionType(Type resultType, ValueRange operands) const { @@ -169,7 +171,7 @@ public: // location as debug info metadata inside of a function cannot be used // outside of that function. auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>(); - return b.create<LLVMFuncOp>(globalloc, funcName, funcType); + return LLVMFuncOp::create(b, globalloc, funcName, funcType); } StringRef getFunctionName(Type type, SourceOp op) const { diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 8b6b553..c2363a1 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -54,8 +54,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, SymbolTable::lookupSymbolIn(symbolTable, name)); if (!func) { OpBuilder b(symbolTable->getRegion(0)); - func = b.create<LLVM::LLVMFuncOp>( - symbolTable->getLoc(), name, + func = LLVM::LLVMFuncOp::create( + b, symbolTable->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes)); func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); func.setNoUnwind(true); @@ -79,7 +79,7 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args) { - auto call = rewriter.create<LLVM::CallOp>(loc, func, args); + auto call = LLVM::CallOp::create(rewriter, loc, func, args); call.setCConv(func.getCConv()); call.setConvergentAttr(func.getConvergentAttr()); call.setNoUnwindAttr(func.getNoUnwindAttr()); @@ -121,7 +121,7 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> { constexpr int64_t localMemFenceFlag = 1; Location loc = op->getLoc(); Value flag = - rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag); + LLVM::ConstantOp::create(rewriter, loc, flagTy, localMemFenceFlag); rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); return success(); } @@ -162,8 +162,8 @@ struct LaunchConfigConversion : ConvertToLLVMPattern { Location loc = op->getLoc(); gpu::Dimension dim = getDimension(op); - Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy, - static_cast<int64_t>(dim)); + Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy, + static_cast<int64_t>(dim)); rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); return success(); } @@ -291,13 +291,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { ConversionPatternRewriter &rewriter) { return TypeSwitch<Type, Value>(oldVal.getType()) .Case([&](BFloat16Type) { - return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(), - oldVal); + return LLVM::BitcastOp::create(rewriter, loc, rewriter.getI16Type(), + oldVal); }) .Case([&](IntegerType intTy) -> Value { if (intTy.getWidth() == 1) - return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(), - oldVal); + return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI8Type(), + oldVal); return oldVal; }) .Default(oldVal); @@ -308,11 +308,11 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { ConversionPatternRewriter &rewriter) { return TypeSwitch<Type, Value>(newTy) .Case([&](BFloat16Type) { - return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal); + return LLVM::BitcastOp::create(rewriter, loc, newTy, oldVal); }) .Case([&](IntegerType intTy) -> Value { if (intTy.getWidth() == 1) - return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal); + return LLVM::TruncOp::create(rewriter, loc, newTy, oldVal); return oldVal; }) .Default(oldVal); @@ -349,7 +349,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter); Value trueVal = - rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), true); rewriter.replaceOp(op, {resultOrConversion, trueVal}); return success(); } @@ -426,7 +426,7 @@ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> { if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) { return failure(); } - result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result); + result = LLVM::ZExtOp::create(rewriter, loc, indexTy, result); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 1ef6ede..317bfc2 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -118,10 +118,10 @@ struct GPUSubgroupReduceOpLowering Location loc = op->getLoc(); auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); + Value offset = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); - auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(), - mode.value(), offset); + auto reduxOp = NVVM::ReduxOp::create(rewriter, loc, int32Type, + op.getValue(), mode.value(), offset); rewriter.replaceOp(op, reduxOp->getResult(0)); return success(); @@ -158,22 +158,22 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { auto int32Type = IntegerType::get(rewriter.getContext(), 32); auto predTy = IntegerType::get(rewriter.getContext(), 1); - Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1); - Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); - Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32); - Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>( - loc, int32Type, thirtyTwo, adaptor.getWidth()); + Value one = LLVM::ConstantOp::create(rewriter, loc, int32Type, 1); + Value minusOne = LLVM::ConstantOp::create(rewriter, loc, int32Type, -1); + Value thirtyTwo = LLVM::ConstantOp::create(rewriter, loc, int32Type, 32); + Value numLeadInactiveLane = LLVM::SubOp::create( + rewriter, loc, int32Type, thirtyTwo, adaptor.getWidth()); // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. - Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne, - numLeadInactiveLane); + Value activeMask = LLVM::LShrOp::create(rewriter, loc, int32Type, minusOne, + numLeadInactiveLane); Value maskAndClamp; if (op.getMode() == gpu::ShuffleMode::UP) { // Clamp lane: `32 - activeWidth` maskAndClamp = numLeadInactiveLane; } else { // Clamp lane: `activeWidth - 1` - maskAndClamp = - rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one); + maskAndClamp = LLVM::SubOp::create(rewriter, loc, int32Type, + adaptor.getWidth(), one); } bool predIsUsed = !op->getResult(1).use_empty(); @@ -184,13 +184,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), {valueTy, predTy}); } - Value shfl = rewriter.create<NVVM::ShflOp>( - loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), - maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); + Value shfl = NVVM::ShflOp::create( + rewriter, loc, resultTy, activeMask, adaptor.getValue(), + adaptor.getOffset(), maskAndClamp, convertShflKind(op.getMode()), + returnValueAndIsValidAttr); if (predIsUsed) { - Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0); + Value shflValue = LLVM::ExtractValueOp::create(rewriter, loc, shfl, 0); Value isActiveSrcLane = - rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1); + LLVM::ExtractValueOp::create(rewriter, loc, shfl, 1); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); } else { rewriter.replaceOp(op, {shfl, nullptr}); @@ -215,16 +216,16 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> { bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>( /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); Value newOp = - rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds); + NVVM::LaneIdOp::create(rewriter, loc, rewriter.getI32Type(), bounds); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - newOp = rewriter.create<LLVM::SExtOp>( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::SExtOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { - newOp = rewriter.create<LLVM::TruncOp>( - loc, IntegerType::get(context, indexBitwidth), newOp); + newOp = LLVM::TruncOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); return success(); @@ -271,10 +272,10 @@ struct AssertOpToAssertfailLowering Block *afterBlock = rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); rewriter.setInsertionPointToEnd(beforeBlock); - rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock, - assertBlock); + cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), afterBlock, + assertBlock); rewriter.setInsertionPointToEnd(assertBlock); - rewriter.create<cf::BranchOp>(loc, afterBlock); + cf::BranchOp::create(rewriter, loc, afterBlock); // Continue cf.assert lowering. rewriter.setInsertionPoint(assertOp); @@ -301,12 +302,12 @@ struct AssertOpToAssertfailLowering // Create constants. auto getGlobal = [&](LLVM::GlobalOp global) { // Get a pointer to the format string's first element. - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()), global.getSymNameAttr()); Value start = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); return start; }; Value assertMessage = getGlobal(getOrCreateStringConstant( @@ -316,8 +317,8 @@ struct AssertOpToAssertfailLowering Value assertFunc = getGlobal(getOrCreateStringConstant( rewriter, loc, moduleOp, i8Type, "assert_func_", funcName)); Value assertLine = - rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine); - Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1); + LLVM::ConstantOp::create(rewriter, loc, i32Type, fileLine); + Value c1 = LLVM::ConstantOp::create(rewriter, loc, i64Type, 1); // Insert function call to __assertfail. SmallVector<Value> arguments{assertMessage, assertFile, assertLine, diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 45fd933..99c059c 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -126,8 +126,8 @@ struct WmmaLoadOpToNVVMLowering cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), adaptor.getSrcMemref(), adaptor.getIndices()); - Value leadingDim = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI32Type(), + Value leadingDim = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), subgroupMmaLoadMatrixOp.getLeadDimensionAttr()); rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>( op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag); @@ -173,7 +173,7 @@ struct WmmaStoreOpToNVVMLowering auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType()); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { Value toUse = - rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i); + LLVM::ExtractValueOp::create(rewriter, loc, adaptor.getSrc(), i); storeOpOperands.push_back(toUse); } @@ -181,8 +181,8 @@ struct WmmaStoreOpToNVVMLowering rewriter, loc, cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()), adaptor.getDstMemref(), adaptor.getIndices()); - Value leadingDim = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI32Type(), + Value leadingDim = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>( op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim); @@ -216,7 +216,7 @@ struct WmmaMmaOpToNVVMLowering auto unpackOp = [&](Value operand) { auto structType = cast<LLVM::LLVMStructType>(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { - Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i); + Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i); unpackedOps.push_back(toUse); } }; @@ -280,19 +280,19 @@ struct WmmaConstantOpToNVVMLowering cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType())); // If the element type is a vector create a vector from the operand. if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) { - Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType); + Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { - Value idx = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI32Type(), vecEl); - vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst, - cst, idx); + Value idx = LLVM::ConstantOp::create(rewriter, loc, + rewriter.getI32Type(), vecEl); + vecCst = LLVM::InsertElementOp::create(rewriter, loc, vecType, vecCst, + cst, idx); } cst = vecCst; } - Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type); + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type); for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { matrixStruct = - rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i); + LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i); } rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); return success(); @@ -305,17 +305,17 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Type i1Type = builder.getI1Type(); if (auto vecType = dyn_cast<VectorType>(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); - Value cmp = builder.create<LLVM::FCmpOp>( - loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, - lhs, rhs); - Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); - Value isNan = builder.create<LLVM::FCmpOp>( - loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); - Value nan = builder.create<LLVM::ConstantOp>( - loc, lhs.getType(), + Value cmp = LLVM::FCmpOp::create( + builder, loc, i1Type, + isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, lhs, rhs); + Value sel = LLVM::SelectOp::create(builder, loc, cmp, lhs, rhs); + Value isNan = LLVM::FCmpOp::create(builder, loc, i1Type, + LLVM::FCmpPredicate::uno, lhs, rhs); + Value nan = LLVM::ConstantOp::create( + builder, loc, lhs.getType(), builder.getFloatAttr(floatType, APFloat::getQNaN(floatType.getFloatSemantics()))); - return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); + return LLVM::SelectOp::create(builder, loc, isNan, nan, sel); } static Value createScalarOp(OpBuilder &builder, Location loc, @@ -323,11 +323,11 @@ static Value createScalarOp(OpBuilder &builder, Location loc, ArrayRef<Value> operands) { switch (op) { case gpu::MMAElementwiseOp::ADDF: - return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands); + return LLVM::FAddOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MULF: - return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands); + return LLVM::FMulOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::DIVF: - return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands); + return LLVM::FDivOp::create(builder, loc, operands[0].getType(), operands); case gpu::MMAElementwiseOp::MAXF: return createMinMaxF(builder, loc, operands[0], operands[1], /*isMin=*/false); @@ -356,18 +356,18 @@ struct WmmaElementwiseOpToNVVMLowering size_t numOperands = adaptor.getOperands().size(); LLVM::LLVMStructType destType = convertMMAToLLVMType( cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType())); - Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType); + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType); for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { SmallVector<Value> extractedOperands; for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { - extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( - loc, adaptor.getOperands()[opIdx], i)); + extractedOperands.push_back(LLVM::ExtractValueOp::create( + rewriter, loc, adaptor.getOperands()[opIdx], i)); } Value element = createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), extractedOperands); matrixStruct = - rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i); + LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, element, i); } rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); return success(); diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 456bfab..d22364e 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -61,10 +61,10 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth()); // TODO: use <=> in C++20. if (indexBitwidth > intWidth) { - return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value); + return LLVM::SExtOp::create(rewriter, loc, indexBitwidthType, value); } if (indexBitwidth < intWidth) { - return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value); + return LLVM::TruncOp::create(rewriter, loc, indexBitwidthType, value); } return value; } @@ -82,12 +82,12 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, const unsigned indexBitwidth) { auto int32Type = IntegerType::get(rewriter.getContext(), 32); - Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); - Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32); - Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type, - ValueRange{minus1, zero}); - Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type, - ValueRange{minus1, mbcntLo}); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); + Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, + ValueRange{minus1, zero}); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, + ValueRange{minus1, mbcntLo}); return laneId; } static constexpr StringLiteral amdgcnDataLayout = @@ -110,21 +110,21 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) Type intTy = IntegerType::get(context, 32); - Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32); - Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32); - Value mbcntLo = - rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero}); - Value laneId = rewriter.create<ROCDL::MbcntHiOp>( - loc, intTy, ValueRange{minus1, mbcntLo}); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); + Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy, + ValueRange{minus1, zero}); + Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy, + ValueRange{minus1, mbcntLo}); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); if (indexBitwidth > 32) { - laneId = rewriter.create<LLVM::SExtOp>( - loc, IntegerType::get(context, indexBitwidth), laneId); + laneId = LLVM::SExtOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } else if (indexBitwidth < 32) { - laneId = rewriter.create<LLVM::TruncOp>( - loc, IntegerType::get(context, indexBitwidth), laneId); + laneId = LLVM::TruncOp::create( + rewriter, loc, IntegerType::get(context, indexBitwidth), laneId); } rewriter.replaceOp(op, {laneId}); return success(); @@ -149,8 +149,8 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> { /*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32, /*upper=*/op.getUpperBoundAttr().getInt() + 1); } - Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>( - op.getLoc(), rewriter.getI32Type(), bounds); + Value wavefrontOp = ROCDL::WavefrontSizeOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), bounds); wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp, *getTypeConverter()); rewriter.replaceOp(op, {wavefrontOp}); @@ -190,44 +190,44 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { auto int32Type = IntegerType::get(rewriter.getContext(), 32); Value width = adaptor.getWidth(); - Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0); - Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width); - Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width); + Value zero = LLVM::ConstantOp::create(rewriter, loc, int32Type, 0); + Value negwidth = LLVM::SubOp::create(rewriter, loc, int32Type, zero, width); + Value add = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, width); Value widthOrZeroIfOutside = - rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth); + LLVM::AndOp::create(rewriter, loc, int32Type, add, negwidth); Value dstLane; switch (op.getMode()) { case gpu::ShuffleMode::UP: - dstLane = rewriter.create<LLVM::SubOp>(loc, int32Type, srcLaneId, - adaptor.getOffset()); + dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLaneId, + adaptor.getOffset()); break; case gpu::ShuffleMode::DOWN: - dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, - adaptor.getOffset()); + dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLaneId, + adaptor.getOffset()); break; case gpu::ShuffleMode::XOR: - dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId, - adaptor.getOffset()); + dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLaneId, + adaptor.getOffset()); break; case gpu::ShuffleMode::IDX: dstLane = adaptor.getOffset(); break; } - Value isActiveSrcLane = rewriter.create<LLVM::ICmpOp>( - loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); - Value selectDstLane = rewriter.create<LLVM::SelectOp>(loc, isActiveSrcLane, - dstLane, srcLaneId); - Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2); + Value isActiveSrcLane = LLVM::ICmpOp::create( + rewriter, loc, LLVM::ICmpPredicate::slt, dstLane, widthOrZeroIfOutside); + Value selectDstLane = LLVM::SelectOp::create(rewriter, loc, isActiveSrcLane, + dstLane, srcLaneId); + Value two = LLVM::ConstantOp::create(rewriter, loc, int32Type, 2); Value dwordAlignedDstLane = - rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two); + LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two); SmallVector<Value> decomposed = LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type); SmallVector<Value> swizzled; for (Value v : decomposed) { - Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type, - dwordAlignedDstLane, v); + Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type, + dwordAlignedDstLane, v); swizzled.emplace_back(res); } Value shflValue = diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index b99ed26..1817861 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -169,11 +169,11 @@ LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( Value vector = spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter); - Value dim = rewriter.create<spirv::CompositeExtractOp>( - op.getLoc(), builtinType, vector, + Value dim = spirv::CompositeExtractOp::create( + rewriter, op.getLoc(), builtinType, vector, rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())})); if (forShader && builtinType != indexType) - dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim); + dim = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, dim); rewriter.replaceOp(op, dim); return success(); } @@ -198,8 +198,8 @@ SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite( Value builtinValue = spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); if (i32Type != indexType) - builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, - builtinValue); + builtinValue = spirv::UConvertOp::create(rewriter, op.getLoc(), indexType, + builtinValue); rewriter.replaceOp(op, builtinValue); return success(); } @@ -257,8 +257,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter, signatureConverter.addInputs(argType.index(), convertedType); } } - auto newFuncOp = rewriter.create<spirv::FuncOp>( - funcOp.getLoc(), funcOp.getName(), + auto newFuncOp = spirv::FuncOp::create( + rewriter, funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), {})); for (const auto &namedAttr : funcOp->getAttrs()) { if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || @@ -367,8 +367,8 @@ LogicalResult GPUModuleConversion::matchAndRewrite( // Add a keyword to the module name to avoid symbolic conflict. std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); - auto spvModule = rewriter.create<spirv::ModuleOp>( - moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, + auto spvModule = spirv::ModuleOp::create( + rewriter, moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt, StringRef(spvModuleName)); // Move the region from the module op into the SPIR-V module. @@ -452,42 +452,42 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( switch (shuffleOp.getMode()) { case gpu::ShuffleMode::XOR: { - result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleXorOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::IDX: { - result = rewriter.create<spirv::GroupNonUniformShuffleOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), shuffleOp.getLoc(), rewriter); break; } case gpu::ShuffleMode::DOWN: { - result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleDownOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset()); - validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, - resultLaneId, adaptor.getWidth()); + arith::AddIOp::create(rewriter, loc, laneId, adaptor.getOffset()); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + resultLaneId, adaptor.getWidth()); break; } case gpu::ShuffleMode::UP: { - result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>( - loc, scope, adaptor.getValue(), adaptor.getOffset()); + result = spirv::GroupNonUniformShuffleUpOp::create( + rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset()); - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); Value resultLaneId = - rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset()); + arith::SubIOp::create(rewriter, loc, laneId, adaptor.getOffset()); auto i32Type = rewriter.getIntegerType(32); - validVal = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, resultLaneId, - rewriter.create<arith::ConstantOp>( - loc, i32Type, rewriter.getIntegerAttr(i32Type, 0))); + validVal = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, resultLaneId, + arith::ConstantOp::create(rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, 0))); break; } } @@ -507,24 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); - Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>( - loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth()); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); + Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr); - validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + IntegerAttr widthAttr = adaptor.getWidthAttr(); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); + validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); @@ -548,18 +551,18 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, ? spirv::GroupOperation::ClusteredReduce : spirv::GroupOperation::Reduce); if (isUniform) { - return builder.create<UniformOp>(loc, type, scope, groupOp, arg) + return UniformOp::create(builder, loc, type, scope, groupOp, arg) .getResult(); } Value clusterSizeValue; if (clusterSize.has_value()) - clusterSizeValue = builder.create<spirv::ConstantOp>( - loc, builder.getI32Type(), + clusterSizeValue = spirv::ConstantOp::create( + builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); - return builder - .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue) + return NonUniformOp::create(builder, loc, type, scope, groupOp, arg, + clusterSizeValue) .getResult(); } @@ -740,8 +743,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstName = makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc"); - return rewriter.create<spirv::SpecConstantOp>( - loc, rewriter.getStringAttr(specCstName), attr); + return spirv::SpecConstantOp::create( + rewriter, loc, rewriter.getStringAttr(specCstName), attr); }; { Operation *parent = @@ -774,8 +777,8 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( std::string specCstCompositeName = (llvm::Twine(globalVarName) + "_scc").str(); - specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>( - loc, TypeAttr::get(globalType), + specCstComposite = spirv::SpecConstantCompositeOp::create( + rewriter, loc, TypeAttr::get(globalType), rewriter.getStringAttr(specCstCompositeName), rewriter.getArrayAttr(constituents)); @@ -785,23 +788,24 @@ LogicalResult GPUPrintfConversion::matchAndRewrite( // Define a GlobalVarOp initialized using specialized constants // that is used to specify the printf format string // to be passed to the SPIRV CLPrintfOp. - globalVar = rewriter.create<spirv::GlobalVariableOp>( - loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite)); + globalVar = spirv::GlobalVariableOp::create( + rewriter, loc, ptrType, globalVarName, + FlatSymbolRefAttr::get(specCstComposite)); globalVar->setAttr("Constant", rewriter.getUnitAttr()); } // Get SSA value of Global variable and create pointer to i8 to point to // the format string. - Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar); - Value fmtStr = rewriter.create<spirv::BitcastOp>( - loc, + Value globalPtr = spirv::AddressOfOp::create(rewriter, loc, globalVar); + Value fmtStr = spirv::BitcastOp::create( + rewriter, loc, spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant), globalPtr); // Get printf arguments. auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs()); - rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs); + spirv::CLPrintfOp::create(rewriter, loc, i32Type, fmtStr, printfArgs); // Need to erase the gpu.printf op as gpu.printf does not use result vs // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index 0b2c06a..a344f88 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -144,11 +144,12 @@ void GPUToSPIRVPass::runOnOperation() { if (targetEnvSupportsKernelCapability(moduleOp)) { moduleOp.walk([&](gpu::GPUFuncOp funcOp) { builder.setInsertionPoint(funcOp); - auto newFuncOp = builder.create<func::FuncOp>( - funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); + auto newFuncOp = + func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(), + funcOp.getFunctionType()); auto entryBlock = newFuncOp.addEntryBlock(); builder.setInsertionPointToEnd(entryBlock); - builder.create<func::ReturnOp>(funcOp.getLoc()); + func::ReturnOp::create(builder, funcOp.getLoc()); newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); funcOp.erase(); diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index 7bb86b5..51dc500 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -283,8 +283,8 @@ struct WmmaLoadOpToSPIRVLowering final int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create<spirv::ConstantOp>( - loc, i32Type, IntegerAttr::get(i32Type, stride)); + auto strideValue = spirv::ConstantOp::create( + rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor @@ -315,8 +315,8 @@ struct WmmaStoreOpToSPIRVLowering final int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); - auto strideValue = rewriter.create<spirv::ConstantOp>( - loc, i32Type, IntegerAttr::get(i32Type, stride)); + auto strideValue = spirv::ConstantOp::create( + rewriter, loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp index 0473bb5..99d2f6c 100644 --- a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp +++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp @@ -36,34 +36,34 @@ struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); - Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); - Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); + Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1); // Compute `x`. Value mPos = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero); - Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, m, zero); + Value x = LLVM::SelectOp::create(rewriter, loc, mPos, negOne, posOne); // Compute the positive result. - Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x); - Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m); - Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne); + Value nPlusX = LLVM::AddOp::create(rewriter, loc, n, x); + Value nPlusXDivM = LLVM::SDivOp::create(rewriter, loc, nPlusX, m); + Value posRes = LLVM::AddOp::create(rewriter, loc, nPlusXDivM, posOne); // Compute the negative result. - Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n); - Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m); - Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM); + Value negN = LLVM::SubOp::create(rewriter, loc, zero, n); + Value negNDivM = LLVM::SDivOp::create(rewriter, loc, negN, m); + Value negRes = LLVM::SubOp::create(rewriter, loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. Value nPos = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero); - Value sameSign = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, n, zero); + Value sameSign = LLVM::ICmpOp::create(rewriter, loc, + LLVM::ICmpPredicate::eq, nPos, mPos); Value nNonZero = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero); - Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = LLVM::AndOp::create(rewriter, loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes); return success(); } @@ -83,17 +83,17 @@ struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); - Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value one = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); // Compute the non-zero result. - Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one); - Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m); - Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one); + Value minusOne = LLVM::SubOp::create(rewriter, loc, n, one); + Value quotient = LLVM::UDivOp::create(rewriter, loc, minusOne, m); + Value plusOne = LLVM::AddOp::create(rewriter, loc, quotient, one); // Pick the result. Value cmp = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, n, zero); rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne); return success(); } @@ -114,32 +114,32 @@ struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); - Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0); - Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1); - Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1); + Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0); + Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1); + Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1); // Compute `x`. Value mNeg = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero); - Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, m, zero); + Value x = LLVM::SelectOp::create(rewriter, loc, mNeg, posOne, negOne); // Compute the negative result. - Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n); - Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m); - Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM); + Value xMinusN = LLVM::SubOp::create(rewriter, loc, x, n); + Value xMinusNDivM = LLVM::SDivOp::create(rewriter, loc, xMinusN, m); + Value negRes = LLVM::SubOp::create(rewriter, loc, negOne, xMinusNDivM); // Compute the positive result. - Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m); + Value posRes = LLVM::SDivOp::create(rewriter, loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. Value nNeg = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero); - Value diffSign = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, n, zero); + Value diffSign = LLVM::ICmpOp::create(rewriter, loc, + LLVM::ICmpPredicate::ne, nNeg, mNeg); Value nNonZero = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero); - Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero); + LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = LLVM::AndOp::create(rewriter, loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes); return success(); } diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp index 4821962..36cfe9d 100644 --- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp +++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp @@ -111,33 +111,33 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create<spirv::ConstantOp>( - loc, n_type, IntegerAttr::get(n_type, 0)); - Value posOne = rewriter.create<spirv::ConstantOp>( - loc, n_type, IntegerAttr::get(n_type, 1)); - Value negOne = rewriter.create<spirv::ConstantOp>( - loc, n_type, IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, -1)); // Compute `x`. - Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero); - Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne); + Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero); + Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne); // Compute the positive result. - Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x); - Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m); - Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne); + Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x); + Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m); + Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne); // Compute the negative result. - Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n); - Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m); - Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM); + Value negN = spirv::ISubOp::create(rewriter, loc, zero, n); + Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m); + Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. - Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero); - Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos); - Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero); - Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero); + Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero); + Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos); + Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero); + Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes); return success(); } @@ -161,18 +161,18 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create<spirv::ConstantOp>( - loc, n_type, IntegerAttr::get(n_type, 0)); - Value one = rewriter.create<spirv::ConstantOp>(loc, n_type, - IntegerAttr::get(n_type, 1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 0)); + Value one = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 1)); // Compute the non-zero result. - Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one); - Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m); - Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one); + Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one); + Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m); + Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one); // Pick the result - Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero); + Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero); rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne); return success(); } @@ -197,32 +197,33 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> { Value m = adaptor.getRhs(); // Define the constants - Value zero = rewriter.create<spirv::ConstantOp>( - loc, n_type, IntegerAttr::get(n_type, 0)); - Value posOne = rewriter.create<spirv::ConstantOp>( - loc, n_type, IntegerAttr::get(n_type, 1)); - Value negOne = rewriter.create<spirv::ConstantOp>( - loc, n_type, IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, + IntegerAttr::get(n_type, -1)); // Compute `x`. - Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero); - Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne); + Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero); + Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne); // Compute the negative result - Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n); - Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m); - Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM); + Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n); + Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m); + Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM); // Compute the positive result. - Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m); + Value posRes = spirv::SDivOp::create(rewriter, loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. - Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero); - Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg); - Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero); + Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero); + Value diffSign = + spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg); + Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero); - Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero); + Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes); return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index e34d5f7..fce7a3f 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -32,7 +32,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor) MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType); + Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType); return MemRefDescriptor(descriptor); } @@ -99,21 +99,21 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, // integer attribute. static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { - return builder.create<LLVM::ConstantOp>(loc, resultType, - builder.getIndexAttr(value)); + return LLVM::ConstantOp::create(builder, loc, resultType, + builder.getIndexAttr(value)); } /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { - return builder.create<LLVM::ExtractValueOp>(loc, value, - kOffsetPosInMemRefDescriptor); + return LLVM::ExtractValueOp::create(builder, loc, value, + kOffsetPosInMemRefDescriptor); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { - value = builder.create<LLVM::InsertValueOp>(loc, value, offset, - kOffsetPosInMemRefDescriptor); + value = LLVM::InsertValueOp::create(builder, loc, value, offset, + kOffsetPosInMemRefDescriptor); } /// Builds IR inserting the offset into the descriptor. @@ -125,8 +125,9 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); + return LLVM::ExtractValueOp::create( + builder, loc, value, + ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); } Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, @@ -137,23 +138,25 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, // Copy size values to stack-allocated memory. auto one = createIndexAttrConstant(builder, loc, indexType, 1); - auto sizes = builder.create<LLVM::ExtractValueOp>( - loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); - auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one, - /*alignment=*/0); - builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr); + auto sizes = LLVM::ExtractValueOp::create( + builder, loc, value, + llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); + auto sizesPtr = LLVM::AllocaOp::create(builder, loc, ptrTy, arrayTy, one, + /*alignment=*/0); + LLVM::StoreOp::create(builder, loc, sizes, sizesPtr); // Load an return size value of interest. - auto resultPtr = builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr, - ArrayRef<LLVM::GEPArg>{0, pos}); - return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr); + auto resultPtr = LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, sizesPtr, + ArrayRef<LLVM::GEPArg>{0, pos}); + return LLVM::LoadOp::create(builder, loc, indexType, resultPtr); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { - value = builder.create<LLVM::InsertValueOp>( - loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); + value = LLVM::InsertValueOp::create( + builder, loc, value, size, + ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, @@ -164,15 +167,16 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, /// Builds IR extracting the pos-th stride from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); + return LLVM::ExtractValueOp::create( + builder, loc, value, + ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { - value = builder.create<LLVM::InsertValueOp>( - loc, value, stride, + value = LLVM::InsertValueOp::create( + builder, loc, value, stride, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); } @@ -207,8 +211,8 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, ? offset(builder, loc) : createIndexAttrConstant(builder, loc, indexType, offsetCst); Type elementType = converter.convertType(type.getElementType()); - ptr = builder.create<LLVM::GEPOp>(loc, ptr.getType(), elementType, ptr, - offsetVal); + ptr = LLVM::GEPOp::create(builder, loc, ptr.getType(), elementType, ptr, + offsetVal); return ptr; } @@ -303,7 +307,7 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType); + Value descriptor = LLVM::PoisonOp::create(builder, loc, descriptorType); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { @@ -380,19 +384,19 @@ void UnrankedMemRefDescriptor::computeSizes( builder, loc, indexType, llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); Value doublePointerSize = - builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize); + LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); // (1 + 2 * rank) * sizeof(index) Value rank = desc.rank(builder, loc); - Value doubleRank = builder.create<LLVM::MulOp>(loc, indexType, two, rank); + Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); Value doubleRankIncremented = - builder.create<LLVM::AddOp>(loc, indexType, doubleRank, one); - Value rankIndexSize = builder.create<LLVM::MulOp>( - loc, indexType, doubleRankIncremented, indexSize); + LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); + Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, + doubleRankIncremented, indexSize); // Total allocation size. - Value allocationSize = builder.create<LLVM::AddOp>( - loc, indexType, doublePointerSize, rankIndexSize); + Value allocationSize = LLVM::AddOp::create( + builder, loc, indexType, doublePointerSize, rankIndexSize); sizes.push_back(allocationSize); } } @@ -400,13 +404,13 @@ void UnrankedMemRefDescriptor::computeSizes( Value UnrankedMemRefDescriptor::allocatedPtr( OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType) { - return builder.create<LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr); + return LLVM::LoadOp::create(builder, loc, elemPtrType, memRefDescPtr); } void UnrankedMemRefDescriptor::setAllocatedPtr( OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) { - builder.create<LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr); + LLVM::StoreOp::create(builder, loc, allocatedPtr, memRefDescPtr); } static std::pair<Value, Type> @@ -423,9 +427,9 @@ Value UnrankedMemRefDescriptor::alignedPtr( castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); Value alignedGep = - builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, - elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); - return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep); + LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, + elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); + return LLVM::LoadOp::create(builder, loc, elemPtrType, alignedGep); } void UnrankedMemRefDescriptor::setAlignedPtr( @@ -435,9 +439,9 @@ void UnrankedMemRefDescriptor::setAlignedPtr( castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); Value alignedGep = - builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, - elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); - builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep); + LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, + elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); + LLVM::StoreOp::create(builder, loc, alignedPtr, alignedGep); } Value UnrankedMemRefDescriptor::offsetBasePtr( @@ -446,8 +450,8 @@ Value UnrankedMemRefDescriptor::offsetBasePtr( auto [elementPtrPtr, elemPtrPtrType] = castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); - return builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, - elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); + return LLVM::GEPOp::create(builder, loc, elemPtrPtrType, elemPtrType, + elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); } Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, @@ -456,8 +460,8 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, LLVM::LLVMPointerType elemPtrType) { Value offsetPtr = offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); - return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(), - offsetPtr); + return LLVM::LoadOp::create(builder, loc, typeConverter.getIndexType(), + offsetPtr); } void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, @@ -467,7 +471,7 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { Value offsetPtr = offsetBasePtr(builder, loc, typeConverter, memRefDescPtr, elemPtrType); - builder.create<LLVM::StoreOp>(loc, offset, offsetPtr); + LLVM::StoreOp::create(builder, loc, offset, offsetPtr); } Value UnrankedMemRefDescriptor::sizeBasePtr( @@ -477,8 +481,8 @@ Value UnrankedMemRefDescriptor::sizeBasePtr( Type structTy = LLVM::LLVMStructType::getLiteral( indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy}); auto resultType = LLVM::LLVMPointerType::get(builder.getContext()); - return builder.create<LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr, - ArrayRef<LLVM::GEPArg>{0, 3}); + return LLVM::GEPOp::create(builder, loc, resultType, structTy, memRefDescPtr, + ArrayRef<LLVM::GEPArg>{0, 3}); } Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, @@ -489,8 +493,8 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value sizeStoreGep = - builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); - return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index); + return LLVM::LoadOp::create(builder, loc, indexTy, sizeStoreGep); } void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, @@ -501,8 +505,8 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value sizeStoreGep = - builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, index); - builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, index); + LLVM::StoreOp::create(builder, loc, size, sizeStoreGep); } Value UnrankedMemRefDescriptor::strideBasePtr( @@ -511,7 +515,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr( Type indexTy = typeConverter.getIndexType(); auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); - return builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, sizeBasePtr, rank); + return LLVM::GEPOp::create(builder, loc, ptrType, indexTy, sizeBasePtr, rank); } Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, @@ -522,8 +526,8 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value strideStoreGep = - builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); - return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index); + return LLVM::LoadOp::create(builder, loc, indexTy, strideStoreGep); } void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, @@ -534,6 +538,6 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); Value strideStoreGep = - builder.create<LLVM::GEPOp>(loc, ptrType, indexTy, strideBasePtr, index); - builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); + LLVM::GEPOp::create(builder, loc, ptrType, indexTy, strideBasePtr, index); + LLVM::StoreOp::create(builder, loc, stride, strideStoreGep); } diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index c5f72f7..2568044 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -57,8 +57,8 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { - return builder.create<LLVM::ConstantOp>(loc, resultType, - builder.getIndexAttr(value)); + return LLVM::ConstantOp::create(builder, loc, resultType, + builder.getIndexAttr(value)); } Value ConvertToLLVMPattern::getStridedElementPtr( @@ -123,7 +123,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( runningStride = sizes[i]; else if (stride == ShapedType::kDynamic) runningStride = - rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); + LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]); else runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride); } @@ -131,10 +131,10 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes( // Buffer size in bytes. Type elementType = typeConverter->convertType(memRefType.getElementType()); auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); - Value gepPtr = rewriter.create<LLVM::GEPOp>( - loc, elementPtrType, elementType, nullPtr, runningStride); - size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType, + elementType, nullPtr, runningStride); + size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); } else { size = runningStride; } @@ -149,10 +149,10 @@ Value ConvertToLLVMPattern::getSizeInBytes( // which is a common pattern of getting the size of a type in bytes. Type llvmType = typeConverter->convertType(type); auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType); - auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType, - nullPtr, ArrayRef<LLVM::GEPArg>{1}); - return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType); + auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType, + nullPtr, ArrayRef<LLVM::GEPArg>{1}); + return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep); } Value ConvertToLLVMPattern::getNumElements( @@ -175,7 +175,7 @@ Value ConvertToLLVMPattern::getNumElements( staticSize == ShapedType::kDynamic ? dynamicSizes[dynamicIndex++] : createIndexAttrConstant(rewriter, loc, indexType, staticSize); - numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, size); } else { numElements = staticSize == ShapedType::kDynamic @@ -272,18 +272,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = - toDynamic - ? builder - .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) - .getResult() - : builder.create<LLVM::AllocaOp>(loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); + toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); - builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false); + LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); if (!toDynamic) - builder.create<LLVM::CallOp>(loc, freeFunc.value(), source); + LLVM::CallOp::create(builder, loc, freeFunc.value(), source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks @@ -349,8 +348,8 @@ LogicalResult LLVM::detail::oneToOneRewrite( SmallVector<Value, 4> results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create<LLVM::ExtractValueOp>( - op->getLoc(), newOp->getResult(0), i)); + results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(), + newOp->getResult(0), i)); } rewriter.replaceOp(op, results); return success(); @@ -371,8 +370,8 @@ LogicalResult LLVM::detail::intrinsicRewrite( if (numResults != 0) resType = typeConverter.packOperationResults(op->getResultTypes()); - auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>( - loc, resType, rewriter.getStringAttr(intrinsic), operands); + auto callIntrOp = LLVM::CallIntrinsicOp::create( + rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands); // Propagate attributes. callIntrOp->setAttrs(op->getAttrDictionary()); @@ -388,7 +387,7 @@ LogicalResult LLVM::detail::intrinsicRewrite( results.reserve(numResults); Value intrRes = callIntrOp.getResults(); for (unsigned i = 0; i < numResults; ++i) - results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i)); + results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i)); rewriter.replaceOp(op, results); return success(); @@ -406,7 +405,7 @@ static unsigned getBitWidth(Type type) { static Value createI32Constant(OpBuilder &builder, Location loc, int32_t value) { Type i32 = builder.getI32Type(); - return builder.create<LLVM::ConstantOp>(loc, i32, value); + return LLVM::ConstantOp::create(builder, loc, i32, value); } SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, @@ -418,17 +417,17 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, unsigned srcBitWidth = getBitWidth(srcType); unsigned dstBitWidth = getBitWidth(dstType); if (srcBitWidth == dstBitWidth) { - Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src); + Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src); return {cast}; } if (dstBitWidth > srcBitWidth) { auto smallerInt = builder.getIntegerType(srcBitWidth); if (srcType != smallerInt) - src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src); + src = LLVM::BitcastOp::create(builder, loc, smallerInt, src); auto largerInt = builder.getIntegerType(dstBitWidth); - Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src); + Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src); return {res}; } assert(srcBitWidth % dstBitWidth == 0 && @@ -436,12 +435,12 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, int64_t numElements = srcBitWidth / dstBitWidth; auto vecType = VectorType::get(numElements, dstType); - src = builder.create<LLVM::BitcastOp>(loc, vecType, src); + src = LLVM::BitcastOp::create(builder, loc, vecType, src); SmallVector<Value> res; for (auto i : llvm::seq(numElements)) { Value idx = createI32Constant(builder, loc, i); - Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx); + Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx); res.emplace_back(elem); } @@ -461,28 +460,28 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, if (dstBitWidth < srcBitWidth) { auto largerInt = builder.getIntegerType(srcBitWidth); if (res.getType() != largerInt) - res = builder.create<LLVM::BitcastOp>(loc, largerInt, res); + res = LLVM::BitcastOp::create(builder, loc, largerInt, res); auto smallerInt = builder.getIntegerType(dstBitWidth); - res = builder.create<LLVM::TruncOp>(loc, smallerInt, res); + res = LLVM::TruncOp::create(builder, loc, smallerInt, res); } if (res.getType() != dstType) - res = builder.create<LLVM::BitcastOp>(loc, dstType, res); + res = LLVM::BitcastOp::create(builder, loc, dstType, res); return res; } int64_t numElements = src.size(); auto srcType = VectorType::get(numElements, src.front().getType()); - Value res = builder.create<LLVM::PoisonOp>(loc, srcType); + Value res = LLVM::PoisonOp::create(builder, loc, srcType); for (auto &&[i, elem] : llvm::enumerate(src)) { Value idx = createI32Constant(builder, loc, i); - res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx); + res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx); } if (res.getType() != dstType) - res = builder.create<LLVM::BitcastOp>(loc, dstType, res); + res = LLVM::BitcastOp::create(builder, loc, dstType, res); return res; } @@ -518,20 +517,20 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc, Value stride = ShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(builder, loc, i) - : builder.create<LLVM::ConstantOp>( - loc, indexType, builder.getIndexAttr(strides[i])); - increment = - builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags); + : LLVM::ConstantOp::create(builder, loc, indexType, + builder.getIndexAttr(strides[i])); + increment = LLVM::MulOp::create(builder, loc, increment, stride, + intOverflowFlags); } - index = index ? builder.create<LLVM::AddOp>(loc, index, increment, - intOverflowFlags) + index = index ? LLVM::AddOp::create(builder, loc, index, increment, + intOverflowFlags) : increment; } Type elementPtrType = memRefDescriptor.getElementPtrType(); - return index ? builder.create<LLVM::GEPOp>( - loc, elementPtrType, - converter.convertType(type.getElementType()), base, index, - noWrapFlags) - : base; + return index + ? LLVM::GEPOp::create(builder, loc, elementPtrType, + converter.convertType(type.getElementType()), + base, index, noWrapFlags) + : base; } diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index 49c73fb..d95aeba 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -66,23 +66,23 @@ LogicalResult mlir::LLVM::createPrintStrCall( DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); auto arrayTy = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); - auto globalOp = builder.create<LLVM::GlobalOp>( - loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, + auto globalOp = LLVM::GlobalOp::create( + builder, loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr); auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); // Emit call to `printStr` in runtime library. builder.restoreInsertionPoint(ip); auto msgAddr = - builder.create<LLVM::AddressOfOp>(loc, ptrTy, globalOp.getName()); + LLVM::AddressOfOp::create(builder, loc, ptrTy, globalOp.getName()); SmallVector<LLVM::GEPArg> indices(1, 0); Value gep = - builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices); + LLVM::GEPOp::create(builder, loc, ptrTy, arrayTy, msgAddr, indices); FailureOr<LLVM::LLVMFuncOp> printer = LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName); if (failed(printer)) return failure(); - builder.create<LLVM::CallOp>(loc, TypeRange(), - SymbolRefAttr::get(printer.value()), gep); + LLVM::CallOp::create(builder, loc, TypeRange(), + SymbolRefAttr::get(printer.value()), gep); return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp index 1cd0bd8..13ed4628 100644 --- a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp @@ -24,10 +24,10 @@ StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) { Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) const { - return builder.create<LLVM::ExtractValueOp>(loc, value, pos); + return LLVM::ExtractValueOp::create(builder, loc, value, pos); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr) { - value = builder.create<LLVM::InsertValueOp>(loc, value, ptr, pos); + value = LLVM::InsertValueOp::create(builder, loc, value, ptr, pos); } diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 7312594..1a9bf56 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -91,7 +91,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder, packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); if (!packed) return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) + return UnrealizedConversionCastOp::create(builder, loc, resultType, packed) .getResult(0); } @@ -107,7 +107,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder, packRankedMemRefDesc(builder, resultType, inputs, loc, converter); if (!packed) return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) + return UnrealizedConversionCastOp::create(builder, loc, resultType, packed) .getResult(0); } @@ -224,12 +224,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, // non-LLVM types persist after an LLVM conversion. addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }); @@ -731,12 +731,12 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); - Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), - builder.getIndexAttr(1)); + Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getIndexAttr(1)); Value allocated = - builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one); + LLVM::AllocaOp::create(builder, loc, ptrType, operand.getType(), one); // Store into the alloca'ed descriptor. - builder.create<LLVM::StoreOp>(loc, operand, allocated); + LLVM::StoreOp::create(builder, loc, operand, allocated); return allocated; } diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index bf3f317..e7dd0b5 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -87,17 +87,17 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; auto loc = op->getLoc(); - Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy); + Value desc = LLVM::PoisonOp::create(rewriter, loc, resultNDVectoryTy); nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector<Value, 4> extractedOperands; for (const auto &operand : llvm::enumerate(operands)) { - extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( - loc, operand.value(), position)); + extractedOperands.push_back(LLVM::ExtractValueOp::create( + rewriter, loc, operand.value(), position)); } Value newVal = createOperand(result1DVectorTy, extractedOperands); - desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position); + desc = LLVM::InsertValueOp::create(rewriter, loc, desc, newVal, position); }); rewriter.replaceOp(op, desc); return success(); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index c3f2131..3f4b4d6 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -78,8 +78,8 @@ getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) { // Insert before module terminator. rewriter.setInsertionPoint(module.getBody(), std::prev(module.getBody()->end())); - func::FuncOp funcOp = rewriter.create<func::FuncOp>( - op->getLoc(), fnNameAttr.getValue(), libFnType); + func::FuncOp funcOp = func::FuncOp::create(rewriter, op->getLoc(), + fnNameAttr.getValue(), libFnType); // Insert a function attribute that will trigger the emission of the // corresponding `_mlir_ciface_xxx` interface so that external libraries see // a normalized ABI. This interface is added during std to llvm conversion. @@ -100,8 +100,8 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, res.push_back(op); continue; } - Value cast = - b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op); + Value cast = memref::CastOp::create( + b, loc, makeStridedLayoutDynamic(memrefType), op); res.push_back(cast); } return res; diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index d4deff5..e5496e5 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, if (!(ret = moduleOp.lookupSymbol<Op>(name))) { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); + ret = Op::create(rewriter, loc, std::forward<Args>(args)...); } return ret; } @@ -54,18 +54,18 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc, Value memRef, Type elType) { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); Value dataPtr = - rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1); - Value offset = rewriter.create<LLVM::ExtractValueOp>( - loc, rewriter.getI64Type(), memRef, 2); + LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1); + Value offset = LLVM::ExtractValueOp::create(rewriter, loc, + rewriter.getI64Type(), memRef, 2); Value resPtr = - rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset); + LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset); Value size; if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) { - size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef, - ArrayRef<int64_t>{3, 0}); - size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size); + size = LLVM::ExtractValueOp::create(rewriter, loc, memRef, + ArrayRef<int64_t>{3, 0}); + size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size); } else { - size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32); + size = arith::ConstantIntOp::create(rewriter, loc, 1, 32); } return {resPtr, size}; } @@ -157,13 +157,13 @@ public: Value getCommWorld(const Location loc, ConversionPatternRewriter &rewriter) override { static constexpr int MPI_COMM_WORLD = 0x44000000; - return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), - MPI_COMM_WORLD); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + MPI_COMM_WORLD); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm); + return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm); } intptr_t getStatusIgnore() override { return 1; } @@ -195,7 +195,8 @@ public: mtype = MPI_UINT8_T; else assert(false && "unsupported type"); - return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), + mtype); } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, @@ -245,7 +246,7 @@ public: op = MPI_REPLACE; break; } - return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op); } }; @@ -281,16 +282,16 @@ public: getOrDefineExternalStruct(loc, rewriter, name, commStructT); // get address of symbol - auto comm = rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, name)); - return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm); + auto comm = LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, name)); + return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create<LLVM::IntToPtrOp>( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); + return LLVM::IntToPtrOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); } intptr_t getStatusIgnore() override { return 0; } @@ -330,9 +331,9 @@ public: // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT); // get address of symbol - return rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, mtype)); + return LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, mtype)); } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, @@ -389,9 +390,9 @@ public: // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, op, opStructT); // get address of symbol - return rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, op)); + return LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, op)); } }; @@ -424,7 +425,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr` - auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType); + auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType); Value llvmnull = nullPtrOp.getRes(); // grab a reference to the global module op: @@ -513,9 +514,9 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> { // get communicator Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); auto outPtr = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one); + LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one); // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm) auto funcType = @@ -524,14 +525,14 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> { LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Comm_split", funcType); - auto callOp = rewriter.create<LLVM::CallOp>( - loc, funcDecl, - ValueRange{comm, adaptor.getColor(), adaptor.getKey(), - outPtr.getRes()}); + auto callOp = + LLVM::CallOp::create(rewriter, loc, funcDecl, + ValueRange{comm, adaptor.getColor(), + adaptor.getKey(), outPtr.getRes()}); // load the communicator into a register - Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult()); - res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res); + Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult()); + res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res); // if retval is checked, replace uses of retval with the results from the // call op @@ -580,14 +581,14 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> { moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType); // replace with function call - auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); - auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one); - auto callOp = rewriter.create<LLVM::CallOp>( - loc, initDecl, ValueRange{comm, rankptr.getRes()}); + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); + auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one); + auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl, + ValueRange{comm, rankptr.getRes()}); // load the rank into a register auto loadedRank = - rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult()); + LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult()); // if retval is checked, replace uses of retval with the results from the // call op @@ -641,10 +642,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType); // replace op with function call - auto funcCall = rewriter.create<LLVM::CallOp>( - loc, funcDecl, - ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(), - comm}); + auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl, + ValueRange{dataPtr, size, dataType, + adaptor.getDest(), + adaptor.getTag(), comm}); if (op.getRetval()) rewriter.replaceOp(op, funcCall.getResult()); else @@ -683,10 +684,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { auto mpiTraits = MPIImplTraits::get(moduleOp); Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - Value statusIgnore = rewriter.create<LLVM::ConstantOp>( - loc, i64, mpiTraits->getStatusIgnore()); + Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64, + mpiTraits->getStatusIgnore()); statusIgnore = - rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore); + LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore); // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst, // tag, comm)` @@ -698,8 +699,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType); // replace op with function call - auto funcCall = rewriter.create<LLVM::CallOp>( - loc, funcDecl, + auto funcCall = LLVM::CallOp::create( + rewriter, loc, funcDecl, ValueRange{dataPtr, size, dataType, adaptor.getSource(), adaptor.getTag(), comm, statusIgnore}); if (op.getRetval()) @@ -738,9 +739,10 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> { // If input and output are the same, request in-place operation. if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { - sendPtr = rewriter.create<LLVM::ConstantOp>( - loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace())); - sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr); + sendPtr = LLVM::ConstantOp::create( + rewriter, loc, i64, + reinterpret_cast<int64_t>(mpiTraits->getInPlace())); + sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr); } Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); @@ -757,8 +759,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType); // replace op with function call - auto funcCall = rewriter.create<LLVM::CallOp>( - loc, funcDecl, + auto funcCall = LLVM::CallOp::create( + rewriter, loc, funcDecl, ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld}); if (op.getRetval()) diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 7f4655e..cde2340 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -17,13 +17,12 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS @@ -33,7 +32,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-funcs" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace { // Pattern to convert vector operations to scalar operations. @@ -121,19 +119,19 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { initValueAttr = FloatAttr::get(resultElementType, 0.0); else initValueAttr = IntegerAttr::get(resultElementType, 0); - Value result = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(vecType, initValueAttr)); + Value result = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(vecType, initValueAttr)); SmallVector<int64_t> strides = computeStrides(shape); for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector<int64_t> positions = delinearize(linearIndex, strides); SmallVector<Value> operands; for (Value input : op->getOperands()) operands.push_back( - rewriter.create<vector::ExtractOp>(loc, input, positions)); + vector::ExtractOp::create(rewriter, loc, input, positions)); Value scalarOp = - rewriter.create<Op>(loc, vecType.getElementType(), operands); + Op::create(rewriter, loc, vecType.getElementType(), operands); result = - rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); + vector::InsertOp::create(rewriter, loc, scalarOp, result, positions); } rewriter.replaceOp(op, result); return success(); @@ -195,7 +193,7 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { FunctionType funcType = FunctionType::get( builder.getContext(), {elementType, elementType}, elementType); - auto funcOp = builder.create<func::FuncOp>(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; Attribute linkage = LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); @@ -208,12 +206,12 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { Value bArg = funcOp.getArgument(0); Value pArg = funcOp.getArgument(1); builder.setInsertionPointToEnd(entryBlock); - Value zeroValue = builder.create<arith::ConstantOp>( - elementType, builder.getIntegerAttr(elementType, 0)); - Value oneValue = builder.create<arith::ConstantOp>( - elementType, builder.getIntegerAttr(elementType, 1)); - Value minusOneValue = builder.create<arith::ConstantOp>( - elementType, + Value zeroValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, 0)); + Value oneValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, 1)); + Value minusOneValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, APInt(elementType.getIntOrFloatBitWidth(), -1ULL, /*isSigned=*/true))); @@ -221,82 +219,83 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { // if (p == T(0)) // return T(1); auto pIsZero = - builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, zeroValue); Block *thenBlock = builder.createBlock(funcBody); - builder.create<func::ReturnOp>(oneValue); + func::ReturnOp::create(builder, oneValue); Block *fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(pIsZero->getBlock()); - builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock); // if (p < T(0)) { builder.setInsertionPointToEnd(fallthroughBlock); - auto pIsNeg = - builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue); + auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg, + zeroValue); // if (b == T(0)) builder.createBlock(funcBody); auto bIsZero = - builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, zeroValue); // return T(1) / T(0); thenBlock = builder.createBlock(funcBody); - builder.create<func::ReturnOp>( - builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult()); + func::ReturnOp::create( + builder, + arith::DivSIOp::create(builder, oneValue, zeroValue).getResult()); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(0)). builder.setInsertionPointToEnd(bIsZero->getBlock()); - builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, bIsZero, thenBlock, fallthroughBlock); // if (b == T(1)) builder.setInsertionPointToEnd(fallthroughBlock); auto bIsOne = - builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, bArg, oneValue); // return T(1); thenBlock = builder.createBlock(funcBody); - builder.create<func::ReturnOp>(oneValue); + func::ReturnOp::create(builder, oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(1)). builder.setInsertionPointToEnd(bIsOne->getBlock()); - builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, bIsOne, thenBlock, fallthroughBlock); // if (b == T(-1)) { builder.setInsertionPointToEnd(fallthroughBlock); - auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, - bArg, minusOneValue); + auto bIsMinusOne = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, + bArg, minusOneValue); // if (p & T(1)) builder.createBlock(funcBody); - auto pIsOdd = builder.create<arith::CmpIOp>( - arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue), - zeroValue); + auto pIsOdd = arith::CmpIOp::create( + builder, arith::CmpIPredicate::ne, + arith::AndIOp::create(builder, pArg, oneValue), zeroValue); // return T(-1); thenBlock = builder.createBlock(funcBody); - builder.create<func::ReturnOp>(minusOneValue); + func::ReturnOp::create(builder, minusOneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(pIsOdd->getBlock()); - builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsOdd, thenBlock, fallthroughBlock); // return T(1); // } // b == T(-1) builder.setInsertionPointToEnd(fallthroughBlock); - builder.create<func::ReturnOp>(oneValue); + func::ReturnOp::create(builder, oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(-1)). builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); - builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(), - fallthroughBlock); + cf::CondBranchOp::create(builder, bIsMinusOne, pIsOdd->getBlock(), + fallthroughBlock); // return T(0); // } // (p < T(0)) builder.setInsertionPointToEnd(fallthroughBlock); - builder.create<func::ReturnOp>(zeroValue); + func::ReturnOp::create(builder, zeroValue); Block *loopHeader = builder.createBlock( funcBody, funcBody->end(), {elementType, elementType, elementType}, {builder.getLoc(), builder.getLoc(), builder.getLoc()}); // Set up conditional branch for (p < T(0)). builder.setInsertionPointToEnd(pIsNeg->getBlock()); // Set initial values of 'result', 'b' and 'p' for the loop. - builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader, - ValueRange{oneValue, bArg, pArg}); + cf::CondBranchOp::create(builder, pIsNeg, bIsZero->getBlock(), loopHeader, + ValueRange{oneValue, bArg, pArg}); // T result = T(1); // while (true) { @@ -313,45 +312,46 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { builder.setInsertionPointToEnd(loopHeader); // if (p & T(1)) - auto powerTmpIsOdd = builder.create<arith::CmpIOp>( - arith::CmpIPredicate::ne, - builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue); + auto powerTmpIsOdd = arith::CmpIOp::create( + builder, arith::CmpIPredicate::ne, + arith::AndIOp::create(builder, powerTmp, oneValue), zeroValue); thenBlock = builder.createBlock(funcBody); // result *= b; - Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp); + Value newResultTmp = arith::MulIOp::create(builder, resultTmp, baseTmp); fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, builder.getLoc()); builder.setInsertionPointToEnd(thenBlock); - builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); - builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock, - resultTmp); + cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock, + resultTmp); // Merged 'result'. newResultTmp = fallthroughBlock->getArgument(0); // p >>= T(1); builder.setInsertionPointToEnd(fallthroughBlock); - Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue); + Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, oneValue); // if (p == T(0)) - auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, - newPowerTmp, zeroValue); + auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, + newPowerTmp, zeroValue); // return result; thenBlock = builder.createBlock(funcBody); - builder.create<func::ReturnOp>(newResultTmp); + func::ReturnOp::create(builder, newResultTmp); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); - builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, newPowerIsZero, thenBlock, + fallthroughBlock); // b *= b; // } builder.setInsertionPointToEnd(fallthroughBlock); - Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp); + Value newBaseTmp = arith::MulIOp::create(builder, baseTmp, baseTmp); // Pass new values for 'result', 'b' and 'p' to the loop header. - builder.create<cf::BranchOp>( - ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); + cf::BranchOp::create( + builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); return funcOp; } @@ -420,7 +420,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, llvm::raw_string_ostream nameOS(funcName); nameOS << '_' << baseType; nameOS << '_' << powType; - auto funcOp = builder.create<func::FuncOp>(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; Attribute linkage = LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); @@ -433,46 +433,48 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Value bArg = funcOp.getArgument(0); Value pArg = funcOp.getArgument(1); builder.setInsertionPointToEnd(entryBlock); - Value oneBValue = builder.create<arith::ConstantOp>( - baseType, builder.getFloatAttr(baseType, 1.0)); - Value zeroPValue = builder.create<arith::ConstantOp>( - powType, builder.getIntegerAttr(powType, 0)); - Value onePValue = builder.create<arith::ConstantOp>( - powType, builder.getIntegerAttr(powType, 1)); - Value minPValue = builder.create<arith::ConstantOp>( - powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue( - powType.getWidth()))); - Value maxPValue = builder.create<arith::ConstantOp>( - powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue( - powType.getWidth()))); + Value oneBValue = arith::ConstantOp::create( + builder, baseType, builder.getFloatAttr(baseType, 1.0)); + Value zeroPValue = arith::ConstantOp::create( + builder, powType, builder.getIntegerAttr(powType, 0)); + Value onePValue = arith::ConstantOp::create( + builder, powType, builder.getIntegerAttr(powType, 1)); + Value minPValue = arith::ConstantOp::create( + builder, powType, + builder.getIntegerAttr( + powType, llvm::APInt::getSignedMinValue(powType.getWidth()))); + Value maxPValue = arith::ConstantOp::create( + builder, powType, + builder.getIntegerAttr( + powType, llvm::APInt::getSignedMaxValue(powType.getWidth()))); // if (p == Tp{0}) // return Tb{1}; - auto pIsZero = - builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue); + auto pIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, + zeroPValue); Block *thenBlock = builder.createBlock(funcBody); - builder.create<func::ReturnOp>(oneBValue); + func::ReturnOp::create(builder, oneBValue); Block *fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == Tp{0}). builder.setInsertionPointToEnd(pIsZero->getBlock()); - builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock); + cf::CondBranchOp::create(builder, pIsZero, thenBlock, fallthroughBlock); builder.setInsertionPointToEnd(fallthroughBlock); // bool isNegativePower{p < Tp{0}} - auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, - zeroPValue); + auto pIsNeg = arith::CmpIOp::create(builder, arith::CmpIPredicate::sle, pArg, + zeroPValue); // bool isMin{p == std::numeric_limits<Tp>::min()}; auto pIsMin = - builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, pArg, minPValue); // if (isMin) { // p = std::numeric_limits<Tp>::max(); // } else if (isNegativePower) { // p = -p; // } - Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg); - auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg); - pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit); + Value negP = arith::SubIOp::create(builder, zeroPValue, pArg); + auto pInit = arith::SelectOp::create(builder, pIsNeg, negP, pArg); + pInit = arith::SelectOp::create(builder, pIsMin, maxPValue, pInit); // Tb result = Tb{1}; // Tb origBase = Tb{b}; @@ -489,7 +491,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, {builder.getLoc(), builder.getLoc(), builder.getLoc()}); // Set initial values of 'result', 'b' and 'p' for the loop. builder.setInsertionPointToEnd(pInit->getBlock()); - builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit}); + cf::BranchOp::create(builder, loopHeader, ValueRange{oneBValue, bArg, pInit}); // Create loop body. Value resultTmp = loopHeader->getArgument(0); @@ -498,30 +500,30 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, builder.setInsertionPointToEnd(loopHeader); // if (p & Tp{1}) - auto powerTmpIsOdd = builder.create<arith::CmpIOp>( - arith::CmpIPredicate::ne, - builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue); + auto powerTmpIsOdd = arith::CmpIOp::create( + builder, arith::CmpIPredicate::ne, + arith::AndIOp::create(builder, powerTmp, onePValue), zeroPValue); thenBlock = builder.createBlock(funcBody); // result *= b; - Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp); + Value newResultTmp = arith::MulFOp::create(builder, resultTmp, baseTmp); fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(thenBlock); - builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); // Set up conditional branch for (p & Tp{1}). builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); - builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock, - resultTmp); + cf::CondBranchOp::create(builder, powerTmpIsOdd, thenBlock, fallthroughBlock, + resultTmp); // Merged 'result'. newResultTmp = fallthroughBlock->getArgument(0); // p >>= Tp{1}; builder.setInsertionPointToEnd(fallthroughBlock); - Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue); + Value newPowerTmp = arith::ShRUIOp::create(builder, powerTmp, onePValue); // if (p == Tp{0}) - auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, - newPowerTmp, zeroPValue); + auto newPowerIsZero = arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, + newPowerTmp, zeroPValue); // break; // // The conditional branch is finalized below with a jump to @@ -531,10 +533,10 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, // b *= b; // } builder.setInsertionPointToEnd(fallthroughBlock); - Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp); + Value newBaseTmp = arith::MulFOp::create(builder, baseTmp, baseTmp); // Pass new values for 'result', 'b' and 'p' to the loop header. - builder.create<cf::BranchOp>( - ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); + cf::BranchOp::create( + builder, ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); // Set up conditional branch for early loop exit: // if (p == Tp{0}) @@ -542,8 +544,8 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); - builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp, - fallthroughBlock, ValueRange{}); + cf::CondBranchOp::create(builder, newPowerIsZero, loopExit, newResultTmp, + fallthroughBlock, ValueRange{}); // if (isMin) { // result *= origBase; @@ -553,11 +555,11 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(loopExit); - builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock, - newResultTmp); + cf::CondBranchOp::create(builder, pIsMin, thenBlock, fallthroughBlock, + newResultTmp); builder.setInsertionPointToEnd(thenBlock); - newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg); - builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); + newResultTmp = arith::MulFOp::create(builder, newResultTmp, bArg); + cf::BranchOp::create(builder, newResultTmp, fallthroughBlock); /// if (isNegativePower) { /// result = Tb{1} / result; @@ -567,15 +569,15 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType, builder.getLoc()); builder.setInsertionPointToEnd(fallthroughBlock); - builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock, - newResultTmp); + cf::CondBranchOp::create(builder, pIsNeg, thenBlock, returnBlock, + newResultTmp); builder.setInsertionPointToEnd(thenBlock); - newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp); - builder.create<cf::BranchOp>(newResultTmp, returnBlock); + newResultTmp = arith::DivFOp::create(builder, oneBValue, newResultTmp); + cf::BranchOp::create(builder, newResultTmp, returnBlock); // return result; builder.setInsertionPointToEnd(returnBlock); - builder.create<func::ReturnOp>(returnBlock->getArgument(0)); + func::ReturnOp::create(builder, returnBlock->getArgument(0)); return funcOp; } @@ -650,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op, /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { if (!isa<IntegerType>(elementType)) { - LLVM_DEBUG({ - DBGS() << "non-integer element type for CtlzFunc; type was: "; - elementType.print(llvm::dbgs()); - }); + LDBG() << "non-integer element type for CtlzFunc; type was: " + << elementType; llvm_unreachable("non-integer element type"); } int64_t bitWidth = elementType.getIntOrFloatBitWidth(); @@ -667,7 +667,7 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { nameOS << '_' << elementType; FunctionType funcType = FunctionType::get(builder.getContext(), {elementType}, elementType); - auto funcOp = builder.create<func::FuncOp>(funcName, funcType); + auto funcOp = func::FuncOp::create(builder, funcName, funcType); // LinkonceODR ensures that there is only one implementation of this function // across all math.ctlz functions that are lowered in this way. @@ -683,33 +683,35 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { Value arg = funcOp.getArgument(0); Type indexType = builder.getIndexType(); - Value bitWidthValue = builder.create<arith::ConstantOp>( - elementType, builder.getIntegerAttr(elementType, bitWidth)); - Value zeroValue = builder.create<arith::ConstantOp>( - elementType, builder.getIntegerAttr(elementType, 0)); + Value bitWidthValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, bitWidth)); + Value zeroValue = arith::ConstantOp::create( + builder, elementType, builder.getIntegerAttr(elementType, 0)); Value inputEqZero = - builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue); + arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, arg, zeroValue); // if input == 0, return bit width, else enter loop. - scf::IfOp ifOp = builder.create<scf::IfOp>( - elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); - ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue); + scf::IfOp ifOp = + scf::IfOp::create(builder, elementType, inputEqZero, + /*addThenBlock=*/true, /*addElseBlock=*/true); + auto thenBuilder = ifOp.getThenBodyBuilder(); + scf::YieldOp::create(thenBuilder, loc, bitWidthValue); auto elseBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); - Value oneIndex = elseBuilder.create<arith::ConstantOp>( - indexType, elseBuilder.getIndexAttr(1)); - Value oneValue = elseBuilder.create<arith::ConstantOp>( - elementType, elseBuilder.getIntegerAttr(elementType, 1)); - Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>( - indexType, elseBuilder.getIndexAttr(bitWidth)); - Value nValue = elseBuilder.create<arith::ConstantOp>( - elementType, elseBuilder.getIntegerAttr(elementType, 0)); - - auto loop = elseBuilder.create<scf::ForOp>( - oneIndex, bitWidthIndex, oneIndex, + Value oneIndex = arith::ConstantOp::create(elseBuilder, indexType, + elseBuilder.getIndexAttr(1)); + Value oneValue = arith::ConstantOp::create( + elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 1)); + Value bitWidthIndex = arith::ConstantOp::create( + elseBuilder, indexType, elseBuilder.getIndexAttr(bitWidth)); + Value nValue = arith::ConstantOp::create( + elseBuilder, elementType, elseBuilder.getIntegerAttr(elementType, 0)); + + auto loop = scf::ForOp::create( + elseBuilder, oneIndex, bitWidthIndex, oneIndex, // Initial values for two loop induction variables, the arg which is being // shifted left in each iteration, and the n value which tracks the count // of leading zeros. @@ -725,25 +727,25 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { Value argIter = args[0]; Value nIter = args[1]; - Value argIsNonNegative = b.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, argIter, zeroValue); - scf::IfOp ifOp = b.create<scf::IfOp>( - loc, argIsNonNegative, + Value argIsNonNegative = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::slt, argIter, zeroValue); + scf::IfOp ifOp = scf::IfOp::create( + b, loc, argIsNonNegative, [&](OpBuilder &b, Location loc) { // If arg is negative, continue (effectively, break) - b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter}); + scf::YieldOp::create(b, loc, ValueRange{argIter, nIter}); }, [&](OpBuilder &b, Location loc) { // Otherwise, increment n and shift arg left. - Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue); - Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue); - b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext}); + Value nNext = arith::AddIOp::create(b, loc, nIter, oneValue); + Value argNext = arith::ShLIOp::create(b, loc, argIter, oneValue); + scf::YieldOp::create(b, loc, ValueRange{argNext, nNext}); }); - b.create<scf::YieldOp>(loc, ifOp.getResults()); + scf::YieldOp::create(b, loc, ifOp.getResults()); }); - elseBuilder.create<scf::YieldOp>(loop.getResult(1)); + scf::YieldOp::create(elseBuilder, loop.getResult(1)); - builder.create<func::ReturnOp>(ifOp.getResult(0)); + func::ReturnOp::create(builder, ifOp.getResult(0)); return funcOp; } diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index f4d69ce..853f454 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -107,8 +107,8 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> { return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), typeConverter, [&](Type llvm1DVectorTy, ValueRange operands) { - return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0], - false); + return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0], + false); }, rewriter); } @@ -145,15 +145,16 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(llvmOperandType)) { - one = rewriter.create<LLVM::ConstantOp>( - loc, llvmOperandType, + one = LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast<ShapedType>(llvmOperandType), floatOne)); } else { - one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne); + one = + LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(), - expAttrs.getAttrs()); + auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(), + expAttrs.getAttrs()); rewriter.replaceOpWithNewOp<LLVM::FSubOp>( op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs()); return success(); @@ -170,12 +171,13 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); - auto exp = rewriter.create<LLVM::ExpOp>( - loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); - return rewriter.create<LLVM::FSubOp>( - loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy, + operands[0], expAttrs.getAttrs()); + return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{exp, one}, + subAttrs.getAttrs()); }, rewriter); } @@ -205,16 +207,16 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) { LLVM::ConstantOp one = isa<VectorType>(llvmOperandType) - ? rewriter.create<LLVM::ConstantOp>( - loc, llvmOperandType, + ? LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast<ShapedType>(llvmOperandType), floatOne)) - : rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, - floatOne); + : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, + floatOne); - auto add = rewriter.create<LLVM::FAddOp>( - loc, llvmOperandType, ValueRange{one, adaptor.getOperand()}, - addAttrs.getAttrs()); + auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType, + ValueRange{one, adaptor.getOperand()}, + addAttrs.getAttrs()); rewriter.replaceOpWithNewOp<LLVM::LogOp>( op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs()); return success(); @@ -231,13 +233,13 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); - auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, - ValueRange{one, operands[0]}, - addAttrs.getAttrs()); - return rewriter.create<LLVM::LogOp>( - loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{one, operands[0]}, + addAttrs.getAttrs()); + return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{add}, logAttrs.getAttrs()); }, rewriter); } @@ -267,15 +269,16 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) { LLVM::ConstantOp one; if (isa<VectorType>(llvmOperandType)) { - one = rewriter.create<LLVM::ConstantOp>( - loc, llvmOperandType, + one = LLVM::ConstantOp::create( + rewriter, loc, llvmOperandType, SplatElementsAttr::get(cast<ShapedType>(llvmOperandType), floatOne)); } else { - one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne); + one = + LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne); } - auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(), - sqrtAttrs.getAttrs()); + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(), + sqrtAttrs.getAttrs()); rewriter.replaceOpWithNewOp<LLVM::FDivOp>( op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); return success(); @@ -292,12 +295,13 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> { mlir::VectorType::get({numElements.getKnownMinValue()}, floatType, {numElements.isScalable()}), floatOne); - auto one = - rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr); - auto sqrt = rewriter.create<LLVM::SqrtOp>( - loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); - return rewriter.create<LLVM::FDivOp>( - loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); + auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy, + splatAttr); + auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy, + operands[0], sqrtAttrs.getAttrs()); + return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy, + ValueRange{one, sqrt}, + divAttrs.getAttrs()); }, rewriter); } diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index a0ce7d3..f7c0d4f 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -84,20 +84,21 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto shape = vecType.getShape(); int64_t numElements = vecType.getNumElements(); - Value result = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get( - vecType, FloatAttr::get(vecType.getElementType(), 0.0))); + Value result = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(vecType, + FloatAttr::get(vecType.getElementType(), 0.0))); SmallVector<int64_t> strides = computeStrides(shape); for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector<int64_t> positions = delinearize(linearIndex, strides); SmallVector<Value> operands; for (auto input : op->getOperands()) operands.push_back( - rewriter.create<vector::ExtractOp>(loc, input, positions)); + vector::ExtractOp::create(rewriter, loc, input, positions)); Value scalarOp = - rewriter.create<Op>(loc, vecType.getElementType(), operands); + Op::create(rewriter, loc, vecType.getElementType(), operands); result = - rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); + vector::InsertOp::create(rewriter, loc, scalarOp, result, positions); } rewriter.replaceOp(op, {result}); return success(); @@ -114,9 +115,9 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto f32 = rewriter.getF32Type(); auto extendedOperands = llvm::to_vector( llvm::map_range(op->getOperands(), [&](Value operand) -> Value { - return rewriter.create<arith::ExtFOp>(loc, f32, operand); + return arith::ExtFOp::create(rewriter, loc, f32, operand); })); - auto newOp = rewriter.create<Op>(loc, f32, extendedOperands); + auto newOp = Op::create(rewriter, loc, f32, extendedOperands); rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp); return success(); } @@ -139,8 +140,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); - opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, - opFunctionTy); + opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), name, + opFunctionTy); opFunc.setPrivate(); // By definition Math dialect operations imply LLVM's "readnone" diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 93d8b49..df219f3 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,7 +22,6 @@ #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOROCDL @@ -31,7 +31,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-rocdl" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") template <typename OpTy> static void populateOpPatterns(const LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 59db14e..a877ad2 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -36,12 +36,12 @@ static Value getScalarOrVectorI32Constant(Type type, int value, if (!vectorType.getElementType().isInteger(32)) return nullptr; SmallVector<int> values(vectorType.getNumElements(), value); - return builder.create<spirv::ConstantOp>(loc, type, - builder.getI32VectorAttr(values)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getI32VectorAttr(values)); } if (type.isInteger(32)) - return builder.create<spirv::ConstantOp>(loc, type, - builder.getI32IntegerAttr(value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getI32IntegerAttr(value)); return nullptr; } @@ -144,10 +144,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { Type intType = rewriter.getIntegerType(bitwidth); uint64_t intValue = uint64_t(1) << (bitwidth - 1); - Value signMask = rewriter.create<spirv::ConstantOp>( - loc, intType, rewriter.getIntegerAttr(intType, intValue)); - Value valueMask = rewriter.create<spirv::ConstantOp>( - loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); + Value signMask = spirv::ConstantOp::create( + rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue)); + Value valueMask = spirv::ConstantOp::create( + rewriter, loc, intType, + rewriter.getIntegerAttr(intType, intValue - 1u)); if (auto vectorType = dyn_cast<VectorType>(type)) { assert(vectorType.getRank() == 1); @@ -155,26 +156,26 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { intType = VectorType::get(count, intType); SmallVector<Value> signSplat(count, signMask); - signMask = - rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); + signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, + signSplat); SmallVector<Value> valueSplat(count, valueMask); - valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, - valueSplat); + valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, + valueSplat); } Value lhsCast = - rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs()); Value rhsCast = - rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); + spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs()); - Value value = rewriter.create<spirv::BitwiseAndOp>( - loc, intType, ValueRange{lhsCast, valueMask}); - Value sign = rewriter.create<spirv::BitwiseAndOp>( - loc, intType, ValueRange{rhsCast, signMask}); + Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType, + ValueRange{lhsCast, valueMask}); + Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType, + ValueRange{rhsCast, signMask}); - Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, - ValueRange{value, sign}); + Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType, + ValueRange{value, sign}); rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); return success(); } @@ -214,18 +215,18 @@ struct CountLeadingZerosPattern final Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); - Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input); + Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input); // We need to subtract from 31 given that the index returned by GLSL // FindUMsb is counted from the least significant bit. Theoretically this // also gives the correct result even if the integer has all zero bits, in // which case GL FindUMsb would return -1. - Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb); + Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb); // However, certain Vulkan implementations have driver bugs for the corner // case where the input is zero. And.. it can be smart to optimize a select // only involving the corner case. So separately compute the result when the // input is either zero or one. - Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input); - Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1); + Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input); + Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1); rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput, subMsb); return success(); @@ -253,7 +254,7 @@ struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { if (!type) return failure(); - Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); + Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); return success(); @@ -283,7 +284,7 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); Value onePlus = - rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); + spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); return success(); } @@ -321,15 +322,15 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> { auto getConstantValue = [&](double value) { if (auto floatType = dyn_cast<FloatType>(type)) { - return rewriter.create<spirv::ConstantOp>( - loc, type, rewriter.getFloatAttr(floatType, value)); + return spirv::ConstantOp::create( + rewriter, loc, type, rewriter.getFloatAttr(floatType, value)); } if (auto vectorType = dyn_cast<VectorType>(type)) { Type elemType = vectorType.getElementType(); if (isa<FloatType>(elemType)) { - return rewriter.create<spirv::ConstantOp>( - loc, type, + return spirv::ConstantOp::create( + rewriter, loc, type, DenseFPElementsAttr::get( vectorType, FloatAttr::get(elemType, value).getValue())); } @@ -341,7 +342,7 @@ struct Log2Log10OpPattern final : public OpConversionPattern<MathLogOp> { Value constantValue = getConstantValue( std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal : log10Reciprocal); - Value log = rewriter.create<SpirvLogOp>(loc, adaptor.getOperand()); + Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand()); rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log, constantValue); return success(); @@ -386,7 +387,7 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { Location loc = powfOp.getLoc(); Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter); Value lessThan = - rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero); + spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero); // Per C/C++ spec: // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is @@ -394,11 +395,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { // Calculate the reminder from the exponent and check whether it is zero. Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter); Value expRem = - rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne); + spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne); Value expRemNonZero = - rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero); + spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero); Value cmpNegativeWithFractionalExp = - rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan); + spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan); // Create NaN result and replace base value if conditions are met. const auto &floatSemantics = scalarFloatType.getFloatSemantics(); const auto nan = APFloat::getNaN(floatSemantics); @@ -407,10 +408,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { nanAttr = DenseElementsAttr::get(vectorType, nan); Value NanValue = - rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr); - Value lhs = rewriter.create<spirv::SelectOp>( - loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs()); - Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs); + spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); + Value lhs = + spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, + NanValue, adaptor.getLhs()); + Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in // order to properly propagate the sign, assuming integer y cases. It @@ -418,18 +420,18 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { // Cast exponent to integer and calculate exponent % 2 != 0. Value intRhs = - rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs()); + spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs()); Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); Value bitwiseAndOne = - rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne); - Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne); + spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne); + Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne); // calculate pow based on abs(lhs)^rhs. - Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs()); - Value negate = rewriter.create<spirv::FNegateOp>(loc, pow); + Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs()); + Value negate = spirv::FNegateOp::create(rewriter, loc, pow); // if the exponent is odd and lhs < 0, negate the result. Value shouldNegate = - rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd); + spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd); rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate, pow); return success(); @@ -455,22 +457,22 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> { auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); Value half; if (VectorType vty = dyn_cast<VectorType>(ty)) { - half = rewriter.create<spirv::ConstantOp>( - loc, vty, + half = spirv::ConstantOp::create( + rewriter, loc, vty, DenseElementsAttr::get(vty, rewriter.getFloatAttr(ety, 0.5).getValue())); } else { - half = rewriter.create<spirv::ConstantOp>( - loc, ty, rewriter.getFloatAttr(ety, 0.5)); + half = spirv::ConstantOp::create(rewriter, loc, ty, + rewriter.getFloatAttr(ety, 0.5)); } - auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand); - auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs); - auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor); + auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand); + auto floor = spirv::GLFloorOp::create(rewriter, loc, abs); + auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor); auto greater = - rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half); - auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero); - auto add = rewriter.create<spirv::FAddOp>(loc, floor, select); + spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half); + auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero); + auto add = spirv::FAddOp::create(rewriter, loc, floor, select); rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand); return success(); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 0b7ffa4..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -160,8 +230,8 @@ struct ConvertGetGlobal final if (opTy.getRank() == 0) { emitc::LValueType lvalueType = emitc::LValueType::get(resultTy); - emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>( - op.getLoc(), lvalueType, operands.getNameAttr()); + emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create( + rewriter, op.getLoc(), lvalueType, operands.getNameAttr()); emitc::PointerType pointerType = emitc::PointerType::get(resultTy); rewriter.replaceOpWithNewOp<emitc::ApplyOp>( op, pointerType, rewriter.getStringAttr("&"), globalLValue); @@ -191,8 +261,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create<emitc::SubscriptOp>( - op.getLoc(), arrayValue, operands.getIndices()); + auto subscript = emitc::SubscriptOp::create( + rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript); return success(); @@ -211,8 +281,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } - auto subscript = rewriter.create<emitc::SubscriptOp>( - op.getLoc(), arrayValue, operands.getIndices()); + auto subscript = emitc::SubscriptOp::create( + rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, operands.getValue()); return success(); @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -242,7 +310,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { if (inputs.size() != 1) return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }; @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 83681b2..dc2035b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -24,11 +24,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" + #include <optional> #define DEBUG_TYPE "memref-to-llvm" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS @@ -87,12 +88,12 @@ getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, /// aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { - Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(), - rewriter.getIndexAttr(1)); - Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one); - Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump); - Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment); - return rewriter.create<LLVM::SubOp>(loc, bumped, mod); + Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(), + rewriter.getIndexAttr(1)); + Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one); + Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump); + Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment); + return LLVM::SubOp::create(rewriter, loc, bumped, mod); } /// Computes the byte size for the MemRef element type. @@ -123,8 +124,9 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space"); unsigned memrefAddrSpace = *maybeMemrefAddrSpace; if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) - allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( - loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), + allocatedPtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, + LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), allocatedPtr); return allocatedPtr; } @@ -168,14 +170,14 @@ public: Value alignment = getAlignment(rewriter, loc, op); if (alignment) { // Adjust the allocation size to consider alignment. - sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); + sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment); } // Allocate the underlying buffer. Type elementPtrType = this->getElementPtrType(memRefType); assert(elementPtrType && "could not compute element ptr type"); auto results = - rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes); + LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes); Value allocatedPtr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -184,11 +186,11 @@ public: if (alignment) { // Compute the aligned pointer. Value allocatedInt = - rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = - rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt); + LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt); } // Create the MemRef descriptor. @@ -268,8 +270,9 @@ public: sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - auto results = rewriter.create<LLVM::CallOp>( - loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); + auto results = + LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), + ValueRange({allocAlignment, sizeBytes})); Value ptr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -360,8 +363,9 @@ struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> { auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); - auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>( - loc, elementPtrType, elementType, size, op.getAlignment().value_or(0)); + auto allocatedElementPtr = + LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size, + op.getAlignment().value_or(0)); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( @@ -397,7 +401,7 @@ struct AllocaScopeOpLowering remainingOpsBlock, allocaScopeOp.getResultTypes(), SmallVector<Location>(allocaScopeOp->getNumResults(), allocaScopeOp.getLoc())); - rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); + LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock); } // Inline body region. @@ -407,8 +411,8 @@ struct AllocaScopeOpLowering // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); - auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(loc, getPtrType()); - rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); + auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); + LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody); // Replace the alloca_scope return with a branch that jumps out of the body. // Stack restore before leaving the body region. @@ -420,7 +424,7 @@ struct AllocaScopeOpLowering // Insert stack restore before jumping out the body of the region. rewriter.setInsertionPoint(branchOp); - rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); + LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); // Replace the op with values return from the body region. rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); @@ -451,11 +455,11 @@ struct AssumeAlignmentOpLowering // This is more direct than ptrtoint-based checks, is explicitly supported, // and works with non-integral address spaces. Value trueCond = - rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true)); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true)); Value alignmentConst = createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); - rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr, - alignmentConst); + LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr, + alignmentConst); rewriter.replaceOp(op, memref); return success(); } @@ -559,20 +563,21 @@ private: // Get pointer to offset field of memref<element_type> descriptor. auto indexPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); - Value offsetPtr = rewriter.create<LLVM::GEPOp>( - loc, indexPtrTy, elementType, underlyingRankedDesc, - ArrayRef<LLVM::GEPArg>{0, 2}); + Value offsetPtr = + LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType, + underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2}); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. - Value idxPlusOne = rewriter.create<LLVM::AddOp>( - loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), + Value idxPlusOne = LLVM::AddOp::create( + rewriter, loc, + createIndexAttrConstant(rewriter, loc, getIndexType(), 1), adaptor.getIndex()); - Value sizePtr = rewriter.create<LLVM::GEPOp>( - loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, - idxPlusOne); - return rewriter - .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) + Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, + getTypeConverter()->getIndexType(), + offsetPtr, idxPlusOne); + return LLVM::LoadOp::create(rewriter, loc, + getTypeConverter()->getIndexType(), sizePtr) .getResult(); } @@ -674,9 +679,10 @@ struct GenericAtomicRMWOpLowering auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType()); auto dataPtr = getStridedElementPtr( rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices()); - Value init = rewriter.create<LLVM::LoadOp>( - loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); - rewriter.create<LLVM::BrOp>(loc, init, loopBlock); + Value init = LLVM::LoadOp::create( + rewriter, loc, typeConverter->convertType(memRefType.getElementType()), + dataPtr); + LLVM::BrOp::create(rewriter, loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); @@ -696,15 +702,16 @@ struct GenericAtomicRMWOpLowering // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( - loc, dataPtr, loopArgument, result, successOrdering, failureOrdering); + auto cmpxchg = + LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument, + result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. - Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0); - Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1); + Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0); + Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1); // Conditionally branch to the end or back to the loop depending on %ok. - rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), - loopBlock, newLoaded); + LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(), + loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); @@ -796,8 +803,8 @@ public: if (!isExternal && isUninitialized) { rewriter.createBlock(&newGlobal.getInitializerRegion()); Value undef[] = { - rewriter.create<LLVM::UndefOp>(newGlobal.getLoc(), arrayTy)}; - rewriter.create<LLVM::ReturnOp>(newGlobal.getLoc(), undef); + LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)}; + LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef); } return success(); } @@ -842,13 +849,13 @@ struct GetGlobalMemrefOpLowering Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); auto addressOf = - rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName()); + LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. - auto gep = rewriter.create<LLVM::GEPOp>( - loc, ptrTy, arrayTy, addressOf, - SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0)); + auto gep = + LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf, + SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0)); // We do not expect the memref obtained using `memref.get_global` to be // ever deallocated. Set the allocated pointer to be known bad value to @@ -857,7 +864,7 @@ struct GetGlobalMemrefOpLowering Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = - rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst); + LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst); // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. @@ -1009,8 +1016,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { loc, adaptor.getSource(), rewriter); // rank = ConstantOp srcRank - auto rankVal = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(rank)); + auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(rank)); // poison = PoisonOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType); @@ -1029,7 +1036,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // struct = LoadOp ptr - auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr); + auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); @@ -1063,32 +1070,33 @@ public: MemRefDescriptor srcDesc(adaptor.getSource()); // Compute number of elements. - Value numElements = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(1)); + Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1)); for (int pos = 0; pos < srcType.getRank(); ++pos) { auto size = srcDesc.size(rewriter, loc, pos); - numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, size); } // Get element size. auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); // Compute total. Value totalSize = - rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); + LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes); Type elementType = typeConverter->convertType(srcType.getElementType()); Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); Value srcOffset = srcDesc.offset(rewriter, loc); - Value srcPtr = rewriter.create<LLVM::GEPOp>( - loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset); + Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(), + elementType, srcBasePtr, srcOffset); MemRefDescriptor targetDesc(adaptor.getTarget()); Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); Value targetOffset = targetDesc.offset(rewriter, loc); - Value targetPtr = rewriter.create<LLVM::GEPOp>( - loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset); - rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, - /*isVolatile=*/false); + Value targetPtr = + LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType, + targetBasePtr, targetOffset); + LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize, + /*isVolatile=*/false); rewriter.eraseOp(op); return success(); @@ -1103,8 +1111,8 @@ public: // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, MemRefType type) { - auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - type.getRank()); + auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + type.getRank()); auto *typeConverter = getTypeConverter(); auto ptr = typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); @@ -1116,7 +1124,7 @@ public: }; // Save stack position before promoting descriptors - auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(loc, getPtrType()); + auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); auto srcMemRefType = dyn_cast<MemRefType>(srcType); Value unrankedSource = @@ -1128,13 +1136,13 @@ public: : adaptor.getTarget(); // Now promote the unranked descriptors to the stack. - auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1)); auto promote = [&](Value desc) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); auto allocated = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one); - rewriter.create<LLVM::StoreOp>(loc, desc, allocated); + LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one); + LLVM::StoreOp::create(rewriter, loc, desc, allocated); return allocated; }; @@ -1149,11 +1157,11 @@ public: sourcePtr.getType(), symbolTables); if (failed(copyFn)) return failure(); - rewriter.create<LLVM::CallOp>(loc, copyFn.value(), - ValueRange{elemSize, sourcePtr, targetPtr}); + LLVM::CallOp::create(rewriter, loc, copyFn.value(), + ValueRange{elemSize, sourcePtr, targetPtr}); // Restore stack used for descriptors - rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); + LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); rewriter.eraseOp(op); @@ -1204,9 +1212,9 @@ struct MemorySpaceCastOpLowering MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, descVals); descVals[0] = - rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]); + LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]); descVals[1] = - rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]); + LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]); Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), resultTypeR, descVals); rewriter.replaceOp(op, result); @@ -1241,8 +1249,9 @@ struct MemorySpaceCastOpLowering UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), result, resultAddrSpace, sizes); Value resultUnderlyingSize = sizes.front(); - Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>( - loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize); + Value resultUnderlyingDesc = + LLVM::AllocaOp::create(rewriter, loc, getPtrType(), + rewriter.getI8Type(), resultUnderlyingSize); result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); // Copy pointers, performing address space casts. @@ -1256,10 +1265,10 @@ struct MemorySpaceCastOpLowering Value alignedPtr = sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), sourceUnderlyingDesc, sourceElemPtrType); - allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( - loc, resultElemPtrType, allocatedPtr); - alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( - loc, resultElemPtrType, alignedPtr); + allocatedPtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, resultElemPtrType, allocatedPtr); + alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc, + resultElemPtrType, alignedPtr); result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, resultElemPtrType, allocatedPtr); @@ -1277,12 +1286,13 @@ struct MemorySpaceCastOpLowering int64_t bytesToSkip = 2 * llvm::divideCeil( getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); - Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); - Value copySize = rewriter.create<LLVM::SubOp>( - loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); - rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals, - copySize, /*isVolatile=*/false); + Value bytesToSkipConst = LLVM::ConstantOp::create( + rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); + Value copySize = + LLVM::SubOp::create(rewriter, loc, getIndexType(), + resultUnderlyingSize, bytesToSkipConst); + LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals, + copySize, /*isVolatile=*/false); rewriter.replaceOp(op, ValueRange{result}); return success(); @@ -1485,7 +1495,7 @@ private: } else { Value shapeOp = reshapeOp.getShape(); Value index = createIndexAttrConstant(rewriter, loc, indexType, i); - dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); + dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index); Type indexType = getIndexType(); if (dimSize.getType() != indexType) dimSize = typeConverter->materializeTargetConversion( @@ -1497,7 +1507,7 @@ private: desc.setStride(rewriter, loc, i, stride); // Prepare the stride value for the next dimension. - stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); + stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize); } *descriptor = desc; @@ -1522,8 +1532,9 @@ private: SmallVector<Value, 4> sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, addressSpace, sizes); - Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( - loc, getPtrType(), IntegerType::get(getContext(), 8), sizes.front()); + Value underlyingDescPtr = LLVM::AllocaOp::create( + rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8), + sizes.front()); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. @@ -1554,7 +1565,7 @@ private: Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); Value resultRankMinusOne = - rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); + LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); Type indexType = getTypeConverter()->getIndexType(); @@ -1568,15 +1579,15 @@ private: rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); rewriter.setInsertionPointToEnd(initBlock); - rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), - condBlock); + LLVM::BrOp::create(rewriter, loc, + ValueRange({resultRankMinusOne, oneIndex}), condBlock); rewriter.setInsertionPointToStart(condBlock); Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); - Value pred = rewriter.create<LLVM::ICmpOp>( - loc, IntegerType::get(rewriter.getContext(), 1), + Value pred = LLVM::ICmpOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 1), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); Block *bodyBlock = @@ -1585,31 +1596,31 @@ private: // Copy size from shape to descriptor. auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( - loc, llvmIndexPtrType, + Value sizeLoadGep = LLVM::GEPOp::create( + rewriter, loc, llvmIndexPtrType, typeConverter->convertType(shapeMemRefType.getElementType()), shapeOperandPtr, indexArg); - Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep); + Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep); UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); - Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); + Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size); // Decrement loop counter and branch back. - Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); - rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), - condBlock); + Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex); + LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}), + condBlock); Block *remainder = rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); // Hook up the cond exit to the remainder. rewriter.setInsertionPointToEnd(condBlock); - rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, ValueRange(), - remainder, ValueRange()); + LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(), + remainder, ValueRange()); // Reset position to beginning of new remainder block. rewriter.setInsertionPointToStart(remainder); @@ -1738,7 +1749,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); if (nextSize) return runningStride - ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) + ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexAttrConstant(rewriter, loc, indexType, 1); @@ -1783,8 +1794,8 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); - alignedPtr = rewriter.create<LLVM::GEPOp>( - loc, alignedPtr.getType(), + alignedPtr = LLVM::GEPOp::create( + rewriter, loc, alignedPtr.getType(), typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, adaptor.getByteShift()); @@ -1838,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " - "from fmax to fmaximum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw maximumf changed " + "from fmax to fmaximum, expect more NaNs"; return LLVM::AtomicBinOp::fmaximum; case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; @@ -1849,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " - "from fmin to fminimum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw minimum changed " + "from fmin to fminimum, expect more NaNs"; return LLVM::AtomicBinOp::fminimum; case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index b866afb..7a70533 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -79,7 +79,8 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, assert(indices.size() == 2); indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); Type t = typeConverter.convertType(op.getComponentPtr().getType()); - return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); + return spirv::AccessChainOp::create(builder, loc, t, op.getBasePtr(), + indices); } /// Casts the given `srcBool` into an integer of `dstType`. @@ -107,8 +108,8 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask, value = castBoolToIntN(loc, value, dstType, builder); } else { if (valueBits < targetBits) { - value = builder.create<spirv::UConvertOp>( - loc, builder.getIntegerType(targetBits), value); + value = spirv::UConvertOp::create( + builder, loc, builder.getIntegerType(targetBits), value); } value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); @@ -372,8 +373,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, std::string varName = std::string("__workgroup_mem__") + std::to_string(std::distance(varOps.begin(), varOps.end())); - varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, - /*initializer=*/nullptr); + varOp = spirv::GlobalVariableOp::create(rewriter, loc, spirvType, varName, + /*initializer=*/nullptr); } // Get pointer to global variable at the current scope. @@ -572,8 +573,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain, - memoryAccess, alignment); + Value loadVal = spirv::LoadOp::create(rewriter, loc, accessChain, + memoryAccess, alignment); if (isBool) loadVal = castIntNToBool(loc, loadVal, rewriter); rewriter.replaceOp(loadOp, loadVal); @@ -601,8 +602,8 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, loadOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr, - memoryAccess, alignment); + Value spvLoadOp = spirv::LoadOp::create(rewriter, loc, dstType, adjustedPtr, + memoryAccess, alignment); // Shift the bits to the rightmost. // ____XXXX________ -> ____________XXXX @@ -770,12 +771,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, if (!scope) return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); - Value result = rewriter.create<spirv::AtomicAndOp>( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - clearBitsMask); - result = rewriter.create<spirv::AtomicOrOp>( - loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, - storeVal); + Value result = spirv::AtomicAndOp::create( + rewriter, loc, dstType, adjustedPtr, *scope, + spirv::MemorySemantics::AcquireRelease, clearBitsMask); + result = spirv::AtomicOrOp::create( + rewriter, loc, dstType, adjustedPtr, *scope, + spirv::MemorySemantics::AcquireRelease, storeVal); // The AtomicOrOp has no side effect. Since it is already inserted, we can // just remove the original StoreOp. Note that rewriter.replaceOp() @@ -850,12 +851,12 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( genericPtrType = typeConverter.convertType(intermediateType); } if (sourceSc != spirv::StorageClass::Generic) { - result = - rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result); + result = spirv::PtrCastToGenericOp::create(rewriter, loc, genericPtrType, + result); } if (resultSc != spirv::StorageClass::Generic) { result = - rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result); + spirv::GenericCastToPtrOp::create(rewriter, loc, resultPtrType, result); } rewriter.replaceOp(addrCastOp, result); return success(); diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 80b3d85..2549a9c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -21,19 +21,17 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <optional> #define DEBUG_TYPE "nvgpu-to-nvvm" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGSE() (llvm::dbgs()) namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS @@ -53,7 +51,7 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { assert(llvm::isa<IntegerType>(type) && "expected an integer Value"); if (type.getIntOrFloatBitWidth() <= 32) return value; - return b.create<LLVM::TruncOp>(b.getI32Type(), value); + return LLVM::TruncOp::create(b, b.getI32Type(), value); } /// Returns the type for the intrinsic given the vectorResultType of the @@ -113,8 +111,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type f32x1Ty = VectorType::get(1, f32Ty); auto makeConst = [&](int32_t index) -> Value { - return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), - rewriter.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(rewriter, loc, IntegerType::get(ctx, 32), + rewriter.getI32IntegerAttr(index)); }; if (arrayType) { @@ -126,7 +124,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, arrayType.getElementType() == f32x1Ty) { for (unsigned i = 0; i < structType.getBody().size(); i++) { Value el = - rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i); el = rewriter.createOrFold<LLVM::BitcastOp>( loc, arrayType.getElementType(), el); elements.push_back(el); @@ -143,24 +141,24 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { Value vec = - rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType()); + LLVM::PoisonOp::create(rewriter, loc, arrayType.getElementType()); Value x1 = - rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2); - Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, - i * 2 + 1); - vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, - x1, makeConst(0)); - vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, - x2, makeConst(1)); + LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, i * 2); + Value x2 = LLVM::ExtractValueOp::create(rewriter, loc, intrinsicResult, + i * 2 + 1); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x1, makeConst(0)); + vec = LLVM::InsertElementOp::create(rewriter, loc, vec.getType(), vec, + x2, makeConst(1)); elements.push_back(vec); } } // Create the final vectorized result. - Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType); + Value result = LLVM::PoisonOp::create(rewriter, loc, arrayType); for (const auto &el : llvm::enumerate(elements)) { - result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(), - el.index()); + result = LLVM::InsertValueOp::create(rewriter, loc, result, el.value(), + el.index()); } return result; } @@ -187,7 +185,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { - Value toUse = b.create<LLVM::ExtractValueOp>(operand, i); + Value toUse = LLVM::ExtractValueOp::create(b, operand, i); // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. @@ -195,7 +193,7 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, arrayTy.getElementType() == i4x8Ty || (arrayTy.getElementType() == f32x1Ty && operandPtxType == NVVM::MMATypes::tf32)) { - result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse)); + result.push_back(LLVM::BitcastOp::create(b, i32Ty, toUse)); continue; } @@ -208,9 +206,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { - result.push_back(b.create<LLVM::ExtractElementOp>( - toUse, - b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)))); + result.push_back(LLVM::ExtractElementOp::create( + b, toUse, + LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(idx)))); } continue; } @@ -285,8 +283,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); - Value ldMatrixResult = b.create<NVVM::LdMatrixOp>( - ldMatrixResultType, srcPtr, + Value ldMatrixResult = NVVM::LdMatrixOp::create( + b, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row); @@ -296,13 +294,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { // actual vector type (still of width 32b) and repack them into a result // struct. Type finalResultType = typeConverter->convertType(vectorResultType); - Value result = b.create<LLVM::PoisonOp>(finalResultType); + Value result = LLVM::PoisonOp::create(b, finalResultType); for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { Value i32Register = - num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i) + num32BitRegs > 1 ? LLVM::ExtractValueOp::create(b, ldMatrixResult, i) : ldMatrixResult; - Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register); - result = b.create<LLVM::InsertValueOp>(result, casted, i); + Value casted = LLVM::BitcastOp::create(b, innerVectorType, i32Register); + result = LLVM::InsertValueOp::create(b, result, casted, i); } rewriter.replaceOp(op, result); @@ -375,16 +373,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); - Value intrinsicResult = b.create<NVVM::MmaOp>( - intrinsicResTy, matA, matB, matC, - /*shape=*/gemmShape, - /*b1Op=*/std::nullopt, - /*intOverflow=*/overflow, - /*multiplicandPtxTypes=*/ - std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, - /*multiplicandLayouts=*/ - std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, - NVVM::MMALayout::col}); + Value intrinsicResult = + NVVM::MmaOp::create(b, intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/std::nullopt, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, + /*multiplicandLayouts=*/ + std::array<NVVM::MMALayout, 2>{ + NVVM::MMALayout::row, NVVM::MMALayout::col}); rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, intrinsicResult, rewriter)); @@ -565,15 +563,16 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm( llvm::append_range(asmVals, args); asmVals.push_back(indexData); - return b.create<LLVM::InlineAsmOp>( - /*resultTypes=*/intrinsicResultType, - /*operands=*/asmVals, - /*asm_string=*/asmStr, - /*constraints=*/constraintStr, - /*has_side_effects=*/true, - /*is_align_stack=*/false, LLVM::TailCallKind::None, - /*asm_dialect=*/asmDialectAttr, - /*operand_attrs=*/ArrayAttr()); + return LLVM::InlineAsmOp::create(b, + /*resultTypes=*/intrinsicResultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/constraintStr, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::TailCallKind::None, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); } /// Lowers `nvgpu.mma.sp.sync` to inline assembly. @@ -631,7 +630,7 @@ struct NVGPUMmaSparseSyncLowering return op->emitOpError() << "Expected metadata type to be LLVM " "VectorType of 2 i16 elements"; sparseMetadata = - b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata); + LLVM::BitcastOp::create(b, rewriter.getI32Type(), sparseMetadata); FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm( b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, @@ -682,7 +681,7 @@ struct NVGPUAsyncCopyLowering // Intrinsics takes a global pointer so we need an address space cast. auto srcPointerGlobalType = LLVM::LLVMPointerType::get( op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); - scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr); + scrPtr = LLVM::AddrSpaceCastOp::create(b, srcPointerGlobalType, scrPtr); int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; @@ -697,13 +696,13 @@ struct NVGPUAsyncCopyLowering // The rest of the DstElements in the destination (shared memory) are // filled with zeros. Value c3I32 = - b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3)); - Value bitwidth = b.create<LLVM::ConstantOp>( - b.getI32Type(), + LLVM::ConstantOp::create(b, b.getI32Type(), b.getI32IntegerAttr(3)); + Value bitwidth = LLVM::ConstantOp::create( + b, b.getI32Type(), b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); - Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes); - srcBytes = b.create<LLVM::LShrOp>( - b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32); + Value srcElementsI32 = LLVM::TruncOp::create(b, b.getI32Type(), srcBytes); + srcBytes = LLVM::LShrOp::create( + b, LLVM::MulOp::create(b, bitwidth, srcElementsI32), c3I32); } // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than // 16 dst bytes. @@ -712,14 +711,15 @@ struct NVGPUAsyncCopyLowering ? NVVM::LoadCacheModifierKind::CG : NVVM::LoadCacheModifierKind::CA; - b.create<NVVM::CpAsyncOp>( - dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), + NVVM::CpAsyncOp::create( + b, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), srcBytes); // Drop the result token. - Value zero = b.create<LLVM::ConstantOp>( - IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); + Value zero = + LLVM::ConstantOp::create(b, IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -733,11 +733,11 @@ struct NVGPUAsyncCreateGroupLowering LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); + NVVM::CpAsyncCommitGroupOp::create(rewriter, op.getLoc()); // Drop the result token. - Value zero = rewriter.create<LLVM::ConstantOp>( - op->getLoc(), IntegerType::get(op.getContext(), 32), - rewriter.getI32IntegerAttr(0)); + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), + IntegerType::get(op.getContext(), 32), + rewriter.getI32IntegerAttr(0)); rewriter.replaceOp(op, zero); return success(); } @@ -753,7 +753,7 @@ struct NVGPUAsyncWaitLowering ConversionPatternRewriter &rewriter) const override { // If numGroup is not present pick 0 as a conservative correct value. int32_t numGroups = adaptor.getNumGroups().value_or(0); - rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); + NVVM::CpAsyncWaitGroupOp::create(rewriter, op.getLoc(), numGroups); rewriter.eraseOp(op); return success(); } @@ -771,8 +771,8 @@ struct NVGPUMBarrierCreateLowering SymbolTable symbolTable(moduleOp); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(&moduleOp.front()); - auto global = rewriter.create<memref::GlobalOp>( - funcOp->getLoc(), "__mbarrier", + auto global = memref::GlobalOp::create( + rewriter, funcOp->getLoc(), "__mbarrier", /*sym_visibility=*/rewriter.getStringAttr("private"), /*type=*/barrierType, /*initial_value=*/ElementsAttr(), @@ -974,7 +974,7 @@ struct NVGPUMBarrierTryWaitParityLowering adaptor.getMbarId(), rewriter); Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = - b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity()); + LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); if (isMbarrierShared(op.getBarriers().getType())) { rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( @@ -1063,16 +1063,16 @@ struct NVGPUGenerateWarpgroupDescriptorLowering auto ti64 = b.getIntegerType(64); auto makeConst = [&](uint64_t index) -> Value { - return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index)); + return LLVM::ConstantOp::create(b, ti64, b.getI64IntegerAttr(index)); }; auto shiftLeft = [&](Value value, unsigned shift) -> Value { - return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift)); + return LLVM::ShlOp::create(b, ti64, value, makeConst(shift)); }; auto shiftRight = [&](Value value, unsigned shift) -> Value { - return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift)); + return LLVM::LShrOp::create(b, ti64, value, makeConst(shift)); }; auto insertBit = [&](Value desc, Value val, int startBit) { - return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit)); + return LLVM::OrOp::create(b, ti64, desc, shiftLeft(val, startBit)); }; int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); @@ -1086,7 +1086,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering Value baseAddr = getStridedElementPtr( rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()), adaptor.getTensor(), {}); - Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr); + Value basePtr = LLVM::PtrToIntOp::create(b, ti64, baseAddr); // Just use 14 bits for base address Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); @@ -1104,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) - << ")\n start_addr : " << baseAddr << "\n"); + LDBG() << "Generating warpgroup.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); return success(); @@ -1118,8 +1118,8 @@ struct NVGPUGenerateWarpgroupDescriptorLowering }; static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { - return b.create<LLVM::ConstantOp>(b.getIntegerType(64), - b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, b.getIntegerType(64), + b.getI32IntegerAttr(index)); } /// Returns a Value that holds data type enum that is expected by CUDA driver. @@ -1182,12 +1182,12 @@ struct NVGPUTmaCreateDescriptorOpLowering auto promotedOperands = getTypeConverter()->promoteOperands( b.getLoc(), op->getOperands(), adaptor.getOperands(), b); - Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, - makeI64Const(b, 5)); + Value boxArrayPtr = LLVM::AllocaOp::create( + b, llvmPointerType, llvmInt64Type, makeI64Const(b, 5)); for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { - Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType, - boxArrayPtr, makeI64Const(b, index)); - b.create<LLVM::StoreOp>(value, gep); + Value gep = LLVM::GEPOp::create(b, llvmPointerType, llvmPointerType, + boxArrayPtr, makeI64Const(b, index)); + LLVM::StoreOp::create(b, value, gep); } nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); @@ -1280,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering } else { llvm_unreachable("msg: not supported K shape"); } - LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM - << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); + LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]"; } /// Generates WGMMATypesAttr from MLIR Type @@ -1337,7 +1337,7 @@ struct NVGPUWarpgroupMmaOpLowering /// Basic function to generate Add Value makeAdd(Value lhs, Value rhs) { - return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. @@ -1365,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering int tileShapeA = matrixTypeA.getDimSize(1); int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k - << "] [wgmma descriptors] Descriptor A + " - << incrementVal << " | \t "); + LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal + << " | \t "; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1390,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering int byte = elemB.getIntOrFloatBitWidth() / 8; int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + LDBG() << "Descriptor B + " << incrementVal; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1399,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" - << (iterationM * wgmmaM) + wgmmaM << "][" - << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" - << wgmmaN << "])\n"); + LDBG() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" + << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM + << "][" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN + << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1430,29 +1429,30 @@ struct NVGPUWarpgroupMmaOpLowering auto overflow = NVVM::MMAIntOverflowAttr::get( op->getContext(), NVVM::MMAIntOverflow::wrapped); - return b.create<NVVM::WgmmaMmaAsyncOp>( - matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, - itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, + return NVVM::WgmmaMmaAsyncOp::create( + b, matrixC.getType(), matrixC, descriptorA, descriptorB, shape, + itypeA, itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); } /// Generates multiple wgmma instructions to complete the given GEMM shape Value generateWgmmaGroup() { Value wgmmaResult = - b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType()); + LLVM::PoisonOp::create(b, adaptor.getMatrixC().getType()); // Perform GEMM SmallVector<Value> wgmmaResults; for (int i = 0; i < iterationM; ++i) { - Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i); + Value matrixC = + LLVM::ExtractValueOp::create(b, adaptor.getMatrixC(), i); for (int j = 0; j < iterationN; ++j) for (int k = 0; k < iterationK; ++k) matrixC = generateWgmma(i, j, k, matrixC); wgmmaResults.push_back(matrixC); } for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) { - wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(), - wgmmaResult, matrix, idx); + wgmmaResult = LLVM::InsertValueOp::create(b, wgmmaResult.getType(), + wgmmaResult, matrix, idx); } return wgmmaResult; } @@ -1465,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN - << "] += A[" << totalM << "][" << totalK << "] * B[" - << totalK << "][" << totalN << "] ---===\n"); + LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A[" + << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN + << "] ---==="; // Find the shape for one wgmma instruction findWgmmaShape( @@ -1486,10 +1486,10 @@ struct NVGPUWarpgroupMmaOpLowering /// (WgmmaGroupSyncAlignedOp) for group synchronization /// (WgmmaWaitGroupSyncOp) after the instructions. Value generateWarpgroupMma() { - b.create<NVVM::WgmmaFenceAlignedOp>(); + NVVM::WgmmaFenceAlignedOp::create(b); Value wgmmaResult = generateWgmmaGroup(); - b.create<NVVM::WgmmaGroupSyncAlignedOp>(); - b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup()); + NVVM::WgmmaGroupSyncAlignedOp::create(b); + NVVM::WgmmaWaitGroupSyncOp::create(b, op.getWaitGroup()); return wgmmaResult; } }; @@ -1557,7 +1557,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering Type i32 = b.getI32Type(); auto makeConst = [&](int32_t index) -> Value { - return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); + return LLVM::ConstantOp::create(b, i32, b.getI32IntegerAttr(index)); }; Value c1 = makeConst(1); Value c2 = makeConst(2); @@ -1567,29 +1567,29 @@ struct NVGPUWarpgroupMmaStoreOpLowering Value warpSize = makeConst(kWarpSize); auto makeMul = [&](Value lhs, Value rhs) -> Value { - return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); + return LLVM::MulOp::create(b, lhs.getType(), lhs, rhs); }; auto makeAdd = [&](Value lhs, Value rhs) -> Value { - return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); + return LLVM::AddOp::create(b, lhs.getType(), lhs, rhs); }; auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, TypedValue<::mlir::MemRefType> memref) { Type it = b.getIndexType(); - Value idx = b.create<arith::IndexCastOp>(it, x); - Value idy0 = b.create<arith::IndexCastOp>(it, y); - Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); - Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); - Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); - b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); - b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); + Value idx = arith::IndexCastOp::create(b, it, x); + Value idy0 = arith::IndexCastOp::create(b, it, y); + Value idy1 = arith::IndexCastOp::create(b, it, makeAdd(y, c1)); + Value d0 = LLVM::ExtractValueOp::create(b, wgmmaResult, i); + Value d1 = LLVM::ExtractValueOp::create(b, wgmmaResult, i + 1); + memref::StoreOp::create(b, d0, memref, ValueRange{idx, idy0}); + memref::StoreOp::create(b, d1, memref, ValueRange{idx, idy1}); }; - Value tidx = b.create<NVVM::ThreadIdXOp>(i32); - Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); - Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); - Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); - Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); + Value tidx = NVVM::ThreadIdXOp::create(b, i32); + Value laneId = LLVM::URemOp::create(b, i32, tidx, warpSize); + Value warpId = LLVM::UDivOp::create(b, i32, tidx, warpSize); + Value lane4Id = LLVM::UDivOp::create(b, i32, laneId, c4); + Value lane4modId = LLVM::URemOp::create(b, i32, laneId, c4); Value tj = makeMul(lane4modId, c2); Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); @@ -1626,7 +1626,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType()); for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { auto structType = cast<LLVM::LLVMStructType>(matrixD); - Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx); + Value innerStructValue = + LLVM::ExtractValueOp::create(b, matriDValue, idx); storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); offset += structType.getBody().size(); } @@ -1648,23 +1649,23 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front()) .getBody() .front(); - Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType)); - Value packStruct = b.create<LLVM::PoisonOp>(packStructType); + Value zero = LLVM::ConstantOp::create(b, elemType, b.getZeroAttr(elemType)); + Value packStruct = LLVM::PoisonOp::create(b, packStructType); SmallVector<Value> innerStructs; // Unpack the structs and set all values to zero for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { auto structType = cast<LLVM::LLVMStructType>(s); - Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx); + Value structValue = LLVM::ExtractValueOp::create(b, packStruct, idx); for (unsigned i = 0; i < structType.getBody().size(); ++i) { - structValue = b.create<LLVM::InsertValueOp>( - structType, structValue, zero, ArrayRef<int64_t>({i})); + structValue = LLVM::InsertValueOp::create(b, structType, structValue, + zero, ArrayRef<int64_t>({i})); } innerStructs.push_back(structValue); } // Pack the inner structs into a single struct for (auto [idx, matrix] : llvm::enumerate(innerStructs)) { - packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(), - packStruct, matrix, idx); + packStruct = LLVM::InsertValueOp::create(b, packStruct.getType(), + packStruct, matrix, idx); } rewriter.replaceOp(op, packStruct); return success(); @@ -1681,7 +1682,7 @@ struct NVGPUTmaFenceOpLowering ImplicitLocOpBuilder b(op->getLoc(), rewriter); auto i32Ty = b.getI32Type(); Value tensormapSize = - b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128)); + LLVM::ConstantOp::create(b, i32Ty, rewriter.getI32IntegerAttr(128)); auto memscope = NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS); @@ -1716,13 +1717,13 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> { VectorType inTy = op.getIn().getType(); // apply rcp.approx.ftz.f on each element in vector. auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { - Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy); + Value ret1DVec = LLVM::PoisonOp::create(b, llvm1DVectorTy); int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements(); for (int i = 0; i < numElems; i++) { - Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i)); - Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx); - Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem); - ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx); + Value idx = LLVM::ConstantOp::create(b, i64Ty, b.getI64IntegerAttr(i)); + Value elem = LLVM::ExtractElementOp::create(b, inVec, idx); + Value dst = NVVM::RcpApproxFtzF32Op::create(b, f32Ty, elem); + ret1DVec = LLVM::InsertElementOp::create(b, ret1DVec, dst, idx); } return ret1DVec; }; diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 662ee9e..91788f9 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -25,11 +25,10 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS @@ -52,17 +51,17 @@ struct PtxLowering LogicalResult matchAndRewrite(BasicPtxBuilderInterface op, PatternRewriter &rewriter) const override { if (op.hasIntrinsic()) { - LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n"); + LDBG() << "Ptx Builder does not lower \n\t" << op; return failure(); } SmallVector<std::pair<Value, PTXRegisterMod>> asmValues; - LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); + LDBG() << op.getPtx(); PtxBuilder generator(op, rewriter); op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { - LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier); + LDBG() << asmValue << "\t Modifier : " << &modifier; generator.insertValue(asmValue, modifier); } diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp index 479725a..f5b3689 100644 --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -39,8 +39,8 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> { IntegerAttr constAttr; if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) { - auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(), - op.getIfCond(), false); + auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), TypeRange(), + op.getIfCond(), false); rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener()); thenBodyBuilder.clone(*op.getOperation()); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 7ac9687..021e31a 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -95,8 +95,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> { } // Create new operation. - auto newOp = rewriter.create<T>(op.getLoc(), resTypes, convertedOperands, - convertedAttrs); + auto newOp = T::create(rewriter, op.getLoc(), resTypes, convertedOperands, + convertedAttrs); // Translate regions. for (auto [originalRegion, convertedRegion] : diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index 7d20109..b711e33 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -196,7 +196,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion, // finalize. if (isa<ExitNode>(node)) { builder.setInsertionPointToEnd(block); - builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); + pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc()); return block; } @@ -272,8 +272,8 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { auto *operationPos = cast<OperationPosition>(pos); if (operationPos->isOperandDefiningOp()) // Standard (downward) traversal which directly follows the defining op. - value = builder.create<pdl_interp::GetDefiningOpOp>( - loc, builder.getType<pdl::OperationType>(), parentVal); + value = pdl_interp::GetDefiningOpOp::create( + builder, loc, builder.getType<pdl::OperationType>(), parentVal); else // A passthrough operation position. value = parentVal; @@ -287,23 +287,23 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { // requested to use a representative value (e.g., upward traversal). if (isa<pdl::RangeType>(parentVal.getType()) && usersPos->useRepresentative()) - value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); + value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0); else value = parentVal; // The second operation retrieves the users. - value = builder.create<pdl_interp::GetUsersOp>(loc, value); + value = pdl_interp::GetUsersOp::create(builder, loc, value); break; } case Predicates::ForEachPos: { assert(!failureBlockStack.empty() && "expected valid failure block"); - auto foreach = builder.create<pdl_interp::ForEachOp>( - loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); + auto foreach = pdl_interp::ForEachOp::create( + builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true); value = foreach.getLoopVariable(); // Create the continuation block. Block *continueBlock = builder.createBlock(&foreach.getRegion()); - builder.create<pdl_interp::ContinueOp>(loc); + pdl_interp::ContinueOp::create(builder, loc); failureBlockStack.push_back(continueBlock); currentBlock = &foreach.getRegion().front(); @@ -311,62 +311,64 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { } case Predicates::OperandPos: { auto *operandPos = cast<OperandPosition>(pos); - value = builder.create<pdl_interp::GetOperandOp>( - loc, builder.getType<pdl::ValueType>(), parentVal, + value = pdl_interp::GetOperandOp::create( + builder, loc, builder.getType<pdl::ValueType>(), parentVal, operandPos->getOperandNumber()); break; } case Predicates::OperandGroupPos: { auto *operandPos = cast<OperandGroupPosition>(pos); Type valueTy = builder.getType<pdl::ValueType>(); - value = builder.create<pdl_interp::GetOperandsOp>( - loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, + value = pdl_interp::GetOperandsOp::create( + builder, loc, + operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, parentVal, operandPos->getOperandGroupNumber()); break; } case Predicates::AttributePos: { auto *attrPos = cast<AttributePosition>(pos); - value = builder.create<pdl_interp::GetAttributeOp>( - loc, builder.getType<pdl::AttributeType>(), parentVal, + value = pdl_interp::GetAttributeOp::create( + builder, loc, builder.getType<pdl::AttributeType>(), parentVal, attrPos->getName().strref()); break; } case Predicates::TypePos: { if (isa<pdl::AttributeType>(parentVal.getType())) - value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); + value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal); else - value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); + value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal); break; } case Predicates::ResultPos: { auto *resPos = cast<ResultPosition>(pos); - value = builder.create<pdl_interp::GetResultOp>( - loc, builder.getType<pdl::ValueType>(), parentVal, + value = pdl_interp::GetResultOp::create( + builder, loc, builder.getType<pdl::ValueType>(), parentVal, resPos->getResultNumber()); break; } case Predicates::ResultGroupPos: { auto *resPos = cast<ResultGroupPosition>(pos); Type valueTy = builder.getType<pdl::ValueType>(); - value = builder.create<pdl_interp::GetResultsOp>( - loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, + value = pdl_interp::GetResultsOp::create( + builder, loc, + resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, parentVal, resPos->getResultGroupNumber()); break; } case Predicates::AttributeLiteralPos: { auto *attrPos = cast<AttributeLiteralPosition>(pos); - value = - builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue()); + value = pdl_interp::CreateAttributeOp::create(builder, loc, + attrPos->getValue()); break; } case Predicates::TypeLiteralPos: { auto *typePos = cast<TypeLiteralPosition>(pos); Attribute rawTypeAttr = typePos->getValue(); if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr)) - value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr); + value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr); else - value = builder.create<pdl_interp::CreateTypesOp>( - loc, cast<ArrayAttr>(rawTypeAttr)); + value = pdl_interp::CreateTypesOp::create(builder, loc, + cast<ArrayAttr>(rawTypeAttr)); break; } case Predicates::ConstraintResultPos: { @@ -413,56 +415,59 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, Predicates::Kind kind = question->getKind(); switch (kind) { case Predicates::IsNotNullQuestion: - builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); + pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure); break; case Predicates::OperationNameQuestion: { auto *opNameAnswer = cast<OperationNameAnswer>(answer); - builder.create<pdl_interp::CheckOperationNameOp>( - loc, val, opNameAnswer->getValue().getStringRef(), success, failure); + pdl_interp::CheckOperationNameOp::create( + builder, loc, val, opNameAnswer->getValue().getStringRef(), success, + failure); break; } case Predicates::TypeQuestion: { auto *ans = cast<TypeAnswer>(answer); if (isa<pdl::RangeType>(val.getType())) - builder.create<pdl_interp::CheckTypesOp>( - loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure); + pdl_interp::CheckTypesOp::create(builder, loc, val, + llvm::cast<ArrayAttr>(ans->getValue()), + success, failure); else - builder.create<pdl_interp::CheckTypeOp>( - loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure); + pdl_interp::CheckTypeOp::create(builder, loc, val, + llvm::cast<TypeAttr>(ans->getValue()), + success, failure); break; } case Predicates::AttributeQuestion: { auto *ans = cast<AttributeAnswer>(answer); - builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), - success, failure); + pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(), + success, failure); break; } case Predicates::OperandCountAtLeastQuestion: case Predicates::OperandCountQuestion: - builder.create<pdl_interp::CheckOperandCountOp>( - loc, val, cast<UnsignedAnswer>(answer)->getValue(), + pdl_interp::CheckOperandCountOp::create( + builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, success, failure); break; case Predicates::ResultCountAtLeastQuestion: case Predicates::ResultCountQuestion: - builder.create<pdl_interp::CheckResultCountOp>( - loc, val, cast<UnsignedAnswer>(answer)->getValue(), + pdl_interp::CheckResultCountOp::create( + builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(), /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, success, failure); break; case Predicates::EqualToQuestion: { bool trueAnswer = isa<TrueAnswer>(answer); - builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), - trueAnswer ? success : failure, - trueAnswer ? failure : success); + pdl_interp::AreEqualOp::create(builder, loc, val, args.front(), + trueAnswer ? success : failure, + trueAnswer ? failure : success); break; } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast<ConstraintQuestion>(question); - auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>( - loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, - cstQuestion->getIsNegated(), success, failure); + auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create( + builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(), + args, cstQuestion->getIsNegated(), success, failure); constraintOpMap.insert({cstQuestion, applyConstraintOp}); break; @@ -487,7 +492,7 @@ static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, blocks.push_back(it.second); values.push_back(cast<PredT>(it.first)->getValue()); } - builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); + OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks); } void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, @@ -536,12 +541,14 @@ void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); switch (kind) { case Predicates::OperandCountAtLeastQuestion: - builder.create<pdl_interp::CheckOperandCountOp>( - loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans, + /*compareAtLeast=*/true, + childBlock, defaultDest); break; case Predicates::ResultCountAtLeastQuestion: - builder.create<pdl_interp::CheckResultCountOp>( - loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); + pdl_interp::CheckResultCountOp::create(builder, loc, val, ans, + /*compareAtLeast=*/true, + childBlock, defaultDest); break; default: llvm_unreachable("Generating invalid AtLeast operation"); @@ -619,8 +626,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); - auto matchOp = builder.create<pdl_interp::RecordMatchOp>( - pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), + auto matchOp = pdl_interp::RecordMatchOp::create( + builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), failureBlockStack.back()); @@ -632,8 +639,8 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { SymbolRefAttr PatternLowering::generateRewriter( pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { builder.setInsertionPointToEnd(rewriterModule.getBody()); - auto rewriterFunc = builder.create<pdl_interp::FuncOp>( - pattern.getLoc(), "pdl_generated_rewriter", + auto rewriterFunc = pdl_interp::FuncOp::create( + builder, pattern.getLoc(), "pdl_generated_rewriter", builder.getFunctionType({}, {})); rewriterSymbolTable.insert(rewriterFunc); @@ -651,18 +658,18 @@ SymbolRefAttr PatternLowering::generateRewriter( Operation *oldOp = oldValue.getDefiningOp(); if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { if (Attribute value = attrOp.getValueAttr()) { - return newValue = builder.create<pdl_interp::CreateAttributeOp>( - attrOp.getLoc(), value); + return newValue = pdl_interp::CreateAttributeOp::create( + builder, attrOp.getLoc(), value); } } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { if (TypeAttr type = typeOp.getConstantTypeAttr()) { - return newValue = builder.create<pdl_interp::CreateTypeOp>( - typeOp.getLoc(), type); + return newValue = pdl_interp::CreateTypeOp::create( + builder, typeOp.getLoc(), type); } } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { if (ArrayAttr type = typeOp.getConstantTypesAttr()) { - return newValue = builder.create<pdl_interp::CreateTypesOp>( - typeOp.getLoc(), typeOp.getType(), type); + return newValue = pdl_interp::CreateTypesOp::create( + builder, typeOp.getLoc(), typeOp.getType(), type); } } @@ -684,8 +691,9 @@ SymbolRefAttr PatternLowering::generateRewriter( auto mappedArgs = llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); - builder.create<pdl_interp::ApplyRewriteOp>( - rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); + pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(), + /*resultTypes=*/TypeRange(), rewriteName, + args); } else { // Otherwise this is a dag rewriter defined using PDL operations. for (Operation &rewriteOp : *rewriter.getBody()) { @@ -703,7 +711,7 @@ SymbolRefAttr PatternLowering::generateRewriter( llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), /*results=*/{})); - builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); + pdl_interp::FinalizeOp::create(builder, rewriter.getLoc()); return SymbolRefAttr::get( builder.getContext(), pdl_interp::PDLInterpDialect::getRewriterModuleName(), @@ -716,9 +724,9 @@ void PatternLowering::generateRewriter( SmallVector<Value, 2> arguments; for (Value argument : rewriteOp.getArgs()) arguments.push_back(mapRewriteValue(argument)); - auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( - rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(), - arguments); + auto interpOp = pdl_interp::ApplyRewriteOp::create( + builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(), + rewriteOp.getNameAttr(), arguments); for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) rewriteValues[std::get<0>(it)] = std::get<1>(it); } @@ -726,16 +734,16 @@ void PatternLowering::generateRewriter( void PatternLowering::generateRewriter( pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, function_ref<Value(Value)> mapRewriteValue) { - Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( - attrOp.getLoc(), attrOp.getValueAttr()); + Value newAttr = pdl_interp::CreateAttributeOp::create( + builder, attrOp.getLoc(), attrOp.getValueAttr()); rewriteValues[attrOp] = newAttr; } void PatternLowering::generateRewriter( pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, function_ref<Value(Value)> mapRewriteValue) { - builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), - mapRewriteValue(eraseOp.getOpValue())); + pdl_interp::EraseOp::create(builder, eraseOp.getLoc(), + mapRewriteValue(eraseOp.getOpValue())); } void PatternLowering::generateRewriter( @@ -756,9 +764,9 @@ void PatternLowering::generateRewriter( // Create the new operation. Location loc = operationOp.getLoc(); - Value createdOp = builder.create<pdl_interp::CreateOperationOp>( - loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands, - attributes, operationOp.getAttributeValueNames()); + Value createdOp = pdl_interp::CreateOperationOp::create( + builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes, + operands, attributes, operationOp.getAttributeValueNames()); rewriteValues[operationOp.getOp()] = createdOp; // Generate accesses for any results that have their types constrained. @@ -768,8 +776,8 @@ void PatternLowering::generateRewriter( if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) { Value &type = rewriteValues[resultTys[0]]; if (!type) { - auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); - type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); + auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp); + type = pdl_interp::GetValueTypeOp::create(builder, loc, results); } return; } @@ -789,12 +797,13 @@ void PatternLowering::generateRewriter( // groups because the exact index of the result is not statically known. Value resultVal; if (seenVariableLength) - resultVal = builder.create<pdl_interp::GetResultsOp>( - loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); + resultVal = pdl_interp::GetResultsOp::create( + builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp, + it.index()); else - resultVal = builder.create<pdl_interp::GetResultOp>( - loc, valueTy, createdOp, it.index()); - type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); + resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy, + createdOp, it.index()); + type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal); } } @@ -804,8 +813,8 @@ void PatternLowering::generateRewriter( SmallVector<Value, 4> replOperands; for (Value operand : rangeOp.getArguments()) replOperands.push_back(mapRewriteValue(operand)); - rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>( - rangeOp.getLoc(), rangeOp.getType(), replOperands); + rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create( + builder, rangeOp.getLoc(), rangeOp.getType(), replOperands); } void PatternLowering::generateRewriter( @@ -820,8 +829,8 @@ void PatternLowering::generateRewriter( // Don't use replace if we know the replaced operation has no results. auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>(); if (!opOp || !opOp.getTypeValues().empty()) { - replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( - replOp.getLoc(), mapRewriteValue(replOp))); + replOperands.push_back(pdl_interp::GetResultsOp::create( + builder, replOp.getLoc(), mapRewriteValue(replOp))); } } else { for (Value operand : replaceOp.getReplValues()) @@ -830,29 +839,29 @@ void PatternLowering::generateRewriter( // If there are no replacement values, just create an erase instead. if (replOperands.empty()) { - builder.create<pdl_interp::EraseOp>( - replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue())); + pdl_interp::EraseOp::create(builder, replaceOp.getLoc(), + mapRewriteValue(replaceOp.getOpValue())); return; } - builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(), - mapRewriteValue(replaceOp.getOpValue()), - replOperands); + pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(), + mapRewriteValue(replaceOp.getOpValue()), + replOperands); } void PatternLowering::generateRewriter( pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, function_ref<Value(Value)> mapRewriteValue) { - rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( - resultOp.getLoc(), builder.getType<pdl::ValueType>(), + rewriteValues[resultOp] = pdl_interp::GetResultOp::create( + builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(), mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); } void PatternLowering::generateRewriter( pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, function_ref<Value(Value)> mapRewriteValue) { - rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( - resultOp.getLoc(), resultOp.getType(), + rewriteValues[resultOp] = pdl_interp::GetResultsOp::create( + builder, resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.getParent()), resultOp.getIndex()); } @@ -863,7 +872,7 @@ void PatternLowering::generateRewriter( // type. if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) { rewriteValues[typeOp] = - builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); + pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr); } } @@ -873,8 +882,8 @@ void PatternLowering::generateRewriter( // If the type isn't constant, the users (e.g. OperationOp) will resolve this // type. if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) { - rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( - typeOp.getLoc(), typeOp.getType(), typeAttr); + rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create( + builder, typeOp.getLoc(), typeOp.getType(), typeAttr); } } @@ -939,10 +948,10 @@ void PatternLowering::generateOperationResultTypeRewriter( !replacedOp->isBeforeInBlock(op)) continue; - Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( - replacedOp->getLoc(), mapRewriteValue(replOpVal)); - types.push_back(builder.create<pdl_interp::GetValueTypeOp>( - replacedOp->getLoc(), replacedOpResults)); + Value replacedOpResults = pdl_interp::GetResultsOp::create( + builder, replacedOp->getLoc(), mapRewriteValue(replOpVal)); + types.push_back(pdl_interp::GetValueTypeOp::create( + builder, replacedOp->getLoc(), replacedOpResults)); return; } @@ -985,16 +994,18 @@ void PDLToPDLInterpPass::runOnOperation() { // Create the main matcher function This function contains all of the match // related functionality from patterns in the module. OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); - auto matcherFunc = builder.create<pdl_interp::FuncOp>( - module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), + auto matcherFunc = pdl_interp::FuncOp::create( + builder, module.getLoc(), + pdl_interp::PDLInterpDialect::getMatcherFunctionName(), builder.getFunctionType(builder.getType<pdl::OperationType>(), /*results=*/{}), /*attrs=*/ArrayRef<NamedAttribute>()); // Create a nested module to hold the functions invoked for rewriting the IR // after a successful match. - ModuleOp rewriterModule = builder.create<ModuleOp>( - module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); + ModuleOp rewriterModule = + ModuleOp::create(builder, module.getLoc(), + pdl_interp::PDLInterpDialect::getRewriterModuleName()); // Generate the code for the patterns within the module. PatternLowering generator(matcherFunc, rewriterModule, configMap); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 3e434ea..5bd1d49 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList, assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type"); predList.emplace_back(pos, builder.getIsNotNull()); - if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { + if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) { // If the attribute has a type or value, add a constraint. if (Value type = attr.getValueType()) getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp index e1a9fa59..2d9c661f 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp @@ -14,9 +14,7 @@ #include "RootOrdering.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" -#include <queue> #include <utility> using namespace mlir; diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 0df91a2..807be7e 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -340,7 +340,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.getStep(); - auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult(); + auto stepped = arith::AddIOp::create(rewriter, loc, iv, step).getResult(); if (!stepped) return failure(); @@ -348,7 +348,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, loopCarried.push_back(stepped); loopCarried.append(terminator->operand_begin(), terminator->operand_end()); auto branchOp = - rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried); + cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried); // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the // llvm.loop_annotation attribute. @@ -375,16 +375,15 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, SmallVector<Value, 8> destOperands; destOperands.push_back(lowerBound); llvm::append_range(destOperands, forOp.getInitArgs()); - rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands); + cf::BranchOp::create(rewriter, loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, iv, upperBound); + auto comparison = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound); - rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock, - ArrayRef<Value>(), endBlock, - ArrayRef<Value>()); + cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock, + ArrayRef<Value>(), endBlock, ArrayRef<Value>()); // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. @@ -409,7 +408,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, continueBlock = rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), SmallVector<Location>(ifOp.getNumResults(), loc)); - rewriter.create<cf::BranchOp>(loc, remainingOpsBlock); + cf::BranchOp::create(rewriter, loc, remainingOpsBlock); } // Move blocks from the "then" region to the region containing 'scf.if', @@ -419,7 +418,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, Operation *thenTerminator = thenRegion.back().getTerminator(); ValueRange thenTerminatorOperands = thenTerminator->getOperands(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands); + cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands); rewriter.eraseOp(thenTerminator); rewriter.inlineRegionBefore(thenRegion, continueBlock); @@ -433,15 +432,15 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, Operation *elseTerminator = elseRegion.back().getTerminator(); ValueRange elseTerminatorOperands = elseTerminator->getOperands(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands); + cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands); rewriter.eraseOp(elseTerminator); rewriter.inlineRegionBefore(elseRegion, continueBlock); } rewriter.setInsertionPointToEnd(condBlock); - rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock, - /*trueArgs=*/ArrayRef<Value>(), elseBlock, - /*falseArgs=*/ArrayRef<Value>()); + cf::CondBranchOp::create(rewriter, loc, ifOp.getCondition(), thenBlock, + /*trueArgs=*/ArrayRef<Value>(), elseBlock, + /*falseArgs=*/ArrayRef<Value>()); // Ok, we're done! rewriter.replaceOp(ifOp, continueBlock->getArguments()); @@ -459,13 +458,14 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, auto ®ion = op.getRegion(); rewriter.setInsertionPointToEnd(condBlock); - rewriter.create<cf::BranchOp>(loc, ®ion.front()); + cf::BranchOp::create(rewriter, loc, ®ion.front()); for (Block &block : region) { if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) { ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(&block); - rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands); + cf::BranchOp::create(rewriter, loc, remainingOpsBlock, + terminatorOperands); rewriter.eraseOp(terminator); } } @@ -503,7 +503,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, for (auto [iv, lower, upper, step] : llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep())) { - ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs); + ForOp forOp = ForOp::create(rewriter, loc, lower, upper, step, iterArgs); ivs.push_back(forOp.getInductionVar()); auto iterRange = forOp.getRegionIterArgs(); iterArgs.assign(iterRange.begin(), iterRange.end()); @@ -517,7 +517,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, // A loop is constructed with an empty "yield" terminator if there are // no results. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create<scf::YieldOp>(loc, forOp.getResults()); + scf::YieldOp::create(rewriter, loc, forOp.getResults()); } rewriter.setInsertionPointToStart(forOp.getBody()); @@ -549,7 +549,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, // has been already created in loop construction). if (!yieldOperands.empty()) { rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.create<scf::YieldOp>(loc, yieldOperands); + scf::YieldOp::create(rewriter, loc, yieldOperands); } rewriter.replaceOp(parallelOp, loopResults); @@ -575,13 +575,14 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits()); + cf::BranchOp::create(rewriter, loc, before, whileOp.getInits()); // Replace terminators with branches. Assuming bodies are SESE, which holds // given only the patterns from this file, we only need to look at the last // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); + SmallVector<Value> args = llvm::to_vector(condOp.getArgs()); rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); @@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. - rewriter.replaceOp(whileOp, condOp.getArgs()); + rewriter.replaceOp(whileOp, args); return success(); } @@ -625,14 +626,14 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp, // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits()); + cf::BranchOp::create(rewriter, whileOp.getLoc(), before, whileOp.getInits()); // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); - rewriter.create<cf::CondBranchOp>(condOp.getLoc(), condOp.getCondition(), - before, condOp.getArgs(), continuation, - ValueRange()); + cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(), + before, condOp.getArgs(), continuation, + ValueRange()); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. @@ -695,12 +696,12 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {}); // Cast switch index to integer case value. - Value caseValue = rewriter.create<arith::IndexCastOp>( - op.getLoc(), rewriter.getI32Type(), op.getArg()); + Value caseValue = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg()); - rewriter.create<cf::SwitchOp>( - op.getLoc(), caseValue, *defaultBlock, ValueRange(), - rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands); + cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock, + ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues), + caseSuccessors, caseOperands); rewriter.replaceOp(op, continueBlock->getArguments()); return success(); } diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index dcb4852..84cbd86 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -91,7 +91,7 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, Type varType = emitc::LValueType::get(resultType); emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); emitc::VariableOp var = - rewriter.create<emitc::VariableOp>(loc, varType, noInit); + emitc::VariableOp::create(rewriter, loc, varType, noInit); resultVariables.push_back(var); } @@ -103,14 +103,14 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, static void assignValues(ValueRange values, ValueRange variables, ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) - rewriter.create<emitc::AssignOp>(loc, var, value); + emitc::AssignOp::create(rewriter, loc, var, value); } SmallVector<Value> loadValues(const SmallVector<Value> &variables, PatternRewriter &rewriter, Location loc) { return llvm::map_to_vector<>(variables, [&](Value var) { Type type = cast<emitc::LValueType>(var.getType()).getValueType(); - return rewriter.create<emitc::LoadOp>(loc, type, var).getResult(); + return emitc::LoadOp::create(rewriter, loc, type, var).getResult(); }); } @@ -129,7 +129,7 @@ static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, assignValues(yieldOperands, resultVariables, rewriter, loc); - rewriter.create<emitc::YieldOp>(loc); + emitc::YieldOp::create(rewriter, loc); rewriter.eraseOp(yield); return success(); @@ -164,8 +164,9 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); - emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>( - loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); + emitc::ForOp loweredFor = + emitc::ForOp::create(rewriter, loc, adaptor.getLowerBound(), + adaptor.getUpperBound(), adaptor.getStep()); Block *loweredBody = loweredFor.getBody(); @@ -257,7 +258,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, bool hasElseBlock = !elseRegion.empty(); auto loweredIf = - rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false); + emitc::IfOp::create(rewriter, loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); auto result = lowerRegion(thenRegion, loweredThenRegion); @@ -304,8 +305,9 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite( "create variables for results failed"); } - auto loweredSwitch = rewriter.create<emitc::SwitchOp>( - loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); + auto loweredSwitch = + emitc::SwitchOp::create(rewriter, loc, adaptor.getArg(), + adaptor.getCases(), indexSwitchOp.getNumCases()); // Lowering all case regions. for (auto pair : diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 844e66e..badd2f6 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -25,9 +25,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/Debug.h" #include <optional> @@ -84,8 +82,8 @@ static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) { // Get a Value that corresponds to the loop step. If the step is an attribute, // materialize a corresponding constant using builder. static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) { - return builder.create<arith::ConstantIndexOp>(forOp.getLoc(), - forOp.getStepAsInt()); + return arith::ConstantIndexOp::create(builder, forOp.getLoc(), + forOp.getStepAsInt()); } // Get a Value for the loop lower bound. If the value requires computation, @@ -190,12 +188,12 @@ AffineLoopToGpuConverter::collectBounds(AffineForOp forOp, unsigned numLoops) { return std::nullopt; } - Value range = builder.create<arith::SubIOp>(currentLoop.getLoc(), - upperBound, lowerBound); + Value range = arith::SubIOp::create(builder, currentLoop.getLoc(), + upperBound, lowerBound); Value step = getOrCreateStep(currentLoop, builder); if (getConstantIntValue(step) != static_cast<int64_t>(1)) - range = - builder.create<arith::CeilDivSIOp>(currentLoop.getLoc(), range, step); + range = arith::CeilDivSIOp::create(builder, currentLoop.getLoc(), range, + step); dims.push_back(range); lbs.push_back(lowerBound); @@ -221,7 +219,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, // no loop mapped to a specific dimension, use constant "1" as its size. Value constOne = (numBlockDims < 3 || numThreadDims < 3) - ? builder.create<arith::ConstantIndexOp>(rootForOp.getLoc(), 1) + ? arith::ConstantIndexOp::create(builder, rootForOp.getLoc(), 1) : nullptr; Value gridSizeX = numBlockDims > 0 ? dims[0] : constOne; Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne; @@ -232,9 +230,9 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, // Create a launch op and move the body region of the innermost loop to the // launch op. - auto launchOp = builder.create<gpu::LaunchOp>( - rootForOp.getLoc(), gridSizeX, gridSizeY, gridSizeZ, blockSizeX, - blockSizeY, blockSizeZ); + auto launchOp = + gpu::LaunchOp::create(builder, rootForOp.getLoc(), gridSizeX, gridSizeY, + gridSizeZ, blockSizeX, blockSizeY, blockSizeZ); // Replace the loop terminator (loops contain only a single block) with the // gpu terminator and move the operations from the loop body block to the gpu @@ -244,7 +242,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, Location terminatorLoc = terminator.getLoc(); terminator.erase(); builder.setInsertionPointToEnd(innermostForOp.getBody()); - builder.create<gpu::TerminatorOp>(terminatorLoc, TypeRange()); + gpu::TerminatorOp::create(builder, terminatorLoc, TypeRange()); launchOp.getBody().front().getOperations().splice( launchOp.getBody().front().begin(), innermostForOp.getBody()->getOperations()); @@ -263,10 +261,10 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); Value step = steps[en.index()]; if (getConstantIntValue(step) != static_cast<int64_t>(1)) - id = builder.create<arith::MulIOp>(rootForOp.getLoc(), step, id); + id = arith::MulIOp::create(builder, rootForOp.getLoc(), step, id); Value ivReplacement = - builder.create<arith::AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id); + arith::AddIOp::create(builder, rootForOp.getLoc(), *lbArgumentIt, id); en.value().replaceAllUsesWith(ivReplacement); std::advance(lbArgumentIt, 1); std::advance(stepArgumentIt, 1); @@ -319,8 +317,8 @@ static Value deriveStaticUpperBound(Value upperBound, if (auto minOp = upperBound.getDefiningOp<AffineMinOp>()) { for (const AffineExpr &result : minOp.getMap().getResults()) { if (auto constExpr = dyn_cast<AffineConstantExpr>(result)) { - return rewriter.create<arith::ConstantIndexOp>(minOp.getLoc(), - constExpr.getValue()); + return arith::ConstantIndexOp::create(rewriter, minOp.getLoc(), + constExpr.getValue()); } } } @@ -344,8 +342,8 @@ static Value deriveStaticUpperBound(Value upperBound, if ((lhs.value() < 0) != (rhs.value() < 0)) return {}; - return rewriter.create<arith::ConstantIndexOp>( - multiplyOp.getLoc(), lhs.value() * rhs.value()); + return arith::ConstantIndexOp::create(rewriter, multiplyOp.getLoc(), + lhs.value() * rhs.value()); } } @@ -422,8 +420,8 @@ static LogicalResult processParallelLoop( if (launchIndependent(val)) return val; if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) - return rewriter.create<arith::ConstantOp>(constOp.getLoc(), - constOp.getValue()); + return arith::ConstantOp::create(rewriter, constOp.getLoc(), + constOp.getValue()); return {}; }; @@ -453,8 +451,8 @@ static LogicalResult processParallelLoop( 1, 2, rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1)); - newIndex = rewriter.create<AffineApplyOp>( - loc, annotation.getMap().compose(lowerAndStep), + newIndex = AffineApplyOp::create( + rewriter, loc, annotation.getMap().compose(lowerAndStep), ValueRange{operand, ensureLaunchIndependent(step), ensureLaunchIndependent(lowerBound)}); // If there was also a bound, insert that, too. @@ -498,8 +496,8 @@ static LogicalResult processParallelLoop( 1, 2, ((rewriter.getAffineDimExpr(0) - rewriter.getAffineSymbolExpr(0)) .ceilDiv(rewriter.getAffineSymbolExpr(1)))); - Value launchBound = rewriter.create<AffineApplyOp>( - loc, annotation.getBound().compose(stepMap), + Value launchBound = AffineApplyOp::create( + rewriter, loc, annotation.getBound().compose(stepMap), ValueRange{ ensureLaunchIndependent( cloningMap.lookupOrDefault(upperBound)), @@ -517,10 +515,10 @@ static LogicalResult processParallelLoop( if (!boundIsPrecise) { // We are using an approximation, create a surrounding conditional. Value originalBound = std::get<3>(config); - arith::CmpIOp pred = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, newIndex, + arith::CmpIOp pred = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, newIndex, cloningMap.lookupOrDefault(originalBound)); - scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, pred, false); + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, pred, false); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); // Put a sentinel into the worklist so we know when to pop out of the // if body again. We use the launchOp here, as that cannot be part of @@ -530,10 +528,10 @@ static LogicalResult processParallelLoop( } } else { // Create a sequential for loop. - auto loopOp = rewriter.create<scf::ForOp>( - loc, cloningMap.lookupOrDefault(lowerBound), - cloningMap.lookupOrDefault(upperBound), - cloningMap.lookupOrDefault(step)); + auto loopOp = scf::ForOp::create(rewriter, loc, + cloningMap.lookupOrDefault(lowerBound), + cloningMap.lookupOrDefault(upperBound), + cloningMap.lookupOrDefault(step)); newIndex = loopOp.getInductionVar(); rewriter.setInsertionPointToStart(loopOp.getBody()); // Put a sentinel into the worklist so we know when to pop out of the loop @@ -608,12 +606,12 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, // sizes. Those will be refined later as we discover them from mappings. Location loc = parallelOp.getLoc(); Value constantOne = - rewriter.create<arith::ConstantIndexOp>(parallelOp.getLoc(), 1); - gpu::LaunchOp launchOp = rewriter.create<gpu::LaunchOp>( - parallelOp.getLoc(), constantOne, constantOne, constantOne, constantOne, - constantOne, constantOne); + arith::ConstantIndexOp::create(rewriter, parallelOp.getLoc(), 1); + gpu::LaunchOp launchOp = gpu::LaunchOp::create( + rewriter, parallelOp.getLoc(), constantOne, constantOne, constantOne, + constantOne, constantOne, constantOne); rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); - rewriter.create<gpu::TerminatorOp>(loc); + gpu::TerminatorOp::create(rewriter, loc); rewriter.setInsertionPointToStart(&launchOp.getBody().front()); IRMapping cloningMap; @@ -667,7 +665,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, if (externalValues.size()) return failure(); // Replace by gpu.all_reduce. - auto gpuRedOp = rewriter.create<gpu::AllReduceOp>(loc, newValue); + auto gpuRedOp = gpu::AllReduceOp::create(rewriter, loc, newValue); cloningMap.map(parentLoop->getResult(0), gpuRedOp.getResult()); // Copy region. rewriter.inlineRegionBefore(reduceOp.getRegion(0), gpuRedOp.getRegion(), diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 584ac2f..34f372a 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -187,8 +187,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) { OpBuilder::InsertionGuard guard(builder); Type type = reduce.getOperands()[reductionIndex].getType(); - auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(), - "__scf_reduction", type); + auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(), + "__scf_reduction", type); symbolTable.insert(decl); builder.createBlock(&decl.getInitializerRegion(), @@ -196,8 +196,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, {reduce.getOperands()[reductionIndex].getLoc()}); builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); Value init = - builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue); - builder.create<omp::YieldOp>(reduce.getLoc(), init); + LLVM::ConstantOp::create(builder, reduce.getLoc(), type, initValue); + omp::YieldOp::create(builder, reduce.getLoc(), init); Operation *terminator = &reduce.getReductions()[reductionIndex].front().back(); @@ -227,12 +227,12 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, {reduceOperandLoc, reduceOperandLoc}); Block *atomicBlock = &decl.getAtomicReductionRegion().back(); builder.setInsertionPointToEnd(atomicBlock); - Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(), - atomicBlock->getArgument(1)); - builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind, - atomicBlock->getArgument(0), loaded, - LLVM::AtomicOrdering::monotonic); - builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>()); + Value loaded = LLVM::LoadOp::create(builder, reduce.getLoc(), decl.getType(), + atomicBlock->getArgument(1)); + LLVM::AtomicRMWOp::create(builder, reduce.getLoc(), atomicKind, + atomicBlock->getArgument(0), loaded, + LLVM::AtomicOrdering::monotonic); + omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>()); return decl; } @@ -380,8 +380,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { // Allocate reduction variables. Make sure the we don't overflow the stack // with local `alloca`s by saving and restoring the stack pointer. Location loc = parallelOp.getLoc(); - Value one = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); + Value one = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(64), + rewriter.getI64IntegerAttr(1)); SmallVector<Value> reductionVariables; reductionVariables.reserve(parallelOp.getNumReductions()); auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext()); @@ -390,9 +391,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { isa<LLVM::PointerElementTypeInterface>(init.getType())) && "cannot create a reduction variable if the type is not an LLVM " "pointer element"); - Value storage = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0); - rewriter.create<LLVM::StoreOp>(loc, init, storage); + Value storage = LLVM::AllocaOp::create(rewriter, loc, ptrType, + init.getType(), one, 0); + LLVM::StoreOp::create(rewriter, loc, init, storage); reductionVariables.push_back(storage); } @@ -411,8 +412,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { assert(redRegion.hasOneBlock() && "expect reduction region to have one block"); Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc); - Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(), - rD.getType(), pvtRedVar); + Value pvtRedVal = LLVM::LoadOp::create(rewriter, reduce.getLoc(), + rD.getType(), pvtRedVar); // Make a copy of the reduction combiner region in the body mlir::OpBuilder builder(rewriter.getContext()); builder.setInsertionPoint(reduce); @@ -427,7 +428,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { assert(yieldOp && yieldOp.getResults().size() == 1 && "expect YieldOp in reduction region to return one result"); Value redVal = yieldOp.getResults()[0]; - rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar); + LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar); rewriter.eraseOp(yieldOp); break; } @@ -437,12 +438,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { Value numThreadsVar; if (numThreads > 0) { - numThreadsVar = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI32IntegerAttr(numThreads)); + numThreadsVar = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(numThreads)); } // Create the parallel wrapper. - auto ompParallel = rewriter.create<omp::ParallelOp>( - loc, + auto ompParallel = omp::ParallelOp::create( + rewriter, loc, /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, @@ -464,7 +465,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { { OpBuilder::InsertionGuard allocaGuard(rewriter); // Create worksharing loop wrapper. - auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc()); + auto wsloopOp = omp::WsloopOp::create(rewriter, parallelOp.getLoc()); if (!reductionVariables.empty()) { wsloopOp.setReductionSymsAttr( ArrayAttr::get(rewriter.getContext(), reductionSyms)); @@ -476,7 +477,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { wsloopOp.setReductionByref( DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef)); } - rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator. + omp::TerminatorOp::create(rewriter, loc); // omp.parallel terminator. // The wrapper's entry block arguments will define the reduction // variables. @@ -490,8 +491,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { parallelOp.getLoc())); // Create loop nest and populate region with contents of scf.parallel. - auto loopOp = rewriter.create<omp::LoopNestOp>( - parallelOp.getLoc(), parallelOp.getLowerBound(), + auto loopOp = omp::LoopNestOp::create( + rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep()); rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), @@ -511,13 +512,13 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin()); rewriter.setInsertionPointToStart(&loopOpEntryBlock); - auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(), - TypeRange()); - rewriter.create<omp::YieldOp>(loc, ValueRange()); + auto scope = memref::AllocaScopeOp::create( + rewriter, parallelOp.getLoc(), TypeRange()); + omp::YieldOp::create(rewriter, loc, ValueRange()); Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); rewriter.mergeBlocks(ops, scopeBlock); rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); - rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange()); + memref::AllocaScopeReturnOp::create(rewriter, loc, ValueRange()); } } @@ -526,7 +527,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { results.reserve(reductionVariables.size()); for (auto [variable, type] : llvm::zip(reductionVariables, parallelOp.getResultTypes())) { - Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable); + Value res = LLVM::LoadOp::create(rewriter, loc, type, variable); results.push_back(res); } rewriter.replaceOp(parallelOp, results); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 78d1327..dc92367 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -71,12 +71,12 @@ void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, auto pointerType = spirv::PointerType::get(convertedType, spirv::StorageClass::Function); rewriter.setInsertionPoint(newOp); - auto alloc = rewriter.create<spirv::VariableOp>( - loc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); + auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); allocas.push_back(alloc); rewriter.setInsertionPointAfter(newOp); - Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); + Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc); resultValue.push_back(loadResult); } rewriter.replaceOp(scfOp, resultValue); @@ -135,7 +135,8 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { // a single back edge from the continue to header block, and a single exit // from header to merge. auto loc = forOp.getLoc(); - auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); + auto loopOp = + spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); OpBuilder::InsertionGuard guard(rewriter); @@ -172,16 +173,17 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); - rewriter.create<spirv::BranchOp>(loc, header, args); + spirv::BranchOp::create(rewriter, loc, header, args); // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = rewriter.create<spirv::SLessThanOp>( - loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); + auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), + newIndVar, adaptor.getUpperBound()); - rewriter.create<spirv::BranchConditionalOp>( - loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); + spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body, + ArrayRef<Value>(), mergeBlock, + ArrayRef<Value>()); // Generate instructions to increment the step of the induction variable and // branch to the header. @@ -189,9 +191,9 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { rewriter.setInsertionPointToEnd(continueBlock); // Add the step to the induction variable and branch to the header. - Value updatedIndVar = rewriter.create<spirv::IAddOp>( - loc, newIndVar.getType(), newIndVar, adaptor.getStep()); - rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); + Value updatedIndVar = spirv::IAddOp::create( + rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep()); + spirv::BranchOp::create(rewriter, loc, header, updatedIndVar); // Infer the return types from the init operands. Vector type may get // converted to CooperativeMatrix or to Vector type, to avoid having complex @@ -237,11 +239,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { // Create `spirv.selection` operation, selection header block and merge // block. - auto selectionOp = - rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); + auto selectionOp = spirv::SelectionOp::create( + rewriter, loc, spirv::SelectionControl::None); auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end()); - rewriter.create<spirv::MergeOp>(loc); + spirv::MergeOp::create(rewriter, loc); OpBuilder::InsertionGuard guard(rewriter); auto *selectionHeaderBlock = @@ -251,7 +253,7 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { auto &thenRegion = ifOp.getThenRegion(); auto *thenBlock = &thenRegion.front(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create<spirv::BranchOp>(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(thenRegion, mergeBlock); auto *elseBlock = mergeBlock; @@ -261,15 +263,15 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { auto &elseRegion = ifOp.getElseRegion(); elseBlock = &elseRegion.front(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create<spirv::BranchOp>(loc, mergeBlock); + spirv::BranchOp::create(rewriter, loc, mergeBlock); rewriter.inlineRegionBefore(elseRegion, mergeBlock); } // Create a `spirv.BranchConditional` operation for selection header block. rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(), - thenBlock, ArrayRef<Value>(), - elseBlock, ArrayRef<Value>()); + spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(), + thenBlock, ArrayRef<Value>(), elseBlock, + ArrayRef<Value>()); replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, returnTypes); @@ -310,7 +312,7 @@ public: auto loc = terminatorOp.getLoc(); for (unsigned i = 0, e = operands.size(); i < e; i++) - rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); + spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]); if (isa<spirv::LoopOp>(parent)) { // For loops we also need to update the branch jumping back to the // header. @@ -319,8 +321,8 @@ public: SmallVector<Value, 8> args(br.getBlockArguments()); args.append(operands.begin(), operands.end()); rewriter.setInsertionPoint(br); - rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), - args); + spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(), + args); rewriter.eraseOp(br); } } @@ -340,7 +342,8 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = whileOp.getLoc(); - auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); + auto loopOp = + spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(rewriter); Region &beforeRegion = whileOp.getBefore(); @@ -382,7 +385,7 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { // Jump from the loop entry block to the loop header block. rewriter.setInsertionPointToEnd(&entryBlock); - rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits()); + spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits()); auto condLoc = cond.getLoc(); @@ -403,18 +406,18 @@ struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { // Create local variables before the scf.while op. rewriter.setInsertionPoint(loopOp); - auto alloc = rewriter.create<spirv::VariableOp>( - condLoc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); + auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); // Load the final result values after the scf.while op. rewriter.setInsertionPointAfter(loopOp); - auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc); + auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc); resultValues[i] = loadResult; // Store the current iteration's result value. rewriter.setInsertionPointToEnd(&beforeBlock); - rewriter.create<spirv::StoreOp>(condLoc, alloc, res); + spirv::StoreOp::create(rewriter, condLoc, alloc, res); } rewriter.setInsertionPointToEnd(&beforeBlock); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index d7ae9f0..035f197 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -68,7 +68,7 @@ static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { /// Copies the given number of bytes from src to dst pointers. static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder) { - builder.create<LLVM::MemcpyOp>(loc, dst, src, size, /*isVolatile=*/false); + LLVM::MemcpyOp::create(builder, loc, dst, src, size, /*isVolatile=*/false); } /// Encodes the binding and descriptor set numbers into a new symbolic name. @@ -194,8 +194,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { if (!kernelFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - kernelFunc = rewriter.create<LLVM::LLVMFuncOp>( - rewriter.getUnknownLoc(), newKernelFuncName, + kernelFunc = LLVM::LLVMFuncOp::create( + rewriter, rewriter.getUnknownLoc(), newKernelFuncName, LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), ArrayRef<Type>())); rewriter.setInsertionPoint(launchOp); @@ -245,8 +245,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { if (!dstGlobal) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - dstGlobal = rewriter.create<LLVM::GlobalOp>( - loc, dstGlobalType, + dstGlobal = LLVM::GlobalOp::create( + rewriter, loc, dstGlobalType, /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), /*alignment=*/0); rewriter.setInsertionPoint(launchOp); @@ -255,8 +255,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { // Copy the data from src operand pointer to dst global variable. Save // src, dst and size so that we can copy data back after emulating the // kernel call. - Value dst = rewriter.create<LLVM::AddressOfOp>( - loc, typeConverter->convertType(spirvGlobal.getType()), + Value dst = LLVM::AddressOfOp::create( + rewriter, loc, typeConverter->convertType(spirvGlobal.getType()), dstGlobal.getSymName()); copy(loc, dst, src, sizeBytes, rewriter); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 1d92b5d..9b61540 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -94,13 +94,13 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { if (isa<VectorType>(srcType)) { - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, + return LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast<ShapedType>(srcType), minusOneIntegerAttribute(srcType, rewriter))); } - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); + return LLVM::ConstantOp::create(rewriter, loc, dstType, + minusOneIntegerAttribute(srcType, rewriter)); } /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. @@ -108,14 +108,14 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { if (auto vecType = dyn_cast<VectorType>(srcType)) { auto floatType = cast<FloatType>(vecType.getElementType()); - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, + return LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } auto floatType = cast<FloatType>(srcType); - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, rewriter.getFloatAttr(floatType, value)); + return LLVM::ConstantOp::create(rewriter, loc, dstType, + rewriter.getFloatAttr(floatType, value)); } /// Utility function for bitfield ops: @@ -134,13 +134,13 @@ static Value optionallyTruncateOrExtend(Location loc, Value value, : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) - return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); + return LLVM::ZExtOp::create(rewriter, loc, llvmType, value); // If the bit widths of `Count` and `Offset` are greater than the bit width // of the target type, they are truncated. Truncation is safe since `Count` // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, // both values can be expressed in 8 bits. if (valueBitWidth > targetBitWidth) - return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); + return LLVM::TruncOp::create(rewriter, loc, llvmType, value); return value; } @@ -151,12 +151,12 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); - Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType); + Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType); for (unsigned i = 0; i < numElements; ++i) { - auto index = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); - broadcasted = rewriter.create<LLVM::InsertElementOp>( - loc, llvmVectorType, broadcasted, toBroadcast, index); + auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, + rewriter.getI32IntegerAttr(i)); + broadcasted = LLVM::InsertElementOp::create( + rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index); } return broadcasted; } @@ -217,8 +217,8 @@ static Type convertStructTypePacked(spirv::StructType type, /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value) { - return rewriter.create<LLVM::ConstantOp>( - loc, IntegerType::get(rewriter.getContext(), 32), + return LLVM::ConstantOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } @@ -322,8 +322,9 @@ public: auto llvmIndexType = getTypeConverter()->convertType(indexType); if (!llvmIndexType) return rewriter.notifyMatchFailure(op, "type conversion failed"); - Value zero = rewriter.create<LLVM::ConstantOp>( - op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); + Value zero = + LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType, + rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); auto elementType = getTypeConverter()->convertType( @@ -375,20 +376,20 @@ public: // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); - Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, - maskShiftedByCount, minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value negated = LLVM::XOrOp::create(rewriter, loc, dstType, + maskShiftedByCount, minusOne); Value maskShiftedByCountAndOffset = - rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); - Value mask = rewriter.create<LLVM::XOrOp>( - loc, dstType, maskShiftedByCountAndOffset, minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, + maskShiftedByCountAndOffset, minusOne); // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = - rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask); + LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask); Value insertShiftedByOffset = - rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset); + LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset); rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, insertShiftedByOffset); return success(); @@ -470,23 +471,23 @@ public: auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = isa<VectorType>(srcType) - ? rewriter.create<LLVM::ConstantOp>( - loc, dstType, + ? LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize)) - : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); + : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit // at Offset + Count - 1 is the most significant bit now. Value countPlusOffset = - rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); + LLVM::AddOp::create(rewriter, loc, dstType, count, offset); Value amountToShiftLeft = - rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); - Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( - loc, dstType, op.getBase(), amountToShiftLeft); + LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset); + Value baseShiftedLeft = LLVM::ShlOp::create( + rewriter, loc, dstType, op.getBase(), amountToShiftLeft); // Shift the result right, filling the bits with the sign bit. Value amountToShiftRight = - rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); + LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft); rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, amountToShiftRight); return success(); @@ -516,13 +517,13 @@ public: // Create a mask with bits set at [0, Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); - Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, - minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount, + minusOne); // Shift `Base` by `Offset` and apply the mask on it. Value shiftedBase = - rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset); + LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset); rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); return success(); } @@ -694,8 +695,8 @@ public: auto structType = LLVM::LLVMStructType::getLiteral(context, fields); // Create `llvm.mlir.global` with initializer region containing one block. - auto global = rewriter.create<LLVM::GlobalOp>( - UnknownLoc::get(context), structType, /*isConstant=*/true, + auto global = LLVM::GlobalOp::create( + rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true, LLVM::Linkage::External, executionModeInfoName, Attribute(), /*alignment=*/0); Location loc = global.getLoc(); @@ -704,22 +705,23 @@ public: // Initialize the struct and set the execution mode value. rewriter.setInsertionPointToStart(block); - Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType); - Value executionMode = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, + Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType); + Value executionMode = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, rewriter.getI32IntegerAttr( static_cast<uint32_t>(executionModeAttr.getValue()))); - structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue, - executionMode, 0); + SmallVector<int64_t> position{0}; + structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue, + executionMode, position); // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { auto attr = values.getValue()[i]; - Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); - structValue = rewriter.create<LLVM::InsertValueOp>( - loc, structValue, entry, ArrayRef<int64_t>({1, i})); + Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr); + structValue = LLVM::InsertValueOp::create( + rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i})); } - rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); + LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue})); rewriter.eraseOp(op); return success(); } @@ -913,7 +915,7 @@ public: Location loc = op.getLoc(); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand()); + Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand()); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); return success(); } @@ -973,10 +975,10 @@ public: IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); auto mask = isa<VectorType>(srcType) - ? rewriter.create<LLVM::ConstantOp>( - loc, dstType, + ? LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast<VectorType>(srcType), minusOne)) - : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); + : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne); rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, notOp.getOperand(), mask); return success(); @@ -1034,8 +1036,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, return func; OpBuilder b(symbolTable->getRegion(0)); - func = b.create<LLVM::LLVMFuncOp>( - symbolTable->getLoc(), name, + func = LLVM::LLVMFuncOp::create( + b, symbolTable->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes)); func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); func.setConvergent(convergent); @@ -1047,7 +1049,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, LLVM::LLVMFuncOp func, ValueRange args) { - auto call = builder.create<LLVM::CallOp>(loc, func, args); + auto call = LLVM::CallOp::create(builder, loc, func, args); call.setCConv(func.getCConv()); call.setConvergentAttr(func.getConvergentAttr()); call.setNoUnwindAttr(func.getNoUnwindAttr()); @@ -1078,12 +1080,12 @@ public: lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); Location loc = controlBarrierOp->getLoc(); - Value execution = rewriter.create<LLVM::ConstantOp>( - loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); - Value memory = rewriter.create<LLVM::ConstantOp>( - loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); - Value semantics = rewriter.create<LLVM::ConstantOp>( - loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); + Value execution = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); + Value memory = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); + Value semantics = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); auto call = createSPIRVBuiltinCall(loc, rewriter, func, {execution, memory, semantics}); @@ -1255,10 +1257,12 @@ public: lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); Location loc = op.getLoc(); - Value scope = rewriter.create<LLVM::ConstantOp>( - loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope())); - Value groupOp = rewriter.create<LLVM::ConstantOp>( - loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation())); + Value scope = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, + static_cast<int32_t>(adaptor.getExecutionScope())); + Value groupOp = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, + static_cast<int32_t>(adaptor.getGroupOperation())); SmallVector<Value> operands{scope, groupOp}; operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); @@ -1368,7 +1372,7 @@ public: return failure(); Block *headerBlock = loopOp.getHeaderBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); + LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock); rewriter.eraseBlock(entryBlock); // Branch from merge block to end block. @@ -1376,7 +1380,7 @@ public: Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock); rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); rewriter.replaceOp(loopOp, endBlock->getArguments()); @@ -1434,16 +1438,15 @@ public: Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock); // Link current block to `true` and `false` blocks within the selection. Block *trueBlock = condBrOp.getTrueBlock(); Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock, - condBrOp.getTrueTargetOperands(), - falseBlock, - condBrOp.getFalseTargetOperands()); + LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock, + condBrOp.getTrueTargetOperands(), falseBlock, + condBrOp.getFalseTargetOperands()); rewriter.eraseBlock(headerBlock); rewriter.inlineRegionBefore(op.getBody(), continueBlock); @@ -1490,11 +1493,11 @@ public: Value extended; if (op2TypeWidth < dstTypeWidth) { if (isUnsignedIntegerOrVector(op2Type)) { - extended = rewriter.template create<LLVM::ZExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } else { - extended = rewriter.template create<LLVM::SExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } } else if (op2TypeWidth == dstTypeWidth) { extended = adaptor.getOperand2(); @@ -1502,8 +1505,8 @@ public: return failure(); } - Value result = rewriter.template create<LLVMOp>( - loc, dstType, adaptor.getOperand1(), extended); + Value result = + LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(op, result); return success(); } @@ -1521,8 +1524,8 @@ public: return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); Location loc = tanOp.getLoc(); - Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand()); - Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand()); + Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); + Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); return success(); } @@ -1549,13 +1552,13 @@ public: Location loc = tanhOp.getLoc(); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); Value multiplied = - rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand()); - Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); + LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); + Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value numerator = - rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); + LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); Value denominator = - rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); + LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, denominator); return success(); @@ -1594,8 +1597,8 @@ public: if (!elementType) return rewriter.notifyMatchFailure(varOp, "type conversion failed"); Value allocated = - rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size); - rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated); + LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size); + LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated); rewriter.replaceOp(varOp, allocated); return success(); } @@ -1656,7 +1659,7 @@ public: // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); StringRef name = funcOp.getName(); - auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); + auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType); // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); @@ -1710,7 +1713,7 @@ public: ConversionPatternRewriter &rewriter) const override { auto newModuleOp = - rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); + ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName()); rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder @@ -1751,7 +1754,7 @@ public: auto componentsArray = components.getValue(); auto *context = rewriter.getContext(); auto llvmI32Type = IntegerType::get(context, 32); - Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType); + Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { if (!isa<IntegerAttr>(componentsArray[i])) return op.emitError("unable to support non-constant component"); @@ -1767,16 +1770,17 @@ public: baseVector = vector2; } - Value dstIndex = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); - Value index = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, + Value dstIndex = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, + rewriter.getIntegerAttr(rewriter.getI32Type(), i)); + Value index = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); - auto extractOp = rewriter.create<LLVM::ExtractElementOp>( - loc, scalarType, baseVector, index); - targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, - extractOp, dstIndex); + auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType, + baseVector, index); + targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp, + extractOp, dstIndex); } rewriter.replaceOp(op, targetOp); return success(); diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp index da9ad3d..245e60b 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -32,7 +32,7 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrRequireOp op, PatternRewriter &rewriter) const override { - rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr()); + cf::AssertOp::create(rewriter, op.getLoc(), op.getPred(), op.getMsgAttr()); rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true); return success(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index bbe1490..0ff9fb3 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" @@ -82,40 +81,40 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { // number of extent tensors and shifted offsets into them. Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, ValueRange rankDiffs, Value outputDimension) { - Value one = lb.create<arith::ConstantIndexOp>(1); + Value one = arith::ConstantIndexOp::create(lb, 1); Value broadcastedDim = one; for (auto tup : llvm::zip(extentTensors, rankDiffs)) { Value shape = std::get<0>(tup); Value rankDiff = std::get<1>(tup); - Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult, - outputDimension, rankDiff); + Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult, + outputDimension, rankDiff); Type indexTy = lb.getIndexType(); broadcastedDim = - lb.create<IfOp>( - outOfBounds, - [&](OpBuilder &b, Location loc) { - b.create<scf::YieldOp>(loc, broadcastedDim); - }, - [&](OpBuilder &b, Location loc) { - // The broadcasting logic is: - // - if one extent (here we arbitrarily choose the - // extent from the greater-rank operand) is equal to 1, - // then take the extent from the other operand - // - otherwise, take the extent as-is. - // Note that this logic remains correct in the presence - // of dimensions of zero extent. - Value lesserRankOperandDimension = b.create<arith::SubIOp>( - loc, indexTy, outputDimension, rankDiff); - Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( - loc, shape, ValueRange{lesserRankOperandDimension}); - - Value dimIsOne = - b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, - lesserRankOperandExtent, one); - Value dim = b.create<arith::SelectOp>( - loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); - b.create<scf::YieldOp>(loc, dim); - }) + IfOp::create( + lb, outOfBounds, + [&](OpBuilder &b, Location loc) { + scf::YieldOp::create(b, loc, broadcastedDim); + }, + [&](OpBuilder &b, Location loc) { + // The broadcasting logic is: + // - if one extent (here we arbitrarily choose the + // extent from the greater-rank operand) is equal to 1, + // then take the extent from the other operand + // - otherwise, take the extent as-is. + // Note that this logic remains correct in the presence + // of dimensions of zero extent. + Value lesserRankOperandDimension = arith::SubIOp::create( + b, loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandExtent = tensor::ExtractOp::create( + b, loc, shape, ValueRange{lesserRankOperandDimension}); + + Value dimIsOne = + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + lesserRankOperandExtent, one); + Value dim = arith::SelectOp::create( + b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); + scf::YieldOp::create(b, loc, dim); + }) .getResult(0); } return broadcastedDim; @@ -133,7 +132,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create<arith::ConstantIndexOp>(0); + Value zero = arith::ConstantIndexOp::create(lb, 0); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -141,31 +140,31 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( // dimension in the tensor. SmallVector<Value> ranks, rankDiffs; llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { - return lb.create<tensor::DimOp>(v, zero); + return tensor::DimOp::create(lb, v, zero); })); // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - maxRank = lb.create<arith::MaxUIOp>(v, maxRank); + maxRank = arith::MaxUIOp::create(lb, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create<arith::SubIOp>(indexTy, maxRank, v); + return arith::SubIOp::create(lb, indexTy, maxRank, v); })); - Value replacement = lb.create<tensor::GenerateOp>( - getExtentTensorType(lb.getContext()), ValueRange{maxRank}, + Value replacement = tensor::GenerateOp::create( + lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value broadcastedDim = getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, args[0]); - b.create<tensor::YieldOp>(loc, broadcastedDim); + tensor::YieldOp::create(b, loc, broadcastedDim); }); if (replacement.getType() != op.getType()) - replacement = lb.create<tensor::CastOp>(op.getType(), replacement); + replacement = tensor::CastOp::create(lb, op.getType(), replacement); rewriter.replaceOp(op, replacement); return success(); } @@ -193,13 +192,13 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite( auto loc = op.getLoc(); SmallVector<Value, 4> extentOperands; for (auto extent : op.getShape()) { - extentOperands.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue())); + extentOperands.push_back(arith::ConstantIndexOp::create( + rewriter, loc, extent.getLimitedValue())); } Type resultTy = RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType()); Value tensor = - rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands); + tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); return success(); } @@ -245,8 +244,8 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - Value zero = lb.create<arith::ConstantIndexOp>(0); - Value one = lb.create<arith::ConstantIndexOp>(1); + Value zero = arith::ConstantIndexOp::create(lb, 0); + Value one = arith::ConstantIndexOp::create(lb, 1); Type indexTy = lb.getIndexType(); // Save all the ranks for bounds checking. Because this is a tensor @@ -254,26 +253,26 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( // dimension in the tensor. SmallVector<Value> ranks, rankDiffs; llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { - return lb.create<tensor::DimOp>(v, zero); + return tensor::DimOp::create(lb, v, zero); })); // Find the maximum rank Value maxRank = ranks.front(); for (Value v : llvm::drop_begin(ranks, 1)) { - maxRank = lb.create<arith::MaxUIOp>(v, maxRank); + maxRank = arith::MaxUIOp::create(lb, v, maxRank); } // Calculate the difference of ranks and the maximum rank for later offsets. llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { - return lb.create<arith::SubIOp>(indexTy, maxRank, v); + return arith::SubIOp::create(lb, indexTy, maxRank, v); })); Type i1Ty = rewriter.getI1Type(); - Value trueVal = - rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); + Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty, + rewriter.getBoolAttr(true)); - auto reduceResult = lb.create<ForOp>( - loc, zero, maxRank, one, ValueRange{trueVal}, + auto reduceResult = ForOp::create( + lb, loc, zero, maxRank, one, ValueRange{trueVal}, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { // Find a non-1 dim, if it exists. Note that the first part of this // could reuse the Broadcast lowering entirely, but we redo the work @@ -285,38 +284,38 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; - Value outOfBounds = b.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ult, iv, rankDiff); + Value outOfBounds = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::ult, iv, rankDiff); broadcastable = - b.create<IfOp>( - loc, outOfBounds, - [&](OpBuilder &b, Location loc) { - // Non existent dimensions are always broadcastable - b.create<scf::YieldOp>(loc, broadcastable); - }, - [&](OpBuilder &b, Location loc) { - // Every value needs to be either 1, or the same non-1 - // value to be broadcastable in this dim. - Value operandDimension = - b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff); - Value dimensionExtent = b.create<tensor::ExtractOp>( - loc, shape, ValueRange{operandDimension}); - - Value equalOne = b.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, dimensionExtent, one); - Value equalBroadcasted = b.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, dimensionExtent, - broadcastedDim); - Value result = b.create<arith::AndIOp>( - loc, broadcastable, - b.create<arith::OrIOp>(loc, equalOne, - equalBroadcasted)); - b.create<scf::YieldOp>(loc, result); - }) + IfOp::create( + b, loc, outOfBounds, + [&](OpBuilder &b, Location loc) { + // Non existent dimensions are always broadcastable + scf::YieldOp::create(b, loc, broadcastable); + }, + [&](OpBuilder &b, Location loc) { + // Every value needs to be either 1, or the same non-1 + // value to be broadcastable in this dim. + Value operandDimension = + arith::SubIOp::create(b, loc, indexTy, iv, rankDiff); + Value dimensionExtent = tensor::ExtractOp::create( + b, loc, shape, ValueRange{operandDimension}); + + Value equalOne = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, dimensionExtent, one); + Value equalBroadcasted = + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + dimensionExtent, broadcastedDim); + Value result = arith::AndIOp::create( + b, loc, broadcastable, + arith::OrIOp::create(b, loc, equalOne, + equalBroadcasted)); + scf::YieldOp::create(b, loc, result); + }) .getResult(0); } - b.create<scf::YieldOp>(loc, broadcastable); + scf::YieldOp::create(b, loc, broadcastable); }); rewriter.replaceOp(op, reduceResult.getResults().front()); @@ -339,7 +338,7 @@ DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor, // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further // lowerings. This can be further optimized if needed to avoid intermediate // steps. - auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue()); + auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue()); rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf, op.getIndex()); return success(); @@ -421,16 +420,17 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, auto loc = op.getLoc(); - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = - rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero); + tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero); - auto loop = rewriter.create<scf::ForOp>( - loc, zero, rank, one, op.getInitVals(), + auto loop = scf::ForOp::create( + rewriter, loc, zero, rank, one, op.getInitVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv); + Value extent = + tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv); SmallVector<Value, 2> mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); @@ -444,7 +444,7 @@ ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, SmallVector<Value, 2> mappedResults; for (auto result : reduceBody->getTerminator()->getOperands()) mappedResults.push_back(mapping.lookup(result)); - b.create<scf::YieldOp>(loc, mappedResults); + scf::YieldOp::create(b, loc, mappedResults); }); rewriter.replaceOp(op, loop.getResults()); @@ -507,44 +507,44 @@ ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value firstShape = adaptor.getShapes().front(); Value firstRank = - rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); + tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero); Value result = nullptr; // Generate a linear sequence of compares, all with firstShape as lhs. for (Value shape : adaptor.getShapes().drop_front(1)) { - Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); - Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, - firstRank, rank); - auto same = rewriter.create<IfOp>( - loc, eqRank, + Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero); + Value eqRank = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank); + auto same = IfOp::create( + rewriter, loc, eqRank, [&](OpBuilder &b, Location loc) { - Value one = b.create<arith::ConstantIndexOp>(loc, 1); + Value one = arith::ConstantIndexOp::create(b, loc, 1); Value init = - b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); - auto loop = b.create<scf::ForOp>( - loc, zero, firstRank, one, ValueRange{init}, + arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true)); + auto loop = scf::ForOp::create( + b, loc, zero, firstRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { Value conj = args[0]; Value lhsExtent = - b.create<tensor::ExtractOp>(loc, firstShape, iv); - Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); - Value eqExtent = b.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); - Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent); - b.create<scf::YieldOp>(loc, ValueRange({conjNext})); + tensor::ExtractOp::create(b, loc, firstShape, iv); + Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv); + Value eqExtent = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); + Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent); + scf::YieldOp::create(b, loc, ValueRange({conjNext})); }); - b.create<scf::YieldOp>(loc, loop.getResults()); + scf::YieldOp::create(b, loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { Value result = - b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); - b.create<scf::YieldOp>(loc, result); + arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false)); + scf::YieldOp::create(b, loc, result); }); result = !result ? same.getResult(0) - : rewriter.create<arith::AndIOp>(loc, result, - same.getResult(0)); + : arith::AndIOp::create(rewriter, loc, result, + same.getResult(0)); } rewriter.replaceOp(op, result); return success(); @@ -581,18 +581,18 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( int64_t rank = rankedTensorTy.getRank(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { - Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); + Value extent = tensor::DimOp::create(rewriter, loc, tensor, i); extentValues.push_back(extent); } else { - Value extent = rewriter.create<arith::ConstantIndexOp>( - loc, rankedTensorTy.getDimSize(i)); + Value extent = arith::ConstantIndexOp::create( + rewriter, loc, rankedTensorTy.getDimSize(i)); extentValues.push_back(extent); } } // Materialize extent tensor. - Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( - loc, RankedTensorType::get({rank}, rewriter.getIndexType()), + Value staticExtentTensor = tensor::FromElementsOp::create( + rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()), extentValues); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), staticExtentTensor); @@ -601,13 +601,13 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( // Lower to `tensor.generate` otherwise. auto *ctx = rewriter.getContext(); - Value rank = rewriter.create<tensor::RankOp>(loc, tensor); + Value rank = tensor::RankOp::create(rewriter, loc, tensor); rewriter.replaceOpWithNewOp<tensor::GenerateOp>( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value dim = args.front(); - Value extent = b.create<tensor::DimOp>(loc, tensor, dim); - b.create<tensor::YieldOp>(loc, extent); + Value extent = tensor::DimOp::create(b, loc, tensor, dim); + tensor::YieldOp::create(b, loc, extent); }); return success(); @@ -634,22 +634,22 @@ LogicalResult SplitAtOpConversion::matchAndRewrite( return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value zero = b.create<arith::ConstantIndexOp>(0); - Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero); + Value zero = arith::ConstantIndexOp::create(b, 0); + Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero); // index < 0 ? index + rank : index Value originalIndex = adaptor.getIndex(); - Value add = b.create<arith::AddIOp>(originalIndex, rank); + Value add = arith::AddIOp::create(b, originalIndex, rank); Value indexIsNegative = - b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero); - Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex); + arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero); + Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex); - Value one = b.create<arith::ConstantIndexOp>(1); + Value one = arith::ConstantIndexOp::create(b, 1); Value head = - b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one); - Value tailSize = b.create<arith::SubIOp>(rank, index); - Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index, - tailSize, one); + tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one); + Value tailSize = arith::SubIOp::create(b, rank, index); + Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index, + tailSize, one); rewriter.replaceOp(op, {head, tail}); return success(); } diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt index 15560aa..564f36f 100644 --- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt +++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt @@ -1,8 +1,8 @@ -add_mlir_conversion_library(MLIRMeshToMPI - MeshToMPI.cpp +add_mlir_conversion_library(MLIRShardToMPI + ShardToMPI.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI DEPENDS MLIRConversionPassIncGen @@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI MLIRLinalgTransforms MLIRMemRefDialect MLIRPass - MLIRMeshDialect + MLIRShardDialect MLIRMPIDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index b931284..fa9e544 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -1,4 +1,4 @@ -//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// +//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation of Mesh communication ops tp MPI ops. +// This file implements a translation of Shard communication ops to MPI ops. // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" +#include "mlir/Conversion/ShardToMPI/ShardToMPI.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -20,11 +20,11 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" @@ -35,16 +35,15 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "mesh-to-mpi" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DEBUG_TYPE "shard-to-mpi" namespace mlir { -#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS +#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; -using namespace mesh; +using namespace shard; namespace { /// Converts a vector of OpFoldResults (ints) into vector of Values of the @@ -65,7 +64,7 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc, values.emplace_back(*(dyn++)); } else { TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); - values.emplace_back(b.create<arith::ConstantOp>(loc, type, val)); + values.emplace_back(arith::ConstantOp::create(b, loc, type, val)); } } return values; @@ -79,9 +78,9 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, SmallVector<Value> multiIndex(n); for (int i = n - 1; i >= 0; --i) { - multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]); + multiIndex[i] = arith::RemSIOp::create(b, loc, linearIndex, dimensions[i]); if (i > 0) - linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]); + linearIndex = arith::DivSIOp::create(b, loc, linearIndex, dimensions[i]); } return multiIndex; @@ -91,13 +90,13 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, ValueRange dimensions) { - Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0); - Value stride = b.create<arith::ConstantIndexOp>(loc, 1); + Value linearIndex = arith::ConstantIndexOp::create(b, loc, 0); + Value stride = arith::ConstantIndexOp::create(b, loc, 1); for (int i = multiIndex.size() - 1; i >= 0; --i) { - Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride); - linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off); - stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]); + Value off = arith::MulIOp::create(b, loc, multiIndex[i], stride); + linearIndex = arith::AddIOp::create(b, loc, linearIndex, off); + stride = arith::MulIOp::create(b, loc, stride, dimensions[i]); } return linearIndex; @@ -144,11 +143,12 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { auto i64 = rewriter.getI64Type(); std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()), maxNAxes}; - Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16); + Value resSplitAxes = tensor::EmptyOp::create(rewriter, loc, shape, i16); auto attr = IntegerAttr::get(i16, -1); - Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr); - resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes) - .getResult(0); + Value fillValue = arith::ConstantOp::create(rewriter, loc, i16, attr); + resSplitAxes = + linalg::FillOp::create(rewriter, loc, fillValue, resSplitAxes) + .getResult(0); // explicitly write values into tensor row by row std::array<int64_t, 2> strides = {1, 1}; @@ -162,9 +162,10 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { std::array<int64_t, 2> sizes = {1, size}; auto tensorType = RankedTensorType::get({size}, i16); auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); - auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs); - resSplitAxes = rewriter.create<tensor::InsertSliceOp>( - loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides); + auto vals = arith::ConstantOp::create(rewriter, loc, tensorType, attrs); + resSplitAxes = tensor::InsertSliceOp::create(rewriter, loc, vals, + resSplitAxes, empty, empty, + empty, offs, sizes, strides); } // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. @@ -175,56 +176,56 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { auto type = RankedTensorType::get({nSplits, 2}, i64); Value resHaloSizes = haloSizes.empty() - ? rewriter - .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, - i64) + ? tensor::EmptyOp::create(rewriter, loc, + std::array<int64_t, 2>{0, 0}, i64) .getResult() - : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes) + : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) .getResult(); // To hold sharded dims offsets, create Tensor with shape {nSplits, // maxSplitSize+1}. Store the offsets in the tensor but set trailing // elements for smaller split-groups to -1. Computing the max size of the // split groups needs using collectiveProcessGroupSize (which needs the - // MeshOp) + // GridOp) Value resOffsets; if (adaptor.getStaticShardedDimsOffsets().empty()) { - resOffsets = rewriter.create<tensor::EmptyOp>( - loc, std::array<int64_t, 2>{0, 0}, i64); + resOffsets = tensor::EmptyOp::create(rewriter, loc, + std::array<int64_t, 2>{0, 0}, i64); } else { SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); int64_t maxSplitSize = 0; for (auto axes : splitAxes) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic); maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); } assert(maxSplitSize); ++maxSplitSize; // add one for the total size - resOffsets = rewriter.create<tensor::EmptyOp>( - loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64); - Value zero = rewriter.create<arith::ConstantOp>( - loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); + resOffsets = tensor::EmptyOp::create( + rewriter, loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64); + Value zero = arith::ConstantOp::create( + rewriter, loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); resOffsets = - rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0); + linalg::FillOp::create(rewriter, loc, zero, resOffsets).getResult(0); SmallVector<Value> offsets = getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), adaptor.getDynamicShardedDimsOffsets()); int64_t curr = 0; for (auto [i, axes] : llvm::enumerate(splitAxes)) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); ++splitSize; // add one for the total size ArrayRef<Value> values(&offsets[curr], splitSize); - Value vals = rewriter.create<tensor::FromElementsOp>(loc, values); + Value vals = tensor::FromElementsOp::create(rewriter, loc, values); std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0}; std::array<int64_t, 2> sizes = {1, splitSize}; - resOffsets = rewriter.create<tensor::InsertSliceOp>( - loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides); + resOffsets = tensor::InsertSliceOp::create(rewriter, loc, vals, + resOffsets, empty, empty, + empty, offs, sizes, strides); curr += splitSize; } } @@ -236,10 +237,10 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { return failure(); resSplitAxes = - rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes); + tensor::CastOp::create(rewriter, loc, resTypes[0], resSplitAxes); resHaloSizes = - rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes); - resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets); + tensor::CastOp::create(rewriter, loc, resTypes[1], resHaloSizes); + resOffsets = tensor::CastOp::create(rewriter, loc, resTypes[2], resOffsets); rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( op, TupleType::get(op.getContext(), resTypes), @@ -261,20 +262,20 @@ struct ConvertProcessMultiIndexOp SymbolTableCollection symbolTableCollection; Location loc = op.getLoc(); - auto meshOp = getMesh(op, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(op, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); - Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp); + Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); - // optionally extract subset of mesh axes + // optionally extract subset of grid axes auto axes = adaptor.getAxes(); if (!axes.empty()) { SmallVector<Value> subIndex; @@ -302,14 +303,12 @@ public: Location loc = op.getLoc(); auto ctx = op.getContext(); Value commWorld = - rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx)); - auto rank = - rewriter - .create<mpi::CommRankOp>( - loc, - TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, - commWorld) - .getRank(); + mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); + auto rank = mpi::CommRankOp::create( + rewriter, loc, + TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, + commWorld) + .getRank(); rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank); return success(); @@ -335,47 +334,47 @@ struct ConvertNeighborsLinearIndicesOp Location loc = op.getLoc(); SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); auto mIdx = adaptor.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { - return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); Value dimSz = dims[axes[0]]; - Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); - Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1); - Value atBorder = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sle, orgIdx, - rewriter.create<arith::ConstantIndexOp>(loc, 0)); - auto down = rewriter.create<scf::IfOp>( - loc, atBorder, + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value minus1 = arith::ConstantIndexOp::create(rewriter, loc, -1); + Value atBorder = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sle, orgIdx, + arith::ConstantIndexOp::create(rewriter, loc, 0)); + auto down = scf::IfOp::create( + rewriter, loc, atBorder, [&](OpBuilder &builder, Location loc) { - builder.create<scf::YieldOp>(loc, minus1); + scf::YieldOp::create(builder, loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector<Value> tmp = mIdx; tmp[axes[0]] = - rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one) + arith::SubIOp::create(rewriter, op.getLoc(), orgIdx, one) .getResult(); - builder.create<scf::YieldOp>( - loc, multiToLinearIndex(loc, rewriter, tmp, dims)); + scf::YieldOp::create(builder, loc, + multiToLinearIndex(loc, rewriter, tmp, dims)); }); - atBorder = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, orgIdx, - rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult()); - auto up = rewriter.create<scf::IfOp>( - loc, atBorder, + atBorder = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, orgIdx, + arith::SubIOp::create(rewriter, loc, dimSz, one).getResult()); + auto up = scf::IfOp::create( + rewriter, loc, atBorder, [&](OpBuilder &builder, Location loc) { - builder.create<scf::YieldOp>(loc, minus1); + scf::YieldOp::create(builder, loc, minus1); }, [&](OpBuilder &builder, Location loc) { SmallVector<Value> tmp = mIdx; tmp[axes[0]] = - rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one); - builder.create<scf::YieldOp>( - loc, multiToLinearIndex(loc, rewriter, tmp, dims)); + arith::AddIOp::create(rewriter, op.getLoc(), orgIdx, one); + scf::YieldOp::create(builder, loc, + multiToLinearIndex(loc, rewriter, tmp, dims)); }); rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); return success(); @@ -391,14 +390,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); if (!sharding) { return op->emitError() - << "Expected SharingOp as defining op for sharding" + << "Expected ShardingOp as defining op for sharding" << " but found " << adaptor.getSharding()[0].getDefiningOp(); } // Compute the sharded shape by applying the sharding to the input shape. // If shardedDimsOffsets is not defined in the sharding, the shard shape is // computed by dividing the dimension size by the number of shards in that - // dimension (which is given by the size of the mesh axes provided in + // dimension (which is given by the size of the grid axes provided in // split-axes). Odd elements get distributed to trailing shards. If a // shardedDimsOffsets is provided, the shard shape is computed by // subtracting the offset of the current shard from the offset of the next @@ -428,11 +427,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { SmallVector<Value> multiIdx = getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); - // Get the MeshOp, the mesh shape is needed to compute the sharded shape. + // Get the GridOp, the grid shape is needed to compute the sharded shape. SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(sharding, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(sharding, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); auto splitAxes = sharding.getSplitAxes().getAxes(); @@ -447,19 +446,20 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { rewriter, loc, sharding.getStaticShardedDimsOffsets(), sharding.getDynamicShardedDimsOffsets(), index); if (!tmp.empty()) - shardedDimsOffs = rewriter.create<tensor::FromElementsOp>( - loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp); + shardedDimsOffs = tensor::FromElementsOp::create( + rewriter, loc, RankedTensorType::get({(int64_t)tmp.size()}, index), + tmp); } - // With static mesh shape the sizes of the split axes are known. + // With static grid shape the sizes of the split axes are known. // Hence the start/pos for each split axes in shardDimsOffsets can be // computed statically. int64_t pos = 0; SmallVector<Value> shardShape; Value zero = - rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(index)); Value one = - rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getOneAttr(index)); // Iterate over the dimensions of the tensor shape, get their split Axes, // and compute the sharded shape. @@ -469,12 +469,12 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { auto axes = splitAxes[i]; // The current dimension might not be sharded. // Create a value from the static position in shardDimsOffsets. - Value posVal = - rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos)); - // Get the index of the local shard in the mesh axis. + Value posVal = arith::ConstantOp::create(rewriter, loc, + rewriter.getIndexAttr(pos)); + // Get the index of the local shard in the grid axis. Value idx = multiIdx[axes[0]]; auto numShards = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); if (shardedDimsOffs) { // If sharded dims offsets are provided, use them to compute the // sharded shape. @@ -482,29 +482,29 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { return op->emitError() << "Only single axis sharding is " << "supported for each dimension."; } - idx = rewriter.create<arith::AddIOp>(loc, posVal, idx); + idx = arith::AddIOp::create(rewriter, loc, posVal, idx); // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. Value off = - rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); - idx = rewriter.create<arith::AddIOp>(loc, idx, one); + tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); + idx = arith::AddIOp::create(rewriter, loc, idx, one); Value nextOff = - rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); - Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off); + tensor::ExtractOp::create(rewriter, loc, shardedDimsOffs, idx); + Value sz = arith::SubIOp::create(rewriter, loc, nextOff, off); shardShape.emplace_back(sz); } else { - Value numShardsVal = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr(numShards)); + Value numShardsVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(numShards)); // Compute shard dim size by distributing odd elements to trailing // shards: // sz = dim / numShards // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) - Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal); - Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal); - sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1); - auto cond = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, idx, sz1); - Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero); - sz = rewriter.create<arith::AddIOp>(loc, sz, odd); + Value sz = arith::DivSIOp::create(rewriter, loc, dim, numShardsVal); + Value sz1 = arith::RemSIOp::create(rewriter, loc, dim, numShardsVal); + sz1 = arith::SubIOp::create(rewriter, loc, numShardsVal, sz1); + auto cond = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, idx, sz1); + Value odd = arith::SelectOp::create(rewriter, loc, cond, one, zero); + sz = arith::AddIOp::create(rewriter, loc, sz, odd); shardShape.emplace_back(sz); } pos += numShards + 1; // add one for the total size. @@ -552,13 +552,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SymbolTableCollection symbolTableCollection; - auto mesh = adaptor.getMesh(); - mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); - if (!meshOp) - return op->emitError() << "No mesh found for AllReduceOp"; - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto grid = adaptor.getGrid(); + mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection); + if (!gridOp) + return op->emitError() << "No grid found for AllReduceOp"; + if (ShapedType::isDynamicShape(gridOp.getShape())) return op->emitError() - << "Dynamic mesh shape not supported in AllReduceOp"; + << "Dynamic grid shape not supported in AllReduceOp"; ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); Value input = adaptor.getInput(); @@ -568,7 +568,7 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { if (isa<RankedTensorType>(input.getType())) { auto memrefType = MemRefType::get( inputShape, cast<ShapedType>(input.getType()).getElementType()); - input = iBuilder.create<bufferization::ToBufferOp>(memrefType, input); + input = bufferization::ToBufferOp::create(iBuilder, memrefType, input); } MemRefType inType = cast<MemRefType>(input.getType()); @@ -577,45 +577,45 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { for (auto i = 0; i < inType.getRank(); ++i) { auto s = inputShape[i]; if (ShapedType::isDynamic(s)) - shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult(); + shape[i] = memref::DimOp::create(iBuilder, input, s).getResult(); else shape[i] = iBuilder.getIndexAttr(s); } // Allocate buffer and copy input to buffer. - Value buffer = iBuilder.create<memref::AllocOp>( - shape, cast<ShapedType>(op.getType()).getElementType()); - iBuilder.create<linalg::CopyOp>(input, buffer); + Value buffer = memref::AllocOp::create( + iBuilder, shape, cast<ShapedType>(op.getType()).getElementType()); + linalg::CopyOp::create(iBuilder, input, buffer); // Get an MPI_Comm_split for the AllReduce operation. - // The color is the linear index of the process in the mesh along the - // non-reduced axes. The key is the linear index of the process in the mesh + // The color is the linear index of the process in the grid along the + // non-reduced axes. The key is the linear index of the process in the grid // along the reduced axes. - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), iBuilder.getIndexType()); SmallVector<Value> myMultiIndex = - iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh) + ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid) .getResult(); - Value zero = iBuilder.create<arith::ConstantIndexOp>(0); + Value zero = arith::ConstantIndexOp::create(iBuilder, 0); SmallVector<Value> multiKey(myMultiIndex.size(), zero); - auto redAxes = adaptor.getMeshAxes(); + auto redAxes = adaptor.getGridAxes(); for (auto axis : redAxes) { multiKey[axis] = myMultiIndex[axis]; myMultiIndex[axis] = zero; } Value color = - createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); - color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color); - Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); - key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key); + createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder); + color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); + Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder); + key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator auto commType = mpi::CommType::get(op->getContext()); - Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType); + Value commWorld = mpi::CommWorldOp::create(iBuilder, commType); auto comm = - iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key) + mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key) .getNewcomm(); Value buffer1d = buffer; @@ -623,19 +623,19 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { if (inType.getRank() > 1) { ReassociationIndices reassociation(inType.getRank()); std::iota(reassociation.begin(), reassociation.end(), 0); - buffer1d = iBuilder.create<memref::CollapseShapeOp>( - buffer, ArrayRef<ReassociationIndices>(reassociation)); + buffer1d = memref::CollapseShapeOp::create( + iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation)); } // Create the MPI AllReduce operation. - iBuilder.create<mpi::AllReduceOp>( - TypeRange(), buffer1d, buffer1d, - getMPIReductionOp(adaptor.getReductionAttr()), comm); + mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d, + getMPIReductionOp(adaptor.getReductionAttr()), + comm); // If the destination is a memref, cast it to a tensor if (isa<RankedTensorType>(op.getType())) - buffer = iBuilder.create<bufferization::ToTensorOp>(op.getType(), buffer, - true); + buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer, + true); rewriter.replaceOp(op, buffer); return success(); @@ -676,9 +676,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { if (auto value = dyn_cast<Value>(v)) return value; - return rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr( - cast<IntegerAttr>(cast<Attribute>(v)).getInt())); + return arith::ConstantOp::create( + rewriter, loc, + rewriter.getIndexAttr( + cast<IntegerAttr>(cast<Attribute>(v)).getInt())); }; auto dest = adaptor.getDestination(); @@ -689,19 +690,18 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { auto mmemrefType = MemRefType::get( dstShape, cast<ShapedType>(array.getType()).getElementType()); array = - rewriter.create<bufferization::ToBufferOp>(loc, mmemrefType, array); + bufferization::ToBufferOp::create(rewriter, loc, mmemrefType, array); } auto rank = cast<ShapedType>(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); - auto mesh = adaptor.getMesh(); - auto meshOp = getMesh(op, symbolTableCollection); + auto grid = adaptor.getGrid(); + auto gridOp = getGrid(op, symbolTableCollection); // subviews need Index values for (auto &sz : haloSizes) { if (auto value = dyn_cast<Value>(sz)) - sz = - rewriter - .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) - .getResult(); + sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), + value) + .getResult(); } // most of the offset/size/stride data is the same for all dims @@ -713,7 +713,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { for (auto i = 0; i < rank; ++i) { auto s = dstShape[i]; if (ShapedType::isDynamic(s)) - shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult(); + shape[i] = memref::DimOp::create(rewriter, loc, array, s).getResult(); else shape[i] = rewriter.getIndexAttr(s); @@ -723,12 +723,12 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { offsets[i] = haloSizes[currHaloDim * 2]; // prepare shape and offsets of highest dim's halo exchange - Value _haloSz = rewriter.create<arith::AddIOp>( - loc, toValue(haloSizes[currHaloDim * 2]), + Value _haloSz = arith::AddIOp::create( + rewriter, loc, toValue(haloSizes[currHaloDim * 2]), toValue(haloSizes[currHaloDim * 2 + 1])); // the halo shape of lower dims exlude the halos dimSizes[i] = - rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz) + arith::SubIOp::create(rewriter, loc, toValue(shape[i]), _haloSz) .getResult(); } else { dimSizes[i] = shape[i]; @@ -736,14 +736,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { } auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something - auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr); + auto tag = arith::ConstantOp::create(rewriter, loc, tagAttr); auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 - auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); + auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = - rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh) + ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { @@ -753,25 +753,26 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split - auto tmp = rewriter - .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex, - splitAxes) + auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid, + myMultiIndex, splitAxes) .getResults(); // MPI operates on i32... - Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), tmp[0]), - rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), tmp[1])}; + Value neighbourIDs[2] = { + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), + tmp[0]), + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), + tmp[1])}; auto lowerRecvOffset = rewriter.getIndexAttr(0); auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); - auto upperRecvOffset = rewriter.create<arith::SubIOp>( - loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); - auto upperSendOffset = rewriter.create<arith::SubIOp>( - loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); + auto upperRecvOffset = + arith::SubIOp::create(rewriter, loc, toValue(shape[dim]), + toValue(haloSizes[currHaloDim * 2 + 1])); + auto upperSendOffset = arith::SubIOp::create( + rewriter, loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); - Value commWorld = rewriter.create<mpi::CommWorldOp>( - loc, mpi::CommType::get(op->getContext())); + Value commWorld = mpi::CommWorldOp::create( + rewriter, loc, mpi::CommType::get(op->getContext())); // Make sure we send/recv in a way that does not lead to a dead-lock. // The current approach is by far not optimal, this should be at least @@ -784,40 +785,41 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive - // Processes on the mesh borders have only one neighbor + // Processes on the grid borders have only one neighbor auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; - auto hasFrom = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, from, zero); - auto hasTo = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, to, zero); - auto buffer = rewriter.create<memref::AllocOp>( - loc, dimSizes, cast<ShapedType>(array.getType()).getElementType()); + auto hasFrom = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, from, zero); + auto hasTo = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::sge, to, zero); + auto buffer = memref::AllocOp::create( + rewriter, loc, dimSizes, + cast<ShapedType>(array.getType()).getElementType()); // if has neighbor: copy halo data from array to buffer and send - rewriter.create<scf::IfOp>( - loc, hasTo, [&](OpBuilder &builder, Location loc) { + scf::IfOp::create( + rewriter, loc, hasTo, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) : OpFoldResult(upperSendOffset); - auto subview = builder.create<memref::SubViewOp>( - loc, array, offsets, dimSizes, strides); - builder.create<memref::CopyOp>(loc, subview, buffer); - builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to, - commWorld); - builder.create<scf::YieldOp>(loc); + auto subview = memref::SubViewOp::create( + builder, loc, array, offsets, dimSizes, strides); + memref::CopyOp::create(builder, loc, subview, buffer); + mpi::SendOp::create(builder, loc, TypeRange{}, buffer, tag, to, + commWorld); + scf::YieldOp::create(builder, loc); }); // if has neighbor: receive halo data into buffer and copy to array - rewriter.create<scf::IfOp>( - loc, hasFrom, [&](OpBuilder &builder, Location loc) { + scf::IfOp::create( + rewriter, loc, hasFrom, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) : OpFoldResult(lowerRecvOffset); - builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from, - commWorld); - auto subview = builder.create<memref::SubViewOp>( - loc, array, offsets, dimSizes, strides); - builder.create<memref::CopyOp>(loc, buffer, subview); - builder.create<scf::YieldOp>(loc); + mpi::RecvOp::create(builder, loc, TypeRange{}, buffer, tag, from, + commWorld); + auto subview = memref::SubViewOp::create( + builder, loc, array, offsets, dimSizes, strides); + memref::CopyOp::create(builder, loc, buffer, subview); + scf::YieldOp::create(builder, loc); }); - rewriter.create<memref::DeallocOp>(loc, buffer); + memref::DeallocOp::create(rewriter, loc, buffer); offsets[dim] = orgOffset; }; @@ -825,16 +827,17 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; Value haloSz = dyn_cast<Value>(v); if (!haloSz) - haloSz = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr( - cast<IntegerAttr>(cast<Attribute>(v)).getInt())); - auto hasSize = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sgt, haloSz, zero); - rewriter.create<scf::IfOp>(loc, hasSize, - [&](OpBuilder &builder, Location loc) { - genSendRecv(upOrDown > 0); - builder.create<scf::YieldOp>(loc); - }); + haloSz = arith::ConstantOp::create( + rewriter, loc, + rewriter.getI32IntegerAttr( + cast<IntegerAttr>(cast<Attribute>(v)).getInt())); + auto hasSize = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, haloSz, zero); + scf::IfOp::create(rewriter, loc, hasSize, + [&](OpBuilder &builder, Location loc) { + genSendRecv(upOrDown > 0); + scf::YieldOp::create(builder, loc); + }); }; doSendRecv(0); @@ -852,16 +855,16 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { rewriter.replaceOp(op, array); } else { assert(isa<RankedTensorType>(op.getResult().getType())); - rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>( - loc, op.getResult().getType(), array, + rewriter.replaceOp(op, bufferization::ToTensorOp::create( + rewriter, loc, op.getResult().getType(), array, /*restrict=*/true, /*writable=*/true)); } return success(); } }; -struct ConvertMeshToMPIPass - : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { +struct ConvertShardToMPIPass + : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> { using Base::Base; /// Run the dialect converter on the module. @@ -870,12 +873,12 @@ struct ConvertMeshToMPIPass RewritePatternSet patterns(ctxt); ConversionTarget target(getContext()); - // Define a type converter to convert mesh::ShardingType, + // Define a type converter to convert shard::ShardingType, // mostly for use in return operations. TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - // convert mesh::ShardingType to a tuple of RankedTensorTypes + // convert shard::ShardingType to a tuple of RankedTensorTypes typeConverter.addConversion( [](ShardingType type, SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { @@ -911,10 +914,10 @@ struct ConvertMeshToMPIPass return results; }); - // No mesh dialect should left after conversion... - target.addIllegalDialect<mesh::MeshDialect>(); - // ...except the global MeshOp. MeshShapeOp which will get folded later. - target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>(); + // No shard dialect should left after conversion... + target.addIllegalDialect<shard::ShardDialect>(); + // ...except the global GridOp. GridShapeOp which will get folded later. + target.addLegalOp<shard::GridOp, shard::GridShapeOp>(); // Allow all the stuff that our patterns will convert to target.addLegalDialect< BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, @@ -942,7 +945,7 @@ struct ConvertMeshToMPIPass // Folding patterns cannot be mixed with conversion patterns -> extra pass. patterns.clear(); SymbolTableCollection symbolTableCollection; - mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); + mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp index 2c4d275..f24972f6 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -68,10 +68,10 @@ public: // We could use the initializer directly; but certain driver compilers // have bugs dealing with that. So for now, use spirv.Store for // initialization. - varOp = rewriter.create<spirv::VariableOp>(loc, varType, - spirv::StorageClass::Function, - /*initializer=*/nullptr); - rewriter.create<spirv::StoreOp>(loc, varOp, adaptor.getTensor()); + varOp = spirv::VariableOp::create(rewriter, loc, varType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); + spirv::StoreOp::create(rewriter, loc, varOp, adaptor.getTensor()); } else { // Need to store the value to the local variable. It's questionable // whether we want to support such case though. @@ -83,7 +83,7 @@ public: Value index = spirv::linearizeIndex(adaptor.getIndices(), strides, /*offset=*/0, indexType, loc, rewriter); - auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index); + auto acOp = spirv::AccessChainOp::create(rewriter, loc, varOp, index); rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp); diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386e..8cd650e 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 40ad6361..044b725 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -51,8 +51,8 @@ TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { Value getConstantValue(Location loc, Type type, int64_t value, PatternRewriter &rewriter) { - return rewriter.create<arith::ConstantOp>( - loc, getConstantAttr(type, value, rewriter)); + return arith::ConstantOp::create(rewriter, loc, + getConstantAttr(type, value, rewriter)); } // This converts the TOSA ApplyScale operator to a set of arithmetic ops, @@ -82,41 +82,41 @@ public: Value one64 = getConstantValue(loc, i64Ty, 1, rewriter); Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter); - Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Compute the multiplication in 64-bits then select the high / low parts. Value value64 = value; if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) - value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value); + value64 = arith::ExtSIOp::create(rewriter, loc, i64Ty, value); Value multiplier64 = - rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32); + arith::ExtSIOp::create(rewriter, loc, i64Ty, multiplier32); Value multiply64 = - rewriter.create<arith::MulIOp>(loc, value64, multiplier64); + arith::MulIOp::create(rewriter, loc, value64, multiplier64); // Apply normal rounding. - Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32); - Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64); - round = rewriter.create<arith::ShRUIOp>(loc, round, one64); - multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round); + Value shift64 = arith::ExtUIOp::create(rewriter, loc, i64Ty, shift32); + Value round = arith::ShLIOp::create(rewriter, loc, one64, shift64); + round = arith::ShRUIOp::create(rewriter, loc, round, one64); + multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round); // Apply double rounding if necessary. if (op.getRoundingMode() == "DOUBLE_ROUND") { int64_t roundInt = 1 << 30; Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter); Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter); - Value positive = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, value, zero); + Value positive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, value, zero); Value dir = - rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown); - Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64); - Value valid = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); + arith::SelectOp::create(rewriter, loc, positive, roundUp, roundDown); + Value val = arith::AddIOp::create(rewriter, loc, dir, multiply64); + Value valid = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); multiply64 = - rewriter.create<arith::SelectOp>(loc, valid, val, multiply64); + arith::SelectOp::create(rewriter, loc, valid, val, multiply64); } - Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64); - Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64); + Value result64 = arith::ShRSIOp::create(rewriter, loc, multiply64, shift64); + Value result32 = arith::TruncIOp::create(rewriter, loc, i32Ty, result64); rewriter.replaceOp(op, result32); return success(); @@ -146,7 +146,7 @@ public: Value value32 = op.getValue(); Value multiplier32 = op.getMultiplier(); - Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); + Value shift32 = arith::ExtUIOp::create(rewriter, loc, i32Ty, op.getShift()); // Constants used during the scaling operation. Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter); @@ -158,86 +158,87 @@ public: // Compute the multiplication in 64-bits then select the high / low parts. // Grab out the high/low of the computation auto value64 = - rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32); + arith::MulSIExtendedOp::create(rewriter, loc, value32, multiplier32); Value low32 = value64.getLow(); Value high32 = value64.getHigh(); // Determine the direction and amount to shift the high bits. - Value shiftOver32 = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); - Value roundHighBits = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); + Value shiftOver32 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); + Value roundHighBits = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); Value shiftHighL = - rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32); + arith::SubIOp::create(rewriter, loc, thirtyTwo32, shift32); Value shiftHighR = - rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32); + arith::SubIOp::create(rewriter, loc, shift32, thirtyTwo32); shiftHighL = - rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL); + arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, shiftHighL); shiftHighR = - rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32); // Conditionally perform our double round. if (op.getRoundingMode() == "DOUBLE_ROUND") { Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); - Value valuePositive = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, value32, zero32); + Value valuePositive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, value32, zero32); - Value roundDir = - rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32); + Value roundDir = arith::SelectOp::create(rewriter, loc, valuePositive, + one32, negOne32); roundDir = - rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32); + arith::SelectOp::create(rewriter, loc, shiftOver32, roundDir, zero32); - Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32); - Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir); - Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32); + Value shiftLow = arith::ShRUIOp::create(rewriter, loc, low32, thirty32); + Value rounded = arith::AddIOp::create(rewriter, loc, shiftLow, roundDir); + Value carry = arith::ShRSIOp::create(rewriter, loc, rounded, two32); Value shiftRound = - rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32); + arith::ShLIOp::create(rewriter, loc, roundDir, thirty32); - low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound); - high32 = rewriter.create<arith::AddIOp>(loc, high32, carry); + low32 = arith::AddIOp::create(rewriter, loc, low32, shiftRound); + high32 = arith::AddIOp::create(rewriter, loc, high32, carry); } // Conditionally apply rounding in the low bits. { - Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32); - Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); - roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32, - roundBit); - - Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit); - Value wasRounded = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ugt, low32, newLow32); + Value shiftSubOne = arith::SubIOp::create(rewriter, loc, shift32, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, zero32, + roundBit); + + Value newLow32 = arith::AddIOp::create(rewriter, loc, low32, roundBit); + Value wasRounded = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ugt, low32, newLow32); low32 = newLow32; - Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded); - high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32); + Value rounded32 = + arith::ExtUIOp::create(rewriter, loc, i32Ty, wasRounded); + high32 = arith::AddIOp::create(rewriter, loc, high32, rounded32); } // Conditionally apply rounding in the high bits. { Value shiftSubOne = - rewriter.create<arith::SubIOp>(loc, shiftHighR, one32); - Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); - roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit, - zero32); - high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit); + arith::SubIOp::create(rewriter, loc, shiftHighR, one32); + Value roundBit = arith::ShLIOp::create(rewriter, loc, one32, shiftSubOne); + roundBit = arith::SelectOp::create(rewriter, loc, roundHighBits, roundBit, + zero32); + high32 = arith::AddIOp::create(rewriter, loc, high32, roundBit); } // Combine the correct high/low bits into the final rescale result. - high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL); - high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR); - low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32); - low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32); + high32 = arith::ShLIOp::create(rewriter, loc, high32, shiftHighL); + high32 = arith::ShRSIOp::create(rewriter, loc, high32, shiftHighR); + low32 = arith::ShRUIOp::create(rewriter, loc, low32, shift32); + low32 = arith::SelectOp::create(rewriter, loc, shiftOver32, zero32, low32); // Apply the rounding behavior and shift to the final alignment. - Value result = rewriter.create<arith::AddIOp>(loc, low32, high32); + Value result = arith::AddIOp::create(rewriter, loc, low32, high32); // Truncate if necessary. if (!getElementTypeOrSelf(resultTy).isInteger(32)) { - result = rewriter.create<arith::TruncIOp>(loc, resultTy, result); + result = arith::TruncIOp::create(rewriter, loc, resultTy, result); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 2f608bb..0e3de06 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -70,14 +69,14 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, return result; // Unordered comparison of NaN against itself will always return true. - Value lhsIsNaN = rewriter.create<arith::CmpFOp>( - op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs); - Value rhsIsNaN = rewriter.create<arith::CmpFOp>( - op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs); + Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), + arith::CmpFPredicate::UNO, lhs, lhs); + Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(), + arith::CmpFPredicate::UNO, rhs, rhs); Value rhsOrResult = - rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result); - return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs, - rhsOrResult); + arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result); + return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs, + rhsOrResult); } static Value createLinalgBodyCalculationForElementwiseOp( @@ -89,38 +88,38 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::AbsOp if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<math::AbsFOp>(loc, resultTypes, args); + return math::AbsFOp::create(rewriter, loc, resultTypes, args); if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) { - auto zero = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(elementTy)); - auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]); - return rewriter.create<arith::MaxSIOp>(loc, args[0], neg); + auto zero = arith::ConstantOp::create(rewriter, loc, + rewriter.getZeroAttr(elementTy)); + auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]); + return arith::MaxSIOp::create(rewriter, loc, args[0], neg); } // tosa::AddOp if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<arith::AddFOp>(loc, resultTypes, args); + return arith::AddFOp::create(rewriter, loc, resultTypes, args); if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::AddIOp>(loc, resultTypes, args); + return arith::AddIOp::create(rewriter, loc, resultTypes, args); // tosa::SubOp if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<arith::SubFOp>(loc, resultTypes, args); + return arith::SubFOp::create(rewriter, loc, resultTypes, args); if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::SubIOp>(loc, resultTypes, args); + return arith::SubIOp::create(rewriter, loc, resultTypes, args); // tosa::IntDivOp if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::DivSIOp>(loc, resultTypes, args); + return arith::DivSIOp::create(rewriter, loc, resultTypes, args); // tosa::ReciprocalOp if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) { auto one = - rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); - return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]); + arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); + return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]); } // tosa::MulOp @@ -140,7 +139,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( "Cannot have shift value for float"); return nullptr; } - return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]); + return arith::MulFOp::create(rewriter, loc, resultTypes, args[0], + args[1]); } if (isa<IntegerType>(elementTy)) { @@ -149,21 +149,21 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (shift > 0) { auto shiftConst = - rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8); + arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) - a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a); + a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) - b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b); + b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b); - auto result = rewriter.create<tosa::ApplyScaleOp>( - loc, rewriter.getI32Type(), a, b, shiftConst, + auto result = tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), a, b, shiftConst, rewriter.getStringAttr("SINGLE_ROUND")); if (elementTy.isInteger(32)) return result; - return rewriter.create<arith::TruncIOp>(loc, elementTy, result); + return arith::TruncIOp::create(rewriter, loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); @@ -171,11 +171,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) - a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a); + a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a); if (bWidth < cWidth) - b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b); + b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b); - return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b); + return arith::MulIOp::create(rewriter, loc, resultTypes, a, b); } } @@ -201,14 +201,14 @@ static Value createLinalgBodyCalculationForElementwiseOp( int64_t outZp = *maybeOutZp; if (isa<FloatType>(elementTy)) - return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]); + return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); if (isa<IntegerType>(elementTy)) { if (!inZp && !outZp) { - auto constant = rewriter.create<arith::ConstantOp>( - loc, IntegerAttr::get(elementTy, 0)); - return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, - args[0]); + auto constant = arith::ConstantOp::create( + rewriter, loc, IntegerAttr::get(elementTy, 0)); + return arith::SubIOp::create(rewriter, loc, resultTypes, constant, + args[0]); } // Compute the maximum value that can occur in the intermediate buffer. @@ -231,214 +231,214 @@ static Value createLinalgBodyCalculationForElementwiseOp( } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + Value zpAddValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue auto ext = - rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]); - auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext); + arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]); + auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext); // Clamp to the negation range. - Value min = rewriter.create<arith::ConstantIntOp>( - loc, intermediateType, + Value min = arith::ConstantIntOp::create( + rewriter, loc, intermediateType, APInt::getSignedMinValue(inputBitWidth).getSExtValue()); - Value max = rewriter.create<arith::ConstantIntOp>( - loc, intermediateType, + Value max = arith::ConstantIntOp::create( + rewriter, loc, intermediateType, APInt::getSignedMaxValue(inputBitWidth).getSExtValue()); auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false); // Truncate to the final value. - return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp); + return arith::TruncIOp::create(rewriter, loc, elementTy, clamp); } } // tosa::BitwiseAndOp if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::AndIOp>(loc, resultTypes, args); + return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseOrOp if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::OrIOp>(loc, resultTypes, args); + return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::BitwiseNotOp if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); - auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr); - return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes); + auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::ShLIOp>(loc, resultTypes, args); + return arith::ShLIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy)) - return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args); + return arith::ShRUIOp::create(rewriter, loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) { - auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args); + auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args); auto round = cast<BoolAttr>(op->getAttr("round")).getValue(); if (!round) { return result; } Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); - auto one = - rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1)); - auto zero = - rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); + auto one = arith::ConstantOp::create(rewriter, loc, + IntegerAttr::get(elementTy, 1)); + auto zero = arith::ConstantOp::create(rewriter, loc, + IntegerAttr::get(elementTy, 0)); auto i1one = - rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1)); + arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 - auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sgt, args[1], zero); + auto shiftValueGreaterThanZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = - rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one); + arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one); auto shifted = - rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract) + arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract) ->getResults(); - auto truncated = rewriter.create<arith::TruncIOp>( - loc, i1Ty, shifted, ArrayRef<NamedAttribute>()); + auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted, + ArrayRef<NamedAttribute>()); auto isInputOdd = - rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one); + arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one); - auto shouldRound = rewriter.create<arith::AndIOp>( - loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); + auto shouldRound = arith::AndIOp::create( + rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = - rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound); - return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended); + arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound); + return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended); } // tosa::ClzOp if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) { - return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]); + return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]); } // tosa::LogicalAnd if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1)) - return rewriter.create<arith::AndIOp>(loc, resultTypes, args); + return arith::AndIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalNot if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) { - auto one = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(elementTy, 1)); - return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one); + auto one = arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(elementTy, 1)); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1)) - return rewriter.create<arith::OrIOp>(loc, resultTypes, args); + return arith::OrIOp::create(rewriter, loc, resultTypes, args); // tosa::LogicalXor if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1)) - return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); + return arith::XOrIOp::create(rewriter, loc, resultTypes, args); // tosa::PowOp if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args); + return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args); // tosa::RsqrtOp if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args); + return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args); // tosa::LogOp if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args); + return mlir::math::LogOp::create(rewriter, loc, resultTypes, args); // tosa::ExpOp if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args); + return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args); // tosa::SinOp if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args); + return mlir::math::SinOp::create(rewriter, loc, resultTypes, args); // tosa::CosOp if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args); + return mlir::math::CosOp::create(rewriter, loc, resultTypes, args); // tosa::TanhOp if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args); + return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args); // tosa::ErfOp if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy)) - return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args); + return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args); // tosa::GreaterOp if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, - args[0], args[1]); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT, + args[0], args[1]); if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger()) - return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, - args[0], args[1]); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, + args[0], args[1]); // tosa::GreaterEqualOp if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, - args[0], args[1]); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE, + args[0], args[1]); if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger()) - return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, - args[0], args[1]); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, + args[0], args[1]); // tosa::EqualOp if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, - args[0], args[1]); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, + args[0], args[1]); if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger()) - return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, - args[0], args[1]); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + args[0], args[1]); // tosa::SelectOp if (isa<tosa::SelectOp>(op)) { elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType(); if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy)) - return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]); + return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]); } // tosa::MaximumOp if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) { - auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]); + auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op), rewriter, args[0], args[1], max); } if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) { - return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]); + return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::MinimumOp if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) { - auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]); + auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op), rewriter, args[0], args[1], min); } if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) { - return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]); + return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } // tosa::CeilOp if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<math::CeilOp>(loc, resultTypes, args); + return math::CeilOp::create(rewriter, loc, resultTypes, args); // tosa::FloorOp if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<math::FloorOp>(loc, resultTypes, args); + return math::FloorOp::create(rewriter, loc, resultTypes, args); // tosa::ClampOp if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) { @@ -449,10 +449,10 @@ static Value createLinalgBodyCalculationForElementwiseOp( APFloat::rmNearestTiesToEven, &losesInfo); maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); - auto min = rewriter.create<arith::ConstantOp>( - loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); - auto max = rewriter.create<arith::ConstantOp>( - loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); + auto min = arith::ConstantOp::create( + rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); + auto max = arith::ConstantOp::create( + rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); auto result = clampFloatHelper(loc, args[0], min, max, rewriter); auto clampOp = llvm::cast<tosa::ClampOp>(op); @@ -478,11 +478,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( // return init if x == NaN else result // Unordered comparison of NaN against itself will always return true. - Value isNaN = rewriter.create<arith::CmpFOp>( - op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]); + Value isNaN = arith::CmpFOp::create( + rewriter, op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]); // TOSA specifies that in "ignore" NaN mode the result is "min" if the input // is NaN. - return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result); + return arith::SelectOp::create(rewriter, op->getLoc(), isNaN, min, result); } if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) { @@ -515,10 +515,10 @@ static Value createLinalgBodyCalculationForElementwiseOp( min = std::min(min, maxRepresentable); max = std::min(max, maxRepresentable); - auto minVal = rewriter.create<arith::ConstantIntOp>( - loc, min, intTy.getIntOrFloatBitWidth()); - auto maxVal = rewriter.create<arith::ConstantIntOp>( - loc, max, intTy.getIntOrFloatBitWidth()); + auto minVal = arith::ConstantIntOp::create(rewriter, loc, min, + intTy.getIntOrFloatBitWidth()); + auto maxVal = arith::ConstantIntOp::create(rewriter, loc, max, + intTy.getIntOrFloatBitWidth()); return clampIntHelper(loc, args[0], minVal, maxVal, rewriter, intTy.isUnsignedInteger()); } @@ -526,11 +526,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::SigmoidOp if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) { auto one = - rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); - auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]); - auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate); - auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one); - return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added); + arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1)); + auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]); + auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate); + auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one); + return arith::DivFOp::create(rewriter, loc, resultTypes, one, added); } // tosa::CastOp @@ -549,50 +549,49 @@ static Value createLinalgBodyCalculationForElementwiseOp( return args.front(); if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend) - return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + return arith::ExtFOp::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend) - return rewriter.create<arith::TruncFOp>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + return arith::TruncFOp::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); // 1-bit integers need to be treated as signless. if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + return arith::UIToFPOp::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend) - return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + return arith::ExtUIOp::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); // Unsigned integers need an unrealized cast so that they can be passed // to UIToFP. if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) { auto unrealizedCast = - rewriter - .create<UnrealizedConversionCastOp>( - loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), - args[0]) + UnrealizedConversionCastOp::create( + rewriter, loc, + rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); - return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0], - unrealizedCast); + return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], + unrealizedCast); } // All other si-to-fp conversions should be handled by SIToFP. if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + return arith::SIToFPOp::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); // Casting to boolean, floats need to only be checked as not-equal to zero. if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) { - Value zero = rewriter.create<arith::ConstantOp>( - loc, rewriter.getFloatAttr(srcTy, 0.0)); - return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, - args.front(), zero); + Value zero = arith::ConstantOp::create(rewriter, loc, + rewriter.getFloatAttr(srcTy, 0.0)); + return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE, + args.front(), zero); } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]); + auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]); const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics(); // Check whether neither int min nor int max can be represented in the @@ -601,37 +600,42 @@ static Value createLinalgBodyCalculationForElementwiseOp( APFloat::semanticsMaxExponent(fltSemantics)) { // Use cmp + select to replace infinites by int min / int max. Other // integral values can be represented in the integer space. - auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded); - auto posInf = rewriter.create<arith::ConstantOp>( - loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), - APFloat::getInf(fltSemantics))); - auto negInf = rewriter.create<arith::ConstantOp>( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - APFloat::getInf(fltSemantics, /*Negative=*/true))); - auto overflow = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::UEQ, rounded, posInf); - auto underflow = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::UEQ, rounded, negInf); - auto intMin = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr( - getElementTypeOrSelf(dstTy), - APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); - auto intMax = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr( - getElementTypeOrSelf(dstTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); + auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded); + auto posInf = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), + APFloat::getInf(fltSemantics))); + auto negInf = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + APFloat::getInf(fltSemantics, /*Negative=*/true))); + auto overflow = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf); + auto underflow = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf); + auto intMin = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr( + getElementTypeOrSelf(dstTy), + APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); + auto intMax = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr( + getElementTypeOrSelf(dstTy), + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto maxClamped = - rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv); - return rewriter.create<arith::SelectOp>(loc, underflow, intMin, - maxClamped); + arith::SelectOp::create(rewriter, loc, overflow, intMax, conv); + return arith::SelectOp::create(rewriter, loc, underflow, intMin, + maxClamped); } - auto intMinFP = rewriter.create<arith::ConstantOp>( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) - .getSExtValue())); + auto intMinFP = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue())); // Check whether the mantissa has enough bits to represent int max. if (cast<FloatType>(srcTy).getFPMantissaWidth() >= @@ -640,58 +644,61 @@ static Value createLinalgBodyCalculationForElementwiseOp( // consists of a single leading bit. Therefore we can clamp the input // in the floating-point domain. - auto intMaxFP = rewriter.create<arith::ConstantOp>( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) - .getSExtValue())); + auto intMaxFP = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue())); Value clamped = clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter); - return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped); + return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped); } // Due to earlier check we know exponant range is big enough to represent // int min. We can therefore rely on int max + 1 being representable as // well because it's just int min with a positive sign. So clamp the min // value and compare against that to select the max int value if needed. - auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>( - loc, rewriter.getFloatAttr( - getElementTypeOrSelf(srcTy), - static_cast<double>( - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) - .getSExtValue()) + - 1.0f)); - - auto intMax = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr( - getElementTypeOrSelf(dstTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); + auto intMaxPlusOneFP = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), + static_cast<double>( + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue()) + + 1.0f)); + + auto intMax = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr( + getElementTypeOrSelf(dstTy), + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); auto minClampedFP = - rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP); + arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP); auto minClamped = - rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP); - auto overflow = rewriter.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); - return rewriter.create<arith::SelectOp>(loc, overflow, intMax, - minClamped); + arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP); + auto overflow = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); + return arith::SelectOp::create(rewriter, loc, overflow, intMax, + minClamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) { - Value zero = rewriter.create<arith::ConstantIntOp>( - loc, 0, srcTy.getIntOrFloatBitWidth()); - return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, - args.front(), zero); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, + srcTy.getIntOrFloatBitWidth()); + return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, + args.front(), zero); } if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend) - return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + return arith::ExtSIOp::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) { - return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]); + return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]); } } @@ -710,14 +717,14 @@ static Value createIndex(PatternRewriter &rewriter, Location loc, auto [it, inserted] = indexPool.try_emplace(index); if (inserted) it->second = - rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(index)); return it->second; } static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index) { auto indexValue = createIndex(rewriter, loc, indexPool, index); - return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult(); + return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult(); } static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, @@ -783,7 +790,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) { auto nextSize = getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim); - targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize); + targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize); } return {targetSize, nullptr}; } @@ -838,8 +845,8 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Check if broadcast is necessary auto one = createIndex(rewriter, loc, indexPool, 1); auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim); - auto broadcastNecessary = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, runtimeSize, one); + auto broadcastNecessary = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one); // Emit 'then' region of 'scf.if' auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { @@ -855,19 +862,18 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, operand, index); outputTensorShape.push_back(size); } - Value outputTensor = opBuilder.create<tensor::EmptyOp>( - loc, outputTensorShape, rankedTensorType.getElementType()); + Value outputTensor = tensor::EmptyOp::create( + opBuilder, loc, outputTensorShape, rankedTensorType.getElementType()); // Emit 'linalg.generic' op auto resultTensor = - opBuilder - .create<linalg::GenericOp>( - loc, outputTensor.getType(), operand, outputTensor, affineMaps, - getNParallelLoopsAttrs(rank), - [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { - // Emit 'linalg.yield' op - opBuilder.create<linalg::YieldOp>(loc, blockArgs.front()); - }) + linalg::GenericOp::create( + opBuilder, loc, outputTensor.getType(), operand, outputTensor, + affineMaps, getNParallelLoopsAttrs(rank), + [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { + // Emit 'linalg.yield' op + linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); + }) .getResult(0); // Cast to original operand type if necessary @@ -875,17 +881,17 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, loc, operand.getType(), resultTensor); // Emit 'scf.yield' op - opBuilder.create<scf::YieldOp>(loc, castResultTensor); + scf::YieldOp::create(opBuilder, loc, castResultTensor); }; // Emit 'else' region of 'scf.if' auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { - opBuilder.create<scf::YieldOp>(loc, operand); + scf::YieldOp::create(opBuilder, loc, operand); }; // Emit 'scf.if' op - auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary, - emitThenRegion, emitElseRegion); + auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary, + emitThenRegion, emitElseRegion); return ifOp.getResult(0); } @@ -930,8 +936,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, if (!resultType) { return rewriter.notifyMatchFailure(operation, "failed to convert type"); } - Value outputTensor = rewriter.create<tensor::EmptyOp>( - loc, targetShape, resultType.getElementType()); + Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape, + resultType.getElementType()); // Create affine maps. Input affine maps broadcast static dimensions of size // 1. The output affine map is an identity map. @@ -957,8 +963,8 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, // Emit 'linalg.generic' op bool encounteredError = false; - auto linalgOp = rewriter.create<linalg::GenericOp>( - loc, outputTensor.getType(), operands, outputTensor, affineMaps, + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, outputTensor.getType(), operands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( @@ -968,7 +974,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, encounteredError = true; return; } - opBuilder.create<linalg::YieldOp>(loc, opResult); + linalg::YieldOp::create(opBuilder, loc, opResult); }); if (encounteredError) return rewriter.notifyMatchFailure( @@ -1078,42 +1084,42 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) { - return rewriter.create<arith::AddFOp>(loc, args); + return arith::AddFOp::create(rewriter, loc, args); } if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) { - return rewriter.create<arith::AddIOp>(loc, args); + return arith::AddIOp::create(rewriter, loc, args); } if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) { - return rewriter.create<arith::MulFOp>(loc, args); + return arith::MulFOp::create(rewriter, loc, args); } if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) { - return rewriter.create<arith::MulIOp>(loc, args); + return arith::MulIOp::create(rewriter, loc, args); } if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) { - return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]); + return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]); } if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) { - return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]); + return arith::MinSIOp::create(rewriter, loc, args[0], args[1]); } if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) { - return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]); + return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]); } if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) { - return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]); + return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]); } if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) - return rewriter.create<arith::AndIOp>(loc, args); + return arith::AndIOp::create(rewriter, loc, args); if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) - return rewriter.create<arith::OrIOp>(loc, args); + return arith::OrIOp::create(rewriter, loc, args); return {}; } @@ -1139,7 +1145,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, if (axis != i) { reduceShape.push_back(inputTy.getDimSize(i)); if (inputTy.isDynamicDim(i)) - dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -1147,22 +1153,20 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, inputs.push_back(input); // First fill the output buffer with the init value. - auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), - dynDims) - .getResult(); + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) + .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); outputs.push_back(filledTensor); bool isNanIgnoreMode = false; @@ -1176,16 +1180,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, // Additionally we have to keep track of whether we've seen any non-NaN // values and then do a final select based on this predicate. auto trueAttr = rewriter.getBoolAttr(true); - auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr); + auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(), - dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + trueValue.getType(), dynDims) .getResult(); auto allResultsNaNTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{trueValue}, - ValueRange{emptyBoolTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{trueValue}, + ValueRange{emptyBoolTensor}) .result(); // Note that because the linalg::ReduceOp has two variadic arguments // (inputs and outputs) and it has the SameVariadicOperandSize trait we @@ -1202,8 +1204,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, } bool didEncounterError = false; - linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>( - loc, inputs, outputs, axis, + linalg::LinalgOp linalgOp = linalg::ReduceOp::create( + rewriter, loc, inputs, outputs, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { std::array<Value, 2> binaryArgs{ blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]}; @@ -1219,21 +1221,22 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto oldAllResultsNanFlagValue = blockArgs[3]; // Unordered comparison of NaN against itself will always return true. - Value isNaN = nestedBuilder.create<arith::CmpFOp>( - op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue); + Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(), + arith::CmpFPredicate::UNO, + inputValue, inputValue); // If we've encountered a NaN, take the non-NaN value. - auto selectOp = nestedBuilder.create<arith::SelectOp>( - op->getLoc(), isNaN, initialValue, result); + auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(), + isNaN, initialValue, result); // Update the flag which keeps track of whether we have seen a non-NaN // value. - auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>( - op->getLoc(), oldAllResultsNanFlagValue, isNaN); + auto newAllResultsNanFlagValue = arith::AndIOp::create( + nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN); resultsToYield.push_back(selectOp); resultsToYield.push_back(newAllResultsNanFlagValue); } else { resultsToYield.push_back(result); } - nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield); + linalg::YieldOp::create(nestedBuilder, loc, resultsToYield); }); if (!didEncounterError) @@ -1250,24 +1253,21 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto nanValueAttr = rewriter.getFloatAttr( elementTy, APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false)); - auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr); + auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); auto nanFilledTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{nanValue}, - ValueRange{emptyNanTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{nanValue}, + ValueRange{emptyNanTensor}) .result(); // Create an empty tensor, non need to fill this since it will be // overwritten by the select. auto finalEmptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); // Do a selection between the tensors akin to: @@ -1278,7 +1278,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, ins.push_back(linalgOp->getResult(0)); outs.push_back(finalEmptyTensor); auto linalgSelect = - rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs); + linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs); linalgOp = linalgSelect; } @@ -1350,7 +1350,7 @@ public: SmallVector<Value> dynDims; for (int i = 0; i < outputTy.getRank(); i++) { if (outputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } @@ -1401,16 +1401,17 @@ public: Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { - multiplierConstant = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(multiplierValues.front())); + multiplierConstant = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector<AffineExpr, 2> multiplierExprs{ rewriter.getAffineDimExpr(rank - 1)}; auto multiplierType = RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, rewriter.getI32Type()); - genericInputs.push_back(rewriter.create<arith::ConstantOp>( - loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); + genericInputs.push_back(arith::ConstantOp::create( + rewriter, loc, + DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, multiplierExprs, @@ -1424,16 +1425,16 @@ public: Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { - shiftConstant = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI8IntegerAttr(shiftValues.front())); + shiftConstant = arith::ConstantOp::create( + rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector<AffineExpr, 2> shiftExprs = { rewriter.getAffineDimExpr(rank - 1)}; auto shiftType = RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, rewriter.getIntegerType(8)); - genericInputs.push_back(rewriter.create<arith::ConstantOp>( - loc, DenseIntElementsAttr::get(shiftType, shiftValues))); + genericInputs.push_back(arith::ConstantOp::create( + rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, rewriter.getContext())); @@ -1444,13 +1445,13 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. - Value emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, outputTy.getShape(), outputTy.getElementType(), + Value emptyTensor = tensor::EmptyOp::create( + rewriter, loc, outputTy.getShape(), outputTy.getElementType(), ArrayRef<Value>({dynDims})); - auto linalgOp = rewriter.create<linalg::GenericOp>( - loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, - getNParallelLoopsAttrs(rank), + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, outputTy, genericInputs, ValueRange{emptyTensor}, + indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value value = blockArgs[0]; @@ -1466,9 +1467,10 @@ public: const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); // Extend zeropoint for sub-32bits widths. const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = nestedBuilder.create<arith::ConstantOp>( - loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), - *maybeIZp)); + auto inputZp = arith::ConstantOp::create( + nestedBuilder, loc, + IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), + *maybeIZp)); FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); if (failed(maybeOZp)) { @@ -1482,43 +1484,43 @@ public: unsigned outBitWidth = outIntType.getWidth(); const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = nestedBuilder.create<arith::ConstantOp>( - loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), - *maybeOZp)); + auto outputZp = arith::ConstantOp::create( + nestedBuilder, loc, + IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), + *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (valueTy.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>( - nestedLoc, - nestedBuilder.getIntegerType( - valueTy.getIntOrFloatBitWidth()), - value) + value = UnrealizedConversionCastOp::create( + nestedBuilder, nestedLoc, + nestedBuilder.getIntegerType( + valueTy.getIntOrFloatBitWidth()), + value) .getResult(0); } if (valueTy.getIntOrFloatBitWidth() < 32) { if (op.getInputUnsigned()) { - value = nestedBuilder.create<arith::ExtUIOp>( - nestedLoc, nestedBuilder.getI32Type(), value); + value = arith::ExtUIOp::create(nestedBuilder, nestedLoc, + nestedBuilder.getI32Type(), value); } else { - value = nestedBuilder.create<arith::ExtSIOp>( - nestedLoc, nestedBuilder.getI32Type(), value); + value = arith::ExtSIOp::create(nestedBuilder, nestedLoc, + nestedBuilder.getI32Type(), value); } } value = - nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp); + arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp); - value = nestedBuilder.create<tosa::ApplyScaleOp>( - loc, nestedBuilder.getI32Type(), value, multiplier, shift, - roundingMode); + value = tosa::ApplyScaleOp::create(nestedBuilder, loc, + nestedBuilder.getI32Type(), value, + multiplier, shift, roundingMode); // Move to the new zero-point. value = - nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp); + arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp); // Saturate to the output size. int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); @@ -1530,27 +1532,26 @@ public: intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } - auto intMinVal = nestedBuilder.create<arith::ConstantOp>( - loc, nestedBuilder.getI32IntegerAttr(intMin)); - auto intMaxVal = nestedBuilder.create<arith::ConstantOp>( - loc, nestedBuilder.getI32IntegerAttr(intMax)); + auto intMinVal = arith::ConstantOp::create( + nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin)); + auto intMaxVal = arith::ConstantOp::create( + nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax)); value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, nestedBuilder, /*isUnsigned=*/false); if (outIntType.getWidth() < 32) { - value = nestedBuilder.create<arith::TruncIOp>( - nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), - value); + value = arith::TruncIOp::create( + nestedBuilder, nestedLoc, + rewriter.getIntegerType(outIntType.getWidth()), value); } if (outIntType.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>(nestedLoc, - outIntType, value) + value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc, + outIntType, value) .getResult(0); } - nestedBuilder.create<linalg::YieldOp>(loc, value); + linalg::YieldOp::create(nestedBuilder, loc, value); }); rewriter.replaceOp(op, linalgOp->getResults()); @@ -1608,48 +1609,49 @@ public: auto collapseTy = RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)}, inputTy.getElementType()); - Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input, - reassociationMap); + Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input, + reassociationMap); // Get any dynamic shapes that appear in the input format. llvm::SmallVector<Value> outputDynSize; if (inputTy.isDynamicDim(0)) - outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) - outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); // Generate the elementwise operation for casting scaling the input value. auto genericTy = collapseTy.clone(resultTy.getElementType()); - Value empty = builder.create<tensor::EmptyOp>( - genericTy.getShape(), resultTy.getElementType(), outputDynSize); + Value empty = + tensor::EmptyOp::create(builder, genericTy.getShape(), + resultTy.getElementType(), outputDynSize); auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); SmallVector<utils::IteratorType> iterators(genericTy.getRank(), utils::IteratorType::parallel); - auto generic = builder.create<linalg::GenericOp>( - genericTy, ValueRange{collapse}, ValueRange{empty}, + auto generic = linalg::GenericOp::create( + builder, genericTy, ValueRange{collapse}, ValueRange{empty}, ArrayRef<AffineMap>{genericMap, genericMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; // This is the quantized case. if (inputTy.getElementType() != resultTy.getElementType()) { - value = - b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value); + value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(), + value); if (isBilinear && scale[0] != 0) { - Value scaleY = b.create<arith::ConstantOp>( - loc, b.getI32IntegerAttr(scale[0])); - value = b.create<arith::MulIOp>(loc, value, scaleY); + Value scaleY = arith::ConstantOp::create( + b, loc, b.getI32IntegerAttr(scale[0])); + value = arith::MulIOp::create(b, loc, value, scaleY); } if (isBilinear && scale[2] != 0) { - Value scaleX = b.create<arith::ConstantOp>( - loc, b.getI32IntegerAttr(scale[2])); - value = b.create<arith::MulIOp>(loc, value, scaleX); + Value scaleX = arith::ConstantOp::create( + b, loc, b.getI32IntegerAttr(scale[2])); + value = arith::MulIOp::create(b, loc, value, scaleX); } } - b.create<linalg::YieldOp>(loc, value); + linalg::YieldOp::create(b, loc, value); }); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( @@ -1697,9 +1699,9 @@ public: resizeShape.push_back(channels); auto resizeTy = resultTy.clone(resizeShape); - auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(), - op.getOffset(), op.getBorder(), - op.getMode()); + auto resize = + tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(), + op.getOffset(), op.getBorder(), op.getMode()); // Collapse an unit result dims. SmallVector<ReassociationExprs, 4> reassociationMap(2); @@ -1720,20 +1722,20 @@ public: collapseShape.push_back(channels); auto collapseTy = resultTy.clone(collapseShape); - Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize, - reassociationMap); + Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, + resize, reassociationMap); // Broadcast the collapsed shape to the output result. llvm::SmallVector<Value> outputDynSize; if (inputTy.isDynamicDim(0)) - outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 0)); if (inputTy.isDynamicDim(3)) - outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3)); + outputDynSize.push_back(tensor::DimOp::create(builder, input, 3)); SmallVector<utils::IteratorType> iterators(resultTy.getRank(), utils::IteratorType::parallel); - Value empty = builder.create<tensor::EmptyOp>( - resultTy.getShape(), resultTy.getElementType(), outputDynSize); + Value empty = tensor::EmptyOp::create( + builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize); SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(0)}; if (inputH != 1) @@ -1751,7 +1753,7 @@ public: ArrayRef<AffineMap>{inputMap, outputMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; - b.create<linalg::YieldOp>(loc, value); + linalg::YieldOp::create(b, loc, value); }); return success(); @@ -1789,10 +1791,10 @@ public: SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy, - *dynamicDimsOr); - auto genericOp = b.create<linalg::GenericOp>( - resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, + auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(), + resultETy, *dynamicDimsOr); + auto genericOp = linalg::GenericOp::create( + b, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); Value resize = genericOp.getResult(0); @@ -1800,19 +1802,21 @@ public: OpBuilder::InsertionGuard regionGuard(b); b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), TypeRange({resultETy}), loc); - Value batch = b.create<linalg::IndexOp>(0); - Value y = b.create<linalg::IndexOp>(1); - Value x = b.create<linalg::IndexOp>(2); - Value channel = b.create<linalg::IndexOp>(3); + Value batch = linalg::IndexOp::create(b, 0); + Value y = linalg::IndexOp::create(b, 1); + Value x = linalg::IndexOp::create(b, 2); + Value channel = linalg::IndexOp::create(b, 3); Value zeroI32 = - b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type())); - Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy)); - Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1)); - Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1)); + arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type())); + Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy)); + Value hMax = + arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1)); + Value wMax = + arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1)); - Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y); - Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x); + Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y); + Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x); SmallVector<int64_t> scale, offset, border; if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || @@ -1824,16 +1828,16 @@ public: } Value yScaleN, yScaleD, xScaleN, xScaleD; - yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0])); - yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1])); - xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2])); - xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3])); + yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0])); + yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1])); + xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2])); + xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3])); Value yOffset, xOffset, yBorder, xBorder; - yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0])); - xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1])); - yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0])); - xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1])); + yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0])); + xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1])); + yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0])); + xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1])); // Compute the ix and dx values for both the X and Y dimensions. auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, @@ -1846,16 +1850,16 @@ public: } // x = x * scale_d + offset; // ix = floor(x / scale_n) - Value val = b.create<arith::MulIOp>(in, scaleD); - val = b.create<arith::AddIOp>(val, offset); - index = b.create<arith::FloorDivSIOp>(val, scaleN); + Value val = arith::MulIOp::create(b, in, scaleD); + val = arith::AddIOp::create(b, val, offset); + index = arith::FloorDivSIOp::create(b, val, scaleN); // rx = x % scale_n // dx = rx / scale_n - Value r = b.create<arith::RemSIOp>(val, scaleN); - Value rFp = b.create<arith::SIToFPOp>(floatTy, r); - Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN); - delta = b.create<arith::DivFOp>(rFp, scaleNfp); + Value r = arith::RemSIOp::create(b, val, scaleN); + Value rFp = arith::SIToFPOp::create(b, floatTy, r); + Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN); + delta = arith::DivFOp::create(b, rFp, scaleNfp); }; // Compute the ix and dx values for the X and Y dimensions - int case. @@ -1870,11 +1874,11 @@ public: // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x - ix * scale_n; - Value val = b.create<arith::MulIOp>(in, scaleD); - val = b.create<arith::AddIOp>(val, offset); - index = b.create<arith::DivSIOp>(val, scaleN); - delta = b.create<arith::MulIOp>(index, scaleN); - delta = b.create<arith::SubIOp>(val, delta); + Value val = arith::MulIOp::create(b, in, scaleD); + val = arith::AddIOp::create(b, val, offset); + index = arith::DivSIOp::create(b, val, scaleN); + delta = arith::MulIOp::create(b, index, scaleN); + delta = arith::SubIOp::create(b, val, delta); }; Value ix, iy, dx, dy; @@ -1887,54 +1891,55 @@ public: } if (op.getMode() == "NEAREST_NEIGHBOR") { - auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1)); + auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, Value max, int size, ImplicitLocOpBuilder &b) -> Value { if (size == 1) { - return b.create<arith::ConstantIndexOp>(0); + return arith::ConstantIndexOp::create(b, 0); } Value pred; if (floatingPointMode) { - auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f)); - pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h); + auto h = + arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f)); + pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h); } else { - Value dvalDouble = b.create<arith::ShLIOp>(dval, one); - pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, - dvalDouble, scale); + Value dvalDouble = arith::ShLIOp::create(b, dval, one); + pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge, + dvalDouble, scale); } - auto offset = b.create<arith::SelectOp>(pred, one, zeroI32); - val = b.create<arith::AddIOp>(val, offset); + auto offset = arith::SelectOp::create(b, pred, one, zeroI32); + val = arith::AddIOp::create(b, val, offset); val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false); - return b.create<arith::IndexCastOp>(b.getIndexType(), val); + return arith::IndexCastOp::create(b, b.getIndexType(), val); }; iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b); ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b); - Value result = b.create<tensor::ExtractOp>( - input, ValueRange{batch, iy, ix, channel}); + Value result = tensor::ExtractOp::create( + b, input, ValueRange{batch, iy, ix, channel}); - b.create<linalg::YieldOp>(result); + linalg::YieldOp::create(b, result); } else { // The mode here must be BILINEAR. assert(op.getMode() == "BILINEAR"); - auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1)); + auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, Value max, ImplicitLocOpBuilder &b) { val0 = in; - val1 = b.create<arith::AddIOp>(val0, oneVal); + val1 = arith::AddIOp::create(b, val0, oneVal); val0 = clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false); val1 = clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false); - val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0); - val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1); + val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0); + val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1); }; // Linalg equivalent to the section below: @@ -1946,27 +1951,27 @@ public: getClampedIdxs(y0, y1, imageH, iy, hMax, b); getClampedIdxs(x0, x1, imageW, ix, wMax, b); - Value y0x0 = b.create<tensor::ExtractOp>( - input, ValueRange{batch, y0, x0, channel}); - Value y0x1 = b.create<tensor::ExtractOp>( - input, ValueRange{batch, y0, x1, channel}); - Value y1x0 = b.create<tensor::ExtractOp>( - input, ValueRange{batch, y1, x0, channel}); - Value y1x1 = b.create<tensor::ExtractOp>( - input, ValueRange{batch, y1, x1, channel}); + Value y0x0 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y0, x0, channel}); + Value y0x1 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y0, x1, channel}); + Value y1x0 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y1, x0, channel}); + Value y1x1 = tensor::ExtractOp::create( + b, input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { auto oneVal = - b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f)); + arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f)); auto interpolate = [&](Value val0, Value val1, Value delta, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) return val0; - Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta); - Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta); - Value mul1 = b.create<arith::MulFOp>(val1, delta); - return b.create<arith::AddFOp>(mul0, mul1); + Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta); + Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta); + Value mul1 = arith::MulFOp::create(b, val1, delta); + return arith::AddFOp::create(b, mul0, mul1); }; // Linalg equivalent to the section below: @@ -1982,18 +1987,18 @@ public: // Linalg equivalent to the section below: // result = topAcc * (unit_y - dy) + bottomAcc * dy Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); - b.create<linalg::YieldOp>(result); + linalg::YieldOp::create(b, result); } else { // Perform in quantized space. - y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0); - y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1); - y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0); - y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1); + y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0); + y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1); + y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0); + y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1); const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) { - dx = b.create<arith::ExtSIOp>(resultETy, dx); - dy = b.create<arith::ExtSIOp>(resultETy, dy); + dx = arith::ExtSIOp::create(b, resultETy, dx); + dy = arith::ExtSIOp::create(b, resultETy, dy); } Value yScaleNExt = yScaleN; @@ -2002,26 +2007,26 @@ public: const int64_t scaleBitwidth = xScaleN.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) { - yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN); - xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN); + yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN); + xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN); } auto interpolate = [](Value val0, Value val1, Value weight1, Value scale, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) - return b.create<arith::MulIOp>(val0, scale); - Value weight0 = b.create<arith::SubIOp>(scale, weight1); - Value mul0 = b.create<arith::MulIOp>(val0, weight0); - Value mul1 = b.create<arith::MulIOp>(val1, weight1); - return b.create<arith::AddIOp>(mul0, mul1); + return arith::MulIOp::create(b, val0, scale); + Value weight0 = arith::SubIOp::create(b, scale, weight1); + Value mul0 = arith::MulIOp::create(b, val0, weight0); + Value mul1 = arith::MulIOp::create(b, val1, weight1); + return arith::AddIOp::create(b, mul0, mul1); }; Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); Value result = interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); - b.create<linalg::YieldOp>(result); + linalg::YieldOp::create(b, result); } } } @@ -2072,17 +2077,16 @@ public: SmallVector<Value> dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } - Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis); + Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, inputTy.getShape(), - inputTy.getElementType(), - ArrayRef<Value>({dynDims})) + auto emptyTensor = tensor::EmptyOp::create( + rewriter, loc, inputTy.getShape(), + inputTy.getElementType(), ArrayRef<Value>({dynDims})) .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; @@ -2094,22 +2098,22 @@ public: llvm::SmallVector<Value> indices; for (unsigned int i = 0; i < inputTy.getRank(); i++) { Value index = - rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult(); + linalg::IndexOp::create(rewriter, nestedLoc, i).getResult(); if (i == axis) { - auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1); + auto one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1); auto sizeMinusOne = - rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one); - index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne, - index); + arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one); + index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne, + index); } indices.push_back(index); } - auto extract = nestedBuilder.create<tensor::ExtractOp>( - nestedLoc, input, indices); - nestedBuilder.create<linalg::YieldOp>(op.getLoc(), - extract.getResult()); + auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc, + input, indices); + linalg::YieldOp::create(nestedBuilder, op.getLoc(), + extract.getResult()); }); return success(); } @@ -2148,12 +2152,12 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> { SmallVector<Value> dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) || multiples[i] == -1) { - dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } - auto emptyTensor = rewriter.create<tensor::EmptyOp>( - op.getLoc(), genericShape, elementTy, dynDims); + auto emptyTensor = tensor::EmptyOp::create( + rewriter, op.getLoc(), genericShape, elementTy, dynDims); // We needs to map the input shape to the non-broadcasted dimensions. SmallVector<AffineExpr, 4> dimExprs; @@ -2168,12 +2172,12 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> { SmallVector<AffineMap, 2> affineMaps = { readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, RankedTensorType::get(genericShape, elementTy), input, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, RankedTensorType::get(genericShape, elementTy), input, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); + linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin()); }); auto shapeValue = getTosaConstShape( @@ -2220,28 +2224,27 @@ public: SmallVector<Value> dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) && i != axis) { - dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i)); } } // First fill the output buffer for the index. - auto emptyTensorIdx = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - outElementTy, dynDims) - .getResult(); - auto fillValueIdx = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(outElementTy, 0)); + auto emptyTensorIdx = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + outElementTy, dynDims) + .getResult(); + auto fillValueIdx = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, - ValueRange{emptyTensorIdx}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx}, + ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. - auto emptyTensorMax = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - inElementTy, dynDims) - .getResult(); + auto emptyTensorMax = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy, + dynDims) + .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); @@ -2250,11 +2253,10 @@ public: argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = - rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr); + arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueMax}, - ValueRange{emptyTensorMax}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax}, + ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along @@ -2274,8 +2276,8 @@ public: bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}, rewriter.getContext()); - auto linalgOp = rewriter.create<linalg::GenericOp>( - loc, ArrayRef<Type>({resultTy, resultMaxTy}), input, + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input, ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { @@ -2283,42 +2285,46 @@ public: auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; - Value newIndex = rewriter.create<arith::IndexCastOp>( - nestedLoc, oldIndex.getType(), - rewriter.create<linalg::IndexOp>(loc, axis)); + Value newIndex = arith::IndexCastOp::create( + rewriter, nestedLoc, oldIndex.getType(), + linalg::IndexOp::create(rewriter, loc, axis)); Value predicate; if (isa<FloatType>(inElementTy)) { if (argmaxOp.getNanMode() == "IGNORE") { // Only update index & max value for non NaN values. If all // values are NaNs, the initial index will be return which is 0. - predicate = rewriter.create<arith::CmpFOp>( - nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + predicate = arith::CmpFOp::create(rewriter, nestedLoc, + arith::CmpFPredicate::OGT, + newValue, oldValue); } else { // Update max value if either of the following is true: // - new value is bigger // - cur max is not NaN and new value is NaN - Value gt = rewriter.create<arith::CmpFOp>( - nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue); - Value oldNonNaN = rewriter.create<arith::CmpFOp>( - nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue); - predicate = rewriter.create<arith::AndIOp>( - nestedLoc, rewriter.getI1Type(), gt, oldNonNaN); + Value gt = arith::CmpFOp::create(rewriter, nestedLoc, + arith::CmpFPredicate::UGT, + newValue, oldValue); + Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc, + arith::CmpFPredicate::ORD, + oldValue, oldValue); + predicate = arith::AndIOp::create( + rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN); } } else if (isa<IntegerType>(inElementTy)) { - predicate = rewriter.create<arith::CmpIOp>( - nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); + predicate = arith::CmpIOp::create(rewriter, nestedLoc, + arith::CmpIPredicate::sgt, + newValue, oldValue); } else { didEncounterError = true; return; } - auto resultMax = rewriter.create<arith::SelectOp>( - nestedLoc, predicate, newValue, oldValue); - auto resultIndex = rewriter.create<arith::SelectOp>( - nestedLoc, predicate, newIndex, oldIndex); - nestedBuilder.create<linalg::YieldOp>( - nestedLoc, ValueRange({resultIndex, resultMax})); + auto resultMax = arith::SelectOp::create( + rewriter, nestedLoc, predicate, newValue, oldValue); + auto resultIndex = arith::SelectOp::create( + rewriter, nestedLoc, predicate, newIndex, oldIndex); + linalg::YieldOp::create(nestedBuilder, nestedLoc, + ValueRange({resultIndex, resultMax})); }); if (didEncounterError) @@ -2351,9 +2357,8 @@ public: auto loc = op.getLoc(); auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, - dynamicDims) + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynamicDims) .getResult(); SmallVector<AffineMap, 2> affineMaps = { @@ -2363,19 +2368,19 @@ public: rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, ArrayRef<Type>({resultTy}), ValueRange{indices}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{indices}, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; - auto index0 = rewriter.create<linalg::IndexOp>(loc, 0); - Value index1 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), indexValue); - auto index2 = rewriter.create<linalg::IndexOp>(loc, 2); - Value extract = rewriter.create<tensor::ExtractOp>( - loc, input, ValueRange{index0, index1, index2}); - rewriter.create<linalg::YieldOp>(loc, extract); + auto index0 = linalg::IndexOp::create(rewriter, loc, 0); + Value index1 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), indexValue); + auto index2 = linalg::IndexOp::create(rewriter, loc, 2); + Value extract = tensor::ExtractOp::create( + rewriter, loc, input, ValueRange{index0, index1, index2}); + linalg::YieldOp::create(rewriter, loc, extract); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); @@ -2424,22 +2429,22 @@ public: for (int i = 0; i < resultTy.getRank(); ++i) { if (inputTy.isDynamicDim(i)) { dynDims.push_back( - rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i)); + tensor::DimOp::create(rewriter, loc, op.getOperand(0), i)); } } - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - resultElementTy, dynDims) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynDims) + .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, - getNParallelLoopsAttrs(resultTy.getRank())); + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, + affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); { @@ -2452,69 +2457,69 @@ public: rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { - Value index = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), inputValue); - Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128); - index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(), - index, offset); + Value index = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), inputValue); + Value offset = arith::ConstantIndexOp::create(rewriter, loc, 128); + index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), + index, offset); Value extract = - rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); - rewriter.create<linalg::YieldOp>(loc, extract); + tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); + linalg::YieldOp::create(rewriter, loc, extract); return success(); } if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { - Value extend = rewriter.create<arith::ExtSIOp>( - loc, rewriter.getI32Type(), inputValue); - - auto offset = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(32768)); - auto seven = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(7)); - auto one = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(1)); - auto b1111111 = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(127)); + Value extend = arith::ExtSIOp::create( + rewriter, loc, rewriter.getI32Type(), inputValue); + + auto offset = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(32768)); + auto seven = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(7)); + auto one = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(1)); + auto b1111111 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value - auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset); - Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven); + auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset); + Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven); Value fraction = - rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111); + arith::AndIOp::create(rewriter, loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; - Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one); + Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one); - index = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), index); - indexPlusOne = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), indexPlusOne); + index = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), index); + indexPlusOne = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), indexPlusOne); Value base = - rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); - Value next = rewriter.create<tensor::ExtractOp>( - loc, table, ValueRange{indexPlusOne}); + tensor::ExtractOp::create(rewriter, loc, table, ValueRange{index}); + Value next = tensor::ExtractOp::create(rewriter, loc, table, + ValueRange{indexPlusOne}); base = - rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base); + arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base); next = - rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next); + arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction - Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven); - Value diff = rewriter.create<arith::SubIOp>(loc, next, base); - Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction); + Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven); + Value diff = arith::SubIOp::create(rewriter, loc, next, base); + Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction); Value result = - rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled); + arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled); - rewriter.create<linalg::YieldOp>(loc, result); + linalg::YieldOp::create(rewriter, loc, result); return success(); } @@ -2532,8 +2537,8 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, OpFoldResult ofr) { - auto one = builder.create<arith::ConstantIndexOp>(loc, 1); - auto two = builder.create<arith::ConstantIndexOp>(loc, 2); + auto one = arith::ConstantIndexOp::create(builder, loc, 1); + auto two = arith::ConstantIndexOp::create(builder, loc, 2); auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr); auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two); @@ -2562,30 +2567,30 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { RankedTensorType type, llvm::ArrayRef<Value> dynamicSizes) { auto emptyTensor = - rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes); + tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); - auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); return filledTensor; } static Value castIndexToFloat(OpBuilder &builder, Location loc, FloatType type, Value value) { - auto integerVal = builder.create<arith::IndexCastUIOp>( - loc, + auto integerVal = arith::IndexCastUIOp::create( + builder, loc, type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type() : builder.getI32Type(), value); - return builder.create<arith::UIToFPOp>(loc, type, integerVal); + return arith::UIToFPOp::create(builder, loc, type, integerVal); } static Value createLinalgIndex(OpBuilder &builder, Location loc, FloatType type, int64_t index) { - auto indexVal = builder.create<linalg::IndexOp>(loc, index); + auto indexVal = linalg::IndexOp::create(builder, loc, index); return castIndexToFloat(builder, loc, type, indexVal); } @@ -2640,7 +2645,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); - auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr); + auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); auto constH = castIndexToFloat(rewriter, loc, elementType, dimH); auto constW = castIndexToFloat(rewriter, loc, elementType, dimW); @@ -2650,43 +2655,45 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { Value sumImag = args[2]; // Indices for angle computation - Value oy = builder.create<linalg::IndexOp>(loc, 1); - Value ox = builder.create<linalg::IndexOp>(loc, 2); - Value iy = builder.create<linalg::IndexOp>(loc, 3); - Value ix = builder.create<linalg::IndexOp>(loc, 4); + Value oy = linalg::IndexOp::create(builder, loc, 1); + Value ox = linalg::IndexOp::create(builder, loc, 2); + Value iy = linalg::IndexOp::create(builder, loc, 3); + Value ix = linalg::IndexOp::create(builder, loc, 4); // Calculating angle without integer parts of components as sin/cos are // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) // / W); - auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); - auto ixXox = builder.create<index::MulOp>(loc, ix, ox); + auto iyXoy = index::MulOp::create(builder, loc, iy, oy); + auto ixXox = index::MulOp::create(builder, loc, ix, ox); - auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); - auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); + auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); + auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem); auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem); - auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); - auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); - auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent); - auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY); + auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); + auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); + auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); + auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); // realComponent = valReal * cos(angle) // imagComponent = valReal * sin(angle) - auto cosAngle = builder.create<math::CosOp>(loc, angle); - auto sinAngle = builder.create<math::SinOp>(loc, angle); + auto cosAngle = math::CosOp::create(builder, loc, angle); + auto sinAngle = math::SinOp::create(builder, loc, angle); auto realComponent = - builder.create<arith::MulFOp>(loc, valReal, cosAngle); + arith::MulFOp::create(builder, loc, valReal, cosAngle); auto imagComponent = - builder.create<arith::MulFOp>(loc, valReal, sinAngle); + arith::MulFOp::create(builder, loc, valReal, sinAngle); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent - auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent); - auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent); + auto outReal = + arith::AddFOp::create(builder, loc, sumReal, realComponent); + auto outImag = + arith::SubFOp::create(builder, loc, sumImag, imagComponent); - builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag}); + linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp<linalg::GenericOp>( @@ -2760,7 +2767,7 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> { // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); - auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr); + auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr); Value constH = RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH); Value constW = @@ -2773,57 +2780,59 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> { Value sumImag = args[3]; // Indices for angle computation - Value oy = builder.create<linalg::IndexOp>(loc, 1); - Value ox = builder.create<linalg::IndexOp>(loc, 2); - Value iy = builder.create<linalg::IndexOp>(loc, 3); - Value ix = builder.create<linalg::IndexOp>(loc, 4); + Value oy = linalg::IndexOp::create(builder, loc, 1); + Value ox = linalg::IndexOp::create(builder, loc, 2); + Value iy = linalg::IndexOp::create(builder, loc, 3); + Value ix = linalg::IndexOp::create(builder, loc, 4); // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * // ox) % W ) / W); - auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); - auto ixXox = builder.create<index::MulOp>(loc, ix, ox); + auto iyXoy = index::MulOp::create(builder, loc, iy, oy); + auto ixXox = index::MulOp::create(builder, loc, ix, ox); - auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); - auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); + auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH); + auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW); auto iyRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem); auto ixRemFloat = RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem); - auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); - auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); + auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH); + auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW); - auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent); - auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY); + auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent); + auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY); if (inverse.getValue()) { - angle = builder.create<arith::MulFOp>( - loc, angle, - rewriter.create<arith::ConstantOp>( - loc, rewriter.getFloatAttr(real_el_ty, -1.0))); + angle = arith::MulFOp::create( + builder, loc, angle, + arith::ConstantOp::create(rewriter, loc, + rewriter.getFloatAttr(real_el_ty, -1.0))); } // realComponent = val_real * cos(a) + val_imag * sin(a); // imagComponent = -val_real * sin(a) + val_imag * cos(a); - auto cosAngle = builder.create<math::CosOp>(loc, angle); - auto sinAngle = builder.create<math::SinOp>(loc, angle); + auto cosAngle = math::CosOp::create(builder, loc, angle); + auto sinAngle = math::SinOp::create(builder, loc, angle); - auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle); - auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle); - auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin); + auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle); + auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle); + auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin); - auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle); - auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle); + auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle); + auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle); - auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin); + auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent - auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent); - auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent); + auto outReal = + arith::AddFOp::create(builder, loc, sumReal, realComponent); + auto outImag = + arith::AddFOp::create(builder, loc, sumImag, imagComponent); - builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag}); + linalg::YieldOp::create(builder, loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp<linalg::GenericOp>( diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 00b9a06..da1fb20 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -52,11 +52,11 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad, highIndices.push_back(rewriter.getIndexAttr(highPad)); } - Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr); + Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr); - return rewriter.create<tensor::PadOp>( - loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices, - highIndices, padValue); + return tensor::PadOp::create(rewriter, loc, + RankedTensorType::get(paddedShape, inputETy), + input, lowIndices, highIndices, padValue); } static mlir::Value @@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef<AffineMap> indexingMaps) { ShapedType resultTy = cast<ShapedType>(conv.getType()); - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal); - } - Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]); - builder.create<linalg::YieldOp>(loc, added); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, conv}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + arith::ExtSIOp::create(builder, loc, resType, biasVal); + } + Value added = + arith::AddIOp::create(builder, loc, biasVal, args[1]); + linalg::YieldOp::create(builder, loc, added); + }) .getResult(0); } @@ -124,29 +125,29 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); // Build the broadcast-like operation as a linalg.generic. - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({source}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = - resultTy.getElementType().isFloat() - ? builder.create<arith::ExtFOp>(loc, resType, biasVal) - .getResult() - : builder.create<arith::ExtSIOp>(loc, resType, biasVal) - .getResult(); - } - builder.create<linalg::YieldOp>(loc, biasVal); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({source}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + resultTy.getElementType().isFloat() + ? arith::ExtFOp::create(builder, loc, resType, biasVal) + .getResult() + : arith::ExtSIOp::create(builder, loc, resType, + biasVal) + .getResult(); + } + linalg::YieldOp::create(builder, loc, biasVal); + }) .getResult(0); } static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder) { - return builder.create<arith::ConstantIndexOp>(attr); + return arith::ConstantIndexOp::create(builder, attr); } // Calculating the output width/height using the formula: @@ -160,22 +161,22 @@ static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t dilationAttr, OpBuilder &rewriter) { ImplicitLocOpBuilder builder(loc, rewriter); - auto one = rewriter.create<arith::ConstantOp>( - loc, IntegerAttr::get(inputDim.getType(), 1)); + auto one = arith::ConstantOp::create(rewriter, loc, + IntegerAttr::get(inputDim.getType(), 1)); Value padBefore = reifyConstantDim(padBeforeAttr, builder); - Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore); + Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore); Value padAfter = reifyConstantDim(padAfterAttr, builder); - Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter); + Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter); - Value subOne = builder.create<arith::SubIOp>(kernelDim, one); + Value subOne = arith::SubIOp::create(builder, kernelDim, one); Value dilation = reifyConstantDim(dilationAttr, builder); - Value dilated = builder.create<arith::MulIOp>(dilation, subOne); - Value addOne = builder.create<arith::AddIOp>(dilated, one); + Value dilated = arith::MulIOp::create(builder, dilation, subOne); + Value addOne = arith::AddIOp::create(builder, dilated, one); - Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne); + Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne); Value stride = reifyConstantDim(strideAttr, builder); - Value divide = builder.create<arith::DivUIOp>(subtract, stride); - return builder.create<arith::AddIOp>(divide, one); + Value divide = arith::DivUIOp::create(builder, subtract, stride); + return arith::AddIOp::create(builder, divide, one); } // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D @@ -198,9 +199,9 @@ static SmallVector<Value> inferDynamicDimsForConv( auto padBottom = padAttr[i * 2 + 1]; auto stride = strideAttr[i]; auto dilation = dilationAttr[i]; - Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim); + Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim); Value kernelDynDim = - rewriter.create<tensor::DimOp>(loc, weight, kernelDim); + tensor::DimOp::create(rewriter, loc, weight, kernelDim); // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) dynDims[inputDim] = getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom, @@ -211,7 +212,7 @@ static SmallVector<Value> inferDynamicDimsForConv( // Get the batch/channels dimensions. for (int i = 0; i < inputRank; i++) { if (resultTy.isDynamicDim(i) && !dynDims[i]) - dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i); + dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i); } SmallVector<Value> filteredDims = condenseValues(dynDims); @@ -350,8 +351,8 @@ public: auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight, - weightPermAttr); + weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight, + weightPermAttr); } } @@ -372,8 +373,8 @@ public: auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm); Type newWeightTy = RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight, - weightPermAttr); + weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight, + weightPermAttr); } // Extract the attributes for convolution. @@ -384,8 +385,8 @@ public: auto strideAttr = rewriter.getI64TensorAttr(stride); auto dilationAttr = rewriter.getI64TensorAttr(dilation); - Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( - loc, resultTy.getShape(), accETy, filteredDims); + Value biasEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultTy.getShape(), accETy, filteredDims); Value broadcastBias = linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor); @@ -394,30 +395,28 @@ public: auto iZp = rewriter.getI32IntegerAttr(inputZpVal); auto kZp = rewriter.getI32IntegerAttr(weightZpVal); - auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); - auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp); + auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); + auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); - Value conv = - rewriter - .create<LinalgConvQOp>( - loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) - ->getResult(0); + Value conv = LinalgConvQOp::create( + rewriter, loc, resultTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) + ->getResult(0); rewriter.replaceOp(op, conv); return success(); } - Value conv = rewriter - .create<LinalgConvOp>( - loc, accTy, ValueRange{input, weight}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) + Value conv = LinalgConvOp::create( + rewriter, loc, accTy, ValueRange{input, weight}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) ->getResult(0); // We may need to truncate back to the result type if the accumulator was // wider than the result. if (resultTy != accTy) - conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv); + conv = tosa::CastOp::create(rewriter, loc, resultTy, conv); rewriter.replaceOp(op, conv); return success(); @@ -526,16 +525,15 @@ public: accETy); auto resultZeroAttr = rewriter.getZeroAttr(accETy); - Value emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, linalgConvTy.getShape(), accETy, filteredDims); - Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value emptyTensor = tensor::EmptyOp::create( + rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); + Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); - Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( - loc, resultTy.getShape(), resultETy, filteredDims); + Value biasEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultTy.getShape(), resultETy, filteredDims); // Broadcast the initial value to the output tensor before convolving. SmallVector<AffineMap, 4> indexingMaps; @@ -544,60 +542,56 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); if (hasNullZps) { - Value conv = rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmOp>( - loc, linalgConvTy, ValueRange{input, weight}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) + Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create( + rewriter, loc, linalgConvTy, ValueRange{input, weight}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); // We may need to truncate back to the result type if the accumulator was // wider than the result. if (accETy != resultETy) - conv = rewriter.create<tosa::CastOp>( - loc, + conv = tosa::CastOp::create( + rewriter, loc, RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(), resultETy), conv); SmallVector<ReassociationExprs, 4> reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); - Value convReshape = rewriter.create<tensor::CollapseShapeOp>( - loc, resultTy, conv, reassociationMap); + Value convReshape = tensor::CollapseShapeOp::create( + rewriter, loc, resultTy, conv, reassociationMap); Value result = - rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, convReshape}), - biasEmptyTensor, indexingMaps, - getNParallelLoopsAttrs(resultRank), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added; - if (llvm::isa<FloatType>(inputETy)) - added = nestedBuilder.create<arith::AddFOp>(loc, args[0], - args[1]); - else - added = nestedBuilder.create<arith::AddIOp>(loc, args[0], - args[1]); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); - }) + linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, convReshape}), + biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added; + if (llvm::isa<FloatType>(inputETy)) + added = arith::AddFOp::create(nestedBuilder, loc, args[0], + args[1]); + else + added = arith::AddIOp::create(nestedBuilder, loc, args[0], + args[1]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, added); + }) .getResult(0); rewriter.replaceOp(op, result); } else { IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal); IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); - auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); - auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp); - Value conv = - rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( - loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) - .getResult(0); + auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); + auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); + Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create( + rewriter, loc, linalgConvTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + .getResult(0); SmallVector<ReassociationExprs, 4> reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); - Value convReshape = rewriter.create<tensor::CollapseShapeOp>( - loc, resultTy, conv, reassociationMap); + Value convReshape = tensor::CollapseShapeOp::create( + rewriter, loc, resultTy, conv, reassociationMap); Value result = linalgIntBroadcastExtSIAdd( rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps); rewriter.replaceOp(op, result); @@ -621,26 +615,26 @@ public: dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank()); if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) { - dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0); + dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0); } if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) { - dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1); + dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1); } if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) { - dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2); + dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2); } SmallVector<Value> filteredDims = condenseValues(dynDims); auto zeroAttr = rewriter.getZeroAttr(outputElementTy); - Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); - auto emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), + outputTy.getElementType(), filteredDims); + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); FailureOr<int64_t> maybeAZp = op.getAZeroPoint(); @@ -670,10 +664,10 @@ public: return success(); } - auto aZp = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(aZpVal)); - auto bZp = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(bZpVal)); + auto aZp = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(aZpVal)); + auto bZp = arith::ConstantOp::create(rewriter, loc, + rewriter.getI32IntegerAttr(bZpVal)); rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor); @@ -702,7 +696,7 @@ public: // Batch dimension if (resultTy.isDynamicDim(0)) - dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); // Height/width dimensions for (int64_t dim : {1, 2}) { @@ -713,10 +707,10 @@ public: int64_t index = dim - 1; // Input height/width - Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim); + Value ihw = tensor::DimOp::create(rewriter, loc, input, dim); // Kernel height/width - Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]); + Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]); // Output height/width Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2], @@ -727,7 +721,7 @@ public: // Channel dimension if (resultTy.isDynamicDim(3)) - dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3)); return dynamicDims; } @@ -776,7 +770,7 @@ public: Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); - Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr); + Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr); ArrayRef<int64_t> kernel = op.getKernel(); ArrayRef<int64_t> stride = op.getStride(); @@ -785,15 +779,16 @@ public: Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims); + Value emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultTy.getElementType(), dynamicDims); Value filledEmptyTensor = - rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor) + linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor) .result(); Value fakeWindowDims = - rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy); + tensor::EmptyOp::create(rewriter, loc, kernel, resultETy); if (isUnsigned) { rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>( @@ -802,8 +797,8 @@ public: return llvm::success(); } - auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>( - op->getLoc(), ArrayRef<Type>{resultTy}, + auto resultOp = linalg::PoolingNhwcMaxOp::create( + rewriter, op->getLoc(), ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, dilationAttr); @@ -823,9 +818,10 @@ public: // it to include the appropriate checks. If the current value is NaN the // old value of pool will be taken otherwise we use the result. if (nanMode == "IGNORE") { - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, resultOp.getType(0), resultOp.getInputs(), resultOp.getOutputs(), - resultOp.getIndexingMapsArray(), resultOp.getIteratorTypesArray(), + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultOp.getType(0), resultOp.getInputs(), + resultOp.getOutputs(), resultOp.getIndexingMapsArray(), + resultOp.getIteratorTypesArray(), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { IRMapping map; auto oldBlock = resultOp.getRegion().begin(); @@ -833,12 +829,12 @@ public: auto &oldMaxOp = *resultOp.getBlock()->begin(); map.map(oldArgs, blockArgs); auto *newOp = opBuilder.clone(oldMaxOp, map); - Value isNaN = opBuilder.create<arith::CmpFOp>( - loc, arith::CmpFPredicate::UNO, blockArgs.front(), - blockArgs.front()); - auto selectOp = opBuilder.create<arith::SelectOp>( - loc, isNaN, blockArgs.back(), newOp->getResult(0)); - opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult()); + Value isNaN = + arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO, + blockArgs.front(), blockArgs.front()); + auto selectOp = arith::SelectOp::create( + opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0)); + linalg::YieldOp::create(opBuilder, loc, selectOp.getResult()); }); rewriter.replaceOp(resultOp, genericOp); } @@ -894,7 +890,7 @@ public: Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); auto initialAttr = rewriter.getZeroAttr(accETy); - Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr); + Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr); ArrayRef<int64_t> kernel = op.getKernel(); ArrayRef<int64_t> stride = op.getStride(); @@ -903,46 +899,44 @@ public: Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); // Create the linalg op that performs pooling. - Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>( - loc, accTy.getShape(), accETy, dynamicDims); + Value poolEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, accTy.getShape(), accETy, dynamicDims); Value filledEmptyTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{initialValue}, - ValueRange{poolEmptyTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{initialValue}, + ValueRange{poolEmptyTensor}) .result(); Value fakeWindowDims = - rewriter.create<tensor::EmptyOp>(loc, kernel, accETy); + tensor::EmptyOp::create(rewriter, loc, kernel, accETy); // Sum across the pooled region. - Value poolingOp = rewriter - .create<linalg::PoolingNhwcSumOp>( - loc, ArrayRef<Type>{accTy}, - ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr) + Value poolingOp = linalg::PoolingNhwcSumOp::create( + rewriter, loc, ArrayRef<Type>{accTy}, + ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr) .getResult(0); // Normalize the summed value by the number of elements grouped in each // pool. - Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1); - Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2); + Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1); + Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2); - auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); - iH = rewriter.create<arith::SubIOp>(loc, iH, one); - iW = rewriter.create<arith::SubIOp>(loc, iW, one); + auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); + iH = arith::SubIOp::create(rewriter, loc, iH, one); + iW = arith::SubIOp::create(rewriter, loc, iW, one); - Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>( - loc, resultTy.getShape(), resultETy, dynamicDims); + Value genericEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultTy.getShape(), resultETy, dynamicDims); auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp}, ValueRange{genericEmptyTensor}, ArrayRef<AffineMap>({affineMap, affineMap}), getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); // Determines what the portion of valid input is covered by the // kernel. @@ -950,30 +944,30 @@ public: if (pad == 0) return valid; - auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad); - Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal); + auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad); + Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal); - Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero); - return rewriter.create<arith::AddIOp>(loc, valid, offset) + Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero); + return arith::AddIOp::create(rewriter, loc, valid, offset) ->getResult(0); }; auto coverageFn = [&](int64_t i, Value isize) -> Value { Value strideVal = - rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]); + arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]); Value val = - rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]); + arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]); // Find the position relative to the input tensor's ends. - Value left = rewriter.create<linalg::IndexOp>(loc, i); - Value right = rewriter.create<arith::SubIOp>(loc, isize, left); - left = rewriter.create<arith::MulIOp>(loc, left, strideVal); - right = rewriter.create<arith::MulIOp>(loc, right, strideVal); + Value left = linalg::IndexOp::create(rewriter, loc, i); + Value right = arith::SubIOp::create(rewriter, loc, isize, left); + left = arith::MulIOp::create(rewriter, loc, left, strideVal); + right = arith::MulIOp::create(rewriter, loc, right, strideVal); // Determine how much padding was included. val = padFn(val, left, pad[i * 2]); val = padFn(val, right, pad[i * 2 + 1]); - return rewriter.create<arith::MaxSIOp>(loc, one, val); + return arith::MaxSIOp::create(rewriter, loc, one, val); }; // Compute the indices from either end. @@ -981,95 +975,95 @@ public: Value kW3 = coverageFn(2, iW); // Compute the total number of elements and normalize. - auto count = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), - rewriter.create<arith::MulIOp>(loc, kH3, kW3)); + auto count = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), + arith::MulIOp::create(rewriter, loc, kH3, kW3)); // Divide by the number of summed values. For floats this is just // a div however for quantized values input normalization had // to be applied. Value poolVal = args[0]; if (isa<FloatType>(accETy)) { - auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count); - poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF) + auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count); + poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF) ->getResult(0); if (accETy.getIntOrFloatBitWidth() > resultETy.getIntOrFloatBitWidth()) poolVal = - rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal); + arith::TruncFOp::create(rewriter, loc, resultETy, poolVal); } else { // If we have quantization information we need to apply an offset // for the input zp value. if (inputZpVal != 0) { - auto inputZp = rewriter.create<arith::ConstantOp>( - loc, b.getIntegerAttr(accETy, inputZpVal)); + auto inputZp = arith::ConstantOp::create( + rewriter, loc, b.getIntegerAttr(accETy, inputZpVal)); Value offset = - rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp); + arith::MulIOp::create(rewriter, loc, accETy, count, inputZp); poolVal = - rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset); + arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset); } // Compute: k = 32 - count_leading_zeros(value - 1) - Value one32 = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(1)); - Value thirtyTwo32 = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(32)); + Value one32 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(1)); + Value thirtyTwo32 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32IntegerAttr(32)); Value countSubOne = - rewriter.create<arith::SubIOp>(loc, count, one32); + arith::SubIOp::create(rewriter, loc, count, one32); Value leadingZeros = - rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne); + math::CountLeadingZerosOp::create(rewriter, loc, countSubOne); Value k = - rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros); + arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros); // Compute: numerator = ((1 << 30) + 1) << k Value k64 = - rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k); - Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); + arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k); + Value thirtyShiftPlusOne = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); Value numerator = - rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64); + arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64); // Compute: scale.multiplier = numerator / value; - Value count64 = rewriter.create<arith::ExtUIOp>( - loc, rewriter.getI64Type(), count); + Value count64 = arith::ExtUIOp::create( + rewriter, loc, rewriter.getI64Type(), count); Value multiplier = - rewriter.create<arith::DivUIOp>(loc, numerator, count64); - multiplier = rewriter.create<arith::TruncIOp>( - loc, rewriter.getI32Type(), multiplier); + arith::DivUIOp::create(rewriter, loc, numerator, count64); + multiplier = arith::TruncIOp::create( + rewriter, loc, rewriter.getI32Type(), multiplier); // Compute: scale.shift = 30 + k Value k8 = - rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k); - Value thirty8 = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI8IntegerAttr(30)); - Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8); + arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k); + Value thirty8 = arith::ConstantOp::create( + rewriter, loc, rewriter.getI8IntegerAttr(30)); + Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); auto scaled = - rewriter - .create<tosa::ApplyScaleOp>( - loc, rewriter.getI32Type(), poolVal, multiplier, shift, - rewriter.getStringAttr("SINGLE_ROUND")) + tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), poolVal, multiplier, + shift, rewriter.getStringAttr("SINGLE_ROUND")) .getResult(); // If we have quantization information we need to apply output // zeropoint. if (outputZpVal != 0) { - auto outputZp = rewriter.create<arith::ConstantOp>( - loc, b.getIntegerAttr(scaled.getType(), outputZpVal)); - scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp) + auto outputZp = arith::ConstantOp::create( + rewriter, loc, + b.getIntegerAttr(scaled.getType(), outputZpVal)); + scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp) .getResult(); } // Apply Clip. int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); - auto min = rewriter.create<arith::ConstantIntOp>( - loc, accETy, + auto min = arith::ConstantIntOp::create( + rewriter, loc, accETy, APInt::getSignedMinValue(outBitwidth).getSExtValue()); - auto max = rewriter.create<arith::ConstantIntOp>( - loc, accETy, + auto max = arith::ConstantIntOp::create( + rewriter, loc, accETy, APInt::getSignedMaxValue(outBitwidth).getSExtValue()); auto clamp = clampIntHelper(loc, scaled, min, max, rewriter, /*isUnsigned=*/false); @@ -1078,11 +1072,11 @@ public: // Convert type. if (resultETy != clamp.getType()) { poolVal = - rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal); + arith::TruncIOp::create(rewriter, loc, resultETy, poolVal); } } - rewriter.create<linalg::YieldOp>(loc, poolVal); + linalg::YieldOp::create(rewriter, loc, poolVal); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -1107,8 +1101,9 @@ public: auto permutedSizes = applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms); - auto permutedInit = rewriter.create<tensor::EmptyOp>( - loc, permutedSizes, op.getInput1().getType().getElementType()); + auto permutedInit = + tensor::EmptyOp::create(rewriter, loc, permutedSizes, + op.getInput1().getType().getElementType()); rewriter.replaceOpWithNewOp<linalg::TransposeOp>( op, op.getInput1(), permutedInit, llvm::to_vector(llvm::map_range( diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp index 7dbccd1..f8efb34 100644 --- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp +++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp @@ -13,7 +13,6 @@ #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; @@ -27,8 +26,8 @@ public: LogicalResult matchAndRewrite(tosa::VariableOp op, PatternRewriter &rewriter) const final { auto variableType = tosa::getVariableType(op); - auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>( - op.getLoc(), op.getName(), variableType, /*is_mutable=*/true, + auto newVariable = mlir::ml_program::GlobalOp::create( + rewriter, op.getLoc(), op.getName(), variableType, /*is_mutable=*/true, op.getInitialValueAttr(), /*sym_visibility=*/nullptr); newVariable.setPrivate(); rewriter.replaceOp(op, newVariable); @@ -45,8 +44,8 @@ public: PatternRewriter &rewriter) const final { auto globalSymbolRef = SymbolRefAttr::get(rewriter.getContext(), op.getName()); - auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>( - op.getLoc(), globalSymbolRef, op.getInput1()); + auto newVariableWrite = ml_program::GlobalStoreOp::create( + rewriter, op.getLoc(), globalSymbolRef, op.getInput1()); rewriter.replaceOp(op, newVariableWrite); return success(); } @@ -60,8 +59,8 @@ public: PatternRewriter &rewriter) const final { auto globalSymbolRef = SymbolRefAttr::get(rewriter.getContext(), op.getName()); - auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>( - op.getLoc(), op.getType(), globalSymbolRef); + auto newVariableRead = ml_program::GlobalLoadOp::create( + rewriter, op.getLoc(), op.getType(), globalSymbolRef); rewriter.replaceOp(op, newVariableRead); return success(); diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp index 03f9d20..aa6b416 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp @@ -30,7 +30,7 @@ static void inlineIfCase(Region &srcRegion, Region &dstRegion, auto yield = cast<YieldOp>(headBlock->getTerminator()); rewriter.setInsertionPoint(yield); - rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs()); + scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs()); rewriter.eraseOp(yield); headBlock->eraseArguments(0, headBlock->getNumArguments()); @@ -46,13 +46,13 @@ static void inlineWhileCase(Region &srcRegion, Region &dstRegion, auto yield = cast<YieldOp>(headBlock->getTerminator()); rewriter.setInsertionPoint(yield); if (isCond) { - auto condition = - rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0)); - rewriter.create<scf::ConditionOp>(yield.getLoc(), condition, - headBlock->getArguments()); + auto condition = tensor::ExtractOp::create(rewriter, yield.getLoc(), + yield.getOperand(0)); + scf::ConditionOp::create(rewriter, yield.getLoc(), condition, + headBlock->getArguments()); } else { rewriter.setInsertionPoint(yield); - rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs()); + scf::YieldOp::create(rewriter, yield.getLoc(), yield.getInputs()); } rewriter.eraseOp(yield); } @@ -66,9 +66,9 @@ public: LogicalResult matchAndRewrite(tosa::IfOp op, PatternRewriter &rewriter) const final { auto condition = - rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition()); - auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(), - condition, true); + tensor::ExtractOp::create(rewriter, op.getLoc(), op.getCondition()); + auto newIf = scf::IfOp::create(rewriter, op.getLoc(), op.getResultTypes(), + condition, true); inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(), rewriter); @@ -88,7 +88,7 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> { static Value createIndexConst(OpBuilder &builder, Location loc, int64_t value) { - return builder.create<arith::ConstantIndexOp>(loc, value); + return arith::ConstantIndexOp::create(builder, loc, value); } public: @@ -119,9 +119,9 @@ public: auto n = ivs[0]; // Read the index and cast it to index type - auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs); - auto castIndex = builder.create<arith::IndexCastOp>( - loc, builder.getIndexType(), index); + auto index = tensor::ExtractOp::create(builder, loc, indices, ivs); + auto castIndex = arith::IndexCastOp::create( + builder, loc, builder.getIndexType(), index); // Offset, sizes, and strides for the input tensor auto inputOffset = llvm::to_vector(ivs); @@ -130,13 +130,13 @@ public: llvm::SmallVector<Value> sizes = {one, one, dimC}; llvm::SmallVector<Value> strides = {one, one, one}; - auto slice = builder.create<tensor::ExtractSliceOp>( - loc, input, inputOffset, sizes, strides); + auto slice = tensor::ExtractSliceOp::create(builder, loc, input, + inputOffset, sizes, strides); // Insert the slice into the output accumulator tensor. llvm::SmallVector<Value> outputOffset = {n, castIndex, zero}; - auto updated = builder.create<tensor::InsertSliceOp>( - loc, slice, args[0], outputOffset, sizes, strides); + auto updated = tensor::InsertSliceOp::create( + builder, loc, slice, args[0], outputOffset, sizes, strides); return {updated}; }; @@ -155,8 +155,8 @@ public: LogicalResult matchAndRewrite(tosa::WhileOp op, PatternRewriter &rewriter) const final { - auto newWhile = rewriter.create<scf::WhileOp>( - op.getLoc(), op.getResultTypes(), op.getInputList()); + auto newWhile = scf::WhileOp::create( + rewriter, op.getLoc(), op.getResultTypes(), op.getInputList()); rewriter.createBlock(&newWhile.getBefore()); rewriter.createBlock(&newWhile.getAfter()); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index c6cbcb0..2945ae3 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -308,15 +308,15 @@ public: if (ShapedType::isStatic(sizes.back())) continue; - auto dim = rewriter.create<tensor::DimOp>(loc, input, index); - auto offset = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr(sliceStarts[index])); - dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset)); + auto dim = tensor::DimOp::create(rewriter, loc, input, index); + auto offset = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(sliceStarts[index])); + dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset)); } - auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( - sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, - ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), + auto newSliceOp = tensor::ExtractSliceOp::create( + rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), + dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), rewriter.getDenseI64ArrayAttr(sizes), rewriter.getDenseI64ArrayAttr(strides)); @@ -361,7 +361,7 @@ public: Value padConstant = rewriter.createOrFold<tensor::ExtractOp>( loc, padOp.getPadConst(), - ValueRange({rewriter.create<arith::ConstantIndexOp>(loc, 0)})); + ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)})); if (!padConstant) { return rewriter.notifyMatchFailure( @@ -375,16 +375,16 @@ public: highValues.reserve(rank); for (int i = 0; i < rank; i++) { - Value lowVal = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr(paddingVals[2 * i])); - Value highVal = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); + Value lowVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i])); + Value highVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); lowValues.push_back(lowVal); highValues.push_back(highVal); } - auto newPadOp = rewriter.create<tensor::PadOp>( - loc, padOp.getType(), input, lowValues, highValues, padConstant); + auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input, + lowValues, highValues, padConstant); rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); @@ -402,7 +402,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { Location loc = op.getLoc(); int axis = op.getAxis(); Value axisValue = - rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis)); int64_t rank = resultType.getRank(); SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); @@ -439,8 +439,9 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { } } - Value result = rewriter.create<tensor::EmptyOp>( - loc, resultType.getShape(), resultType.getElementType(), dynDims); + Value result = + tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType(), dynDims); for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index d6f9495..9efa34a 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -226,22 +226,22 @@ struct BroadcastOpToArmSMELowering (srcVectorType && (srcVectorType.getRank() == 0))) { // Broadcast scalar or 0-d vector to 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - broadcastOp1D = rewriter.create<vector::BroadcastOp>( - loc, tileSliceType, broadcastOp.getSource()); + broadcastOp1D = vector::BroadcastOp::create(rewriter, loc, tileSliceType, + broadcastOp.getSource()); } else if (srcVectorType && (srcVectorType.getRank() == 1)) // Value to broadcast is already a 1-d vector, nothing to do. broadcastOp1D = broadcastOp.getSource(); else return failure(); - auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.insert_tile_slice' to broadcast the value // to each tile slice. - auto nextTile = b.create<arm_sme::InsertTileSliceOp>( - loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); + auto nextTile = arm_sme::InsertTileSliceOp::create( + b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; @@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering } }; -/// Conversion pattern for vector.splat. -/// -/// Example: -/// -/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> -/// -/// is converted to: -/// -/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> -/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices -/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) -/// { -/// %tile_update = arm_sme.insert_tile_slice -/// %broadcast_to_1d, %iter_tile[%tile_slice_index] : -/// vector<[4]xi32> into vector<[4]x[4]xi32> -/// scf.yield %tile_update : vector<[4]x[4]xi32> -/// } -/// -/// This is identical to vector.broadcast of a scalar. -struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { - using OpRewritePattern<vector::SplatOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::SplatOp splatOp, - PatternRewriter &rewriter) const final { - auto tileType = splatOp.getResult().getType(); - if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) - return failure(); - - auto loc = splatOp.getLoc(); - auto srcType = splatOp.getOperand().getType(); - - assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); - // Avoid unused-variable warning when building without assertions. - (void)srcType; - - // First, broadcast the scalar to a 1-d vector. - VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( - loc, tileSliceType, splatOp.getInput()); - - auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); - - auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, - Value currentTile) { - auto nextTile = b.create<arm_sme::InsertTileSliceOp>( - loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); - return nextTile.getResult(); - }; - - // Next, create a loop over ZA tile slices and "move" the generated 1-d - // vector to each slice. - auto forOp = - createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); - - rewriter.replaceOp(splatOp, forOp.getResult(0)); - - return success(); - } -}; - /// Conversion pattern for vector.transpose. /// /// Stores the input tile to memory and reloads vertically. @@ -370,22 +310,22 @@ struct TransposeOpToArmSMELowering // Allocate buffer to store input tile to. Value vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); - Value minTileSlices = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexAttr(tileType.getDimSize(0))); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + Value minTileSlices = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(tileType.getDimSize(0))); Value c0 = - rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); Value numTileSlices = - rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices); + arith::MulIOp::create(rewriter, loc, vscale, minTileSlices); auto bufferType = MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, tileType.getElementType()); - auto buffer = rewriter.create<memref::AllocaOp>( - loc, bufferType, ValueRange{numTileSlices, numTileSlices}); + auto buffer = memref::AllocaOp::create( + rewriter, loc, bufferType, ValueRange{numTileSlices, numTileSlices}); // Store input tile. - auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>( - loc, input, buffer, ValueRange{c0, c0}); + auto tileStoreOp = arm_sme::TileStoreOp::create(rewriter, loc, input, + buffer, ValueRange{c0, c0}); // Reload input tile vertically. rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( @@ -488,10 +428,10 @@ struct VectorOuterProductToArmSMELowering Value rhsMaskDim = createMaskOp.getOperand(1); VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); - Value lhsMask = - rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim); - Value rhsMask = - rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim); + Value lhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType, + lhsMaskDim); + Value rhsMask = vector::CreateMaskOp::create(rewriter, loc, operandMaskType, + rhsMaskDim); return std::make_pair(lhsMask, rhsMask); } @@ -531,8 +471,8 @@ struct VectorExtractToArmSMELowering } Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); - auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( - loc, sourceVector, sliceIndex); + auto extractTileSlice = arm_sme::ExtractTileSliceOp::create( + rewriter, loc, sourceVector, sliceIndex); if (position.size() == 1) { // Single index case: Extracts a 1D slice. @@ -593,10 +533,10 @@ struct VectorInsertToArmSMELowering if (position.size() == 2) { // Two indices case: Insert single element into tile. // We need to first extract the existing slice and update the element. - tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( - loc, insertOp.getDest(), sliceIndex); - tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice, - position[1]); + tileSlice = arm_sme::ExtractTileSliceOp::create( + rewriter, loc, insertOp.getDest(), sliceIndex); + tileSlice = vector::InsertOp::create(rewriter, loc, source, tileSlice, + position[1]); } // Insert the slice into the destination tile. @@ -642,23 +582,24 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { auto loc = printOp.getLoc(); // Create a loop over the rows of the tile. - auto vscale = rewriter.create<vector::VectorScaleOp>(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto minTileRows = - rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0)); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, vectorType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = arith::MulIOp::create(rewriter, loc, minTileRows, vscale); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); { // Loop body. rewriter.setInsertionPointToStart(forOp.getBody()); // Extract the current row from the tile. Value rowIndex = forOp.getInductionVar(); - auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( - loc, printOp.getSource(), rowIndex); + auto tileSlice = arm_sme::ExtractTileSliceOp::create( + rewriter, loc, printOp.getSource(), rowIndex); // Print the row with a 1D vector.print. - rewriter.create<vector::PrintOp>(loc, tileSlice, - printOp.getPunctuation()); + vector::PrintOp::create(rewriter, loc, tileSlice, + printOp.getPunctuation()); } rewriter.eraseOp(printOp); @@ -707,8 +648,8 @@ struct FoldTransferWriteOfExtractTileSlice Value mask = writeOp.getMask(); if (!mask) { auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); - mask = rewriter.create<arith::ConstantOp>( - writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); + mask = arith::ConstantOp::create(rewriter, writeOp.getLoc(), maskType, + DenseElementsAttr::get(maskType, true)); } rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>( @@ -776,10 +717,10 @@ struct ExtractFromCreateMaskToPselLowering // Create the two 1-D masks at the location of the 2-D create_mask (which is // usually outside a loop). This prevents the need for later hoisting. rewriter.setInsertionPoint(createMaskOp); - auto rowMask = rewriter.create<vector::CreateMaskOp>( - loc, rowMaskType, createMaskOp.getOperand(0)); - auto colMask = rewriter.create<vector::CreateMaskOp>( - loc, colMaskType, createMaskOp.getOperand(1)); + auto rowMask = vector::CreateMaskOp::create(rewriter, loc, rowMaskType, + createMaskOp.getOperand(0)); + auto colMask = vector::CreateMaskOp::create(rewriter, loc, colMaskType, + createMaskOp.getOperand(1)); rewriter.setInsertionPoint(extractOp); auto position = @@ -790,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering } }; +// Convert all `vector.splat` to `vector.broadcast`. There is a path from +// `vector.broadcast` to ArmSME via another pattern. +struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> { + using OpRewritePattern<vector::SplatOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::SplatOp splatOp, + PatternRewriter &rewriter) const final { + + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), + splatOp.getInput()); + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, + patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast, TransferReadToArmSMELowering, TransferWriteToArmSMELowering, TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 18adaa7..1d1904f 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -31,10 +31,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-to-gpu" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU @@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op, // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { if (!supportsMMaMatrixType(op, useNvGpu)) { - LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); + LDBG() << "cannot convert op: " << *op; return true; } return false; @@ -412,22 +411,22 @@ struct PrepareContractToGPUMMA if (maps == infer({{m, k}, {k, n}, {m, n}})) return rewriter.notifyMatchFailure(op, "contraction already prepared"); if (maps == infer({{m, k}, {n, k}, {m, n}})) { - rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); - lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { std::swap(rhs, lhs); - rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); - lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { std::swap(rhs, lhs); - rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { std::swap(lhs, rhs); - lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else { @@ -482,25 +481,23 @@ struct CombineTransferReadOpTranspose final permutationMap.compose(transferReadOp.getPermutationMap()); auto loc = op.getLoc(); - Value result = - rewriter - .create<vector::TransferReadOp>( - loc, resultType, transferReadOp.getBase(), - transferReadOp.getIndices(), AffineMapAttr::get(newMap), - transferReadOp.getPadding(), transferReadOp.getMask(), - transferReadOp.getInBoundsAttr()) - .getResult(); + Value result = vector::TransferReadOp::create( + rewriter, loc, resultType, transferReadOp.getBase(), + transferReadOp.getIndices(), AffineMapAttr::get(newMap), + transferReadOp.getPadding(), transferReadOp.getMask(), + transferReadOp.getInBoundsAttr()) + .getResult(); // Fuse through the integer extend op. if (extOp) { if (isa<arith::ExtSIOp>(extOp)) - result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result) + result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result) .getResult(); else if (isa<arith::ExtUIOp>(extOp)) - result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result) + result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result) .getResult(); else - result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result) + result = arith::ExtFOp::create(rewriter, loc, op.getType(), result) .getResult(); } @@ -550,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } @@ -579,13 +576,13 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, } gpu::MMAMatrixType type = gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); - Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>( - op.getLoc(), type, op.getBase(), op.getIndices(), + Value load = gpu::SubgroupMmaLoadMatrixOp::create( + rewriter, op.getLoc(), type, op.getBase(), op.getIndices(), rewriter.getIndexAttr(*stride), isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; - LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); + LDBG() << "transfer read to: " << load; return success(); } @@ -599,25 +596,25 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no mapping\n"); + LDBG() << "no mapping"; return rewriter.notifyMatchFailure(op, "no mapping"); } Value matrix = it->second; - auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>( - op.getLoc(), matrix, op.getBase(), op.getIndices(), + auto store = gpu::SubgroupMmaStoreMatrixOp::create( + rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(), rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; - LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); + LDBG() << "transfer write to: " << store; - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -643,26 +640,26 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); if (!dense) { - LLVM_DEBUG(DBGS() << "not a splat\n"); + LDBG() << "not a splat"; return rewriter.notifyMatchFailure(op, "not a splat"); } - Value result = rewriter.create<arith::ConstantOp>( - op.getLoc(), vectorType, + Value result = arith::ConstantOp::create( + rewriter, op.getLoc(), vectorType, DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); valueMapping[op.getResult()] = result; return success(); @@ -679,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { mlir::AffineMap map = op.getPermutationMap(); if (map.getNumResults() != 2) { - LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " - "is not a 2d operand\n"); + LDBG() << "Failed because the result of `vector.transfer_read` " + "is not a 2d operand"; return failure(); } @@ -693,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { auto exprN = dyn_cast<AffineDimExpr>(dN); if (!exprM || !exprN) { - LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " - "expressions, then transpose cannot be determined.\n"); + LDBG() << "Failed because expressions are not affine dim " + "expressions, then transpose cannot be determined."; return failure(); } @@ -711,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } FailureOr<bool> transpose = isTransposed(op); if (failed(transpose)) { - LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); + LDBG() << "failed to determine the transpose"; return rewriter.notifyMatchFailure( op, "Op should likely not be converted to a nvgpu.ldmatrix call."); } @@ -733,21 +730,19 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); if (failed(params)) { - LLVM_DEBUG( - DBGS() - << "failed to convert vector.transfer_read to ldmatrix. " - << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); + LDBG() << "failed to convert vector.transfer_read to ldmatrix. " + << "Op should likely not be converted to a nvgpu.ldmatrix call."; return rewriter.notifyMatchFailure( op, "failed to convert vector.transfer_read to ldmatrix; this op " "likely should not be converted to a nvgpu.ldmatrix call."); } // Adjust the load offset. - auto laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); + auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); FailureOr<AffineMap> offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { - LLVM_DEBUG(DBGS() << "no offsets\n"); + LDBG() << "no offsets"; return rewriter.notifyMatchFailure(op, "no offsets"); } @@ -757,8 +752,9 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId}, indices); - nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>( - loc, vectorType, op.getBase(), indices, *transpose, params->numTiles); + nvgpu::LdMatrixOp newOp = + nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(), + indices, *transpose, params->numTiles); valueMapping[op] = newOp->getResult(0); return success(); } @@ -782,17 +778,17 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, "conversion to distributed non-ldmatrix compatible load"); } - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); // This is the individual element type. Type loadedElType = regInfo->registerLLVMType; VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value fill = rewriter.create<arith::ConstantOp>( - op.getLoc(), vectorType.getElementType(), + Value fill = arith::ConstantOp::create( + rewriter, op.getLoc(), vectorType.getElementType(), rewriter.getZeroAttr(vectorType.getElementType())); Value result = - rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType); + vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill); bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); @@ -809,16 +805,16 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, if (failed(coords)) return rewriter.notifyMatchFailure(op, "no coords"); - Value logicalValueId = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexType(), + Value logicalValueId = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); SmallVector<Value, 4> newIndices; getXferIndices<vector::TransferReadOp>( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = rewriter.create<vector::LoadOp>(loc, loadedElType, - op.getBase(), newIndices); - result = rewriter.create<vector::InsertOp>(loc, el, result, i); + Value el = vector::LoadOp::create(rewriter, loc, loadedElType, + op.getBase(), newIndices); + result = vector::InsertOp::create(rewriter, loc, el, result, i); } } else { if (auto vecType = dyn_cast<VectorType>(loadedElType)) { @@ -828,8 +824,8 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; innerIdx++) { - Value logicalValueId = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexType(), + Value logicalValueId = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( rewriter, op.getLoc(), *warpMatrixInfo); @@ -839,10 +835,10 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, SmallVector<Value, 4> newIndices; getXferIndices<vector::TransferReadOp>( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType, - op.getBase(), newIndices); - result = rewriter.create<vector::InsertOp>( - op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx}); + Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType, + op.getBase(), newIndices); + result = vector::InsertOp::create(rewriter, op.getLoc(), el, result, + ArrayRef<int64_t>{i, innerIdx}); } } } @@ -916,11 +912,11 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, return rewriter.notifyMatchFailure(op, "not mma sync reg info"); VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value laneId = rewriter.create<gpu::LaneIdOp>(loc, /*upperBound=*/nullptr); + Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr); for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { - Value logicalValueId = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIndexType(), + Value logicalValueId = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( rewriter, op.getLoc(), *warpMatrixInfo); @@ -928,14 +924,14 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, return rewriter.notifyMatchFailure(op, "no coords"); Value el = - rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); + vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef<int64_t>{i}); SmallVector<Value, 4> newIndices; getXferIndices<vector::TransferWriteOp>( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices); + vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1015,8 +1011,8 @@ convertExtractStridedSlice(RewriterBase &rewriter, else if (offsets[1]) sliceOffset[0] = (warpVectorShape[1] / offsets[1]); - Value newOp = rewriter.create<vector::ExtractStridedSliceOp>( - loc, sourceVector, sliceOffset, sliceShape, strides); + Value newOp = vector::ExtractStridedSliceOp::create( + rewriter, loc, sourceVector, sliceOffset, sliceShape, strides); valueMapping[op] = newOp; return success(); @@ -1035,9 +1031,10 @@ convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; - Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>( - op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), - /*b_transpose=*/UnitAttr()); + Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(), + opC.getType(), opA, opB, opC, + /*a_transpose=*/UnitAttr(), + /*b_transpose=*/UnitAttr()); valueMapping[op.getResult()] = matmul; return success(); } @@ -1058,8 +1055,8 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0]; int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0]; int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1]; - Value matmul = rewriter.create<nvgpu::MmaSyncOp>( - op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); + Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC, + rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; return success(); } @@ -1076,13 +1073,13 @@ convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, auto splat = cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>(); auto scalarConstant = - rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); + arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = cast<VectorType>(op.getType()); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( - op.getLoc(), type, scalarConstant); + auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(), + type, scalarConstant); valueMapping[op.getResult()] = matrix; return success(); } @@ -1100,8 +1097,8 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, auto vecType = op.getResultVectorType(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>( - op.getLoc(), type, op.getSource()); + auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(), + type, op.getSource()); valueMapping[op.getResult()] = matrix; return success(); } @@ -1118,9 +1115,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, rewriter.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getInitArgs()); llvm::append_range(operands, newInitArgs); - scf::ForOp newLoop = rewriter.create<scf::ForOp>( - loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), - operands); + scf::ForOp newLoop = + scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(), + loop.getUpperBound(), loop.getStep(), operands); rewriter.eraseBlock(newLoop.getBody()); newLoop.getRegion().getBlocks().splice( @@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, loop.getNumResults()))) rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); - LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); - LLVM_DEBUG(DBGS() << "erase: " << loop); + LDBG() << "newLoop now: " << newLoop; + LDBG() << "stripped scf.for: " << loop; + LDBG() << "erase: " << loop; rewriter.eraseOp(loop); return newLoop; @@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, for (const auto &operand : llvm::enumerate(op.getInitArgs())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); + LDBG() << "no value mapping for: " << operand.value(); continue; } argMapping.push_back(std::make_pair( @@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } - LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); + LDBG() << "scf.for to: " << newForOp; return success(); } @@ -1189,9 +1186,9 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; yieldOperands.push_back(it->second); } - rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands); + scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1220,8 +1217,8 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op, resultType.getOperand()); } - Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>( - op->getLoc(), resultType, matrixOperands, opType); + Value newOp = gpu::SubgroupMmaElementwiseOp::create( + rewriter, op->getLoc(), resultType, matrixOperands, opType); valueMapping[op->getResult(0)] = newOp; return success(); } @@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, auto globalRes = LogicalResult::success(); for (Operation *op : ops) { - LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); + LDBG() << "Process op: " << *op; // Apparently callers do not want to early exit on failure here. auto res = LogicalResult::success(); if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 501d988..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -29,7 +29,9 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/APFloat.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/Casting.h" + #include <optional> using namespace mlir; @@ -43,13 +45,13 @@ static Value insertOne(ConversionPatternRewriter &rewriter, assert(rank > 0 && "0-D vector corner case should have been handled already"); if (rank == 1) { auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter.convertType(idxType), + auto constant = LLVM::ConstantOp::create( + rewriter, loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, - constant); + return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2, + constant); } - return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos); + return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos); } // Helper that picks the proper sequence for extracting. @@ -58,13 +60,13 @@ static Value extractOne(ConversionPatternRewriter &rewriter, Value val, Type llvmType, int64_t rank, int64_t pos) { if (rank <= 1) { auto idxType = rewriter.getIndexType(); - auto constant = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter.convertType(idxType), + auto constant = LLVM::ConstantOp::create( + rewriter, loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); - return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val, - constant); + return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val, + constant); } - return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos); + return LLVM::ExtractValueOp::create(rewriter, loc, val, pos); } // Helper that returns data layout alignment of a vector. @@ -141,9 +143,9 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, auto ptrsType = LLVM::getVectorType(pType, vectorType.getDimSize(0), /*isScalable=*/vectorType.getScalableDims()[0]); - return rewriter.create<LLVM::GEPOp>( - loc, ptrsType, typeConverter.convertType(memRefType.getElementType()), - base, index); + return LLVM::GEPOp::create( + rewriter, loc, ptrsType, + typeConverter.convertType(memRefType.getElementType()), base, index); } /// Convert `foldResult` into a Value. Integer attribute is converted to @@ -152,7 +154,7 @@ static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult) { if (auto attr = dyn_cast<Attribute>(foldResult)) { auto intAttr = cast<IntegerAttr>(attr); - return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult(); + return LLVM::ConstantOp::create(builder, loc, intAttr).getResult(); } return cast<Value>(foldResult); @@ -184,41 +186,6 @@ public: } }; -/// Conversion pattern for a vector.matrix_multiply. -/// This is lowered directly to the proper llvm.intr.matrix.multiply. -class VectorMatmulOpConversion - : public ConvertOpToLLVMPattern<vector::MatmulOp> { -public: - using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( - matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), - adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), - matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); - return success(); - } -}; - -/// Conversion pattern for a vector.flat_transpose. -/// This is lowered directly to the proper llvm.intr.matrix.transpose. -class VectorFlatTransposeOpConversion - : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> { -public: - using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( - transOp, typeConverter->convertType(transOp.getRes().getType()), - adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); - return success(); - } -}; - /// Overloaded utility that replaces a vector.load, vector.store, /// vector.maskedload and vector.maskedstore with their respective LLVM /// couterparts. @@ -475,32 +442,32 @@ class ReductionNeutralFPMax {}; static Value createReductionNeutralValue(ReductionNeutralZero neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>(loc, llvmType, - rewriter.getZeroAttr(llvmType)); + return LLVM::ConstantOp::create(rewriter, loc, llvmType, + rewriter.getZeroAttr(llvmType)); } /// Create the reduction neutral integer one value. static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); + return LLVM::ConstantOp::create(rewriter, loc, llvmType, + rewriter.getIntegerAttr(llvmType, 1)); } /// Create the reduction neutral fp one value. static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); + return LLVM::ConstantOp::create(rewriter, loc, llvmType, + rewriter.getFloatAttr(llvmType, 1.0)); } /// Create the reduction neutral all-ones value. static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr( llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); } @@ -509,8 +476,8 @@ static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( llvmType.getIntOrFloatBitWidth()))); } @@ -519,8 +486,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( llvmType.getIntOrFloatBitWidth()))); } @@ -529,8 +496,8 @@ static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( llvmType.getIntOrFloatBitWidth()))); } @@ -539,8 +506,8 @@ static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( llvmType.getIntOrFloatBitWidth()))); } @@ -550,8 +517,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { auto floatType = cast<FloatType>(llvmType); - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getFloatAttr( llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), /*Negative=*/false))); @@ -562,8 +529,8 @@ static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { auto floatType = cast<FloatType>(llvmType); - return rewriter.create<LLVM::ConstantOp>( - loc, llvmType, + return LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getFloatAttr( llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), /*Negative=*/true))); @@ -591,19 +558,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, auto vShape = vType.getShape(); assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); - Value baseVecLength = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI32Type(), + Value baseVecLength = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); if (!vType.getScalableDims()[0]) return baseVecLength; // For a scalable vector type, create and return `vScale * baseVecLength`. - Value vScale = rewriter.create<vector::VectorScaleOp>(loc); + Value vScale = vector::VectorScaleOp::create(rewriter, loc); vScale = - rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale); + arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale); Value scalableVecLength = - rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale); + arith::MulIOp::create(rewriter, loc, baseVecLength, vScale); return scalableVecLength; } @@ -616,10 +583,11 @@ static Value createIntegerReductionArithmeticOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator) { - Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); + Value result = + LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand); if (accumulator) - result = rewriter.create<ScalarOp>(loc, accumulator, result); + result = ScalarOp::create(rewriter, loc, accumulator, result); return result; } @@ -631,11 +599,12 @@ template <class LLVMRedIntrinOp> static Value createIntegerReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { - Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand); + Value result = + LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand); if (accumulator) { Value cmp = - rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result); - result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result); + LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result); + result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result); } return result; } @@ -666,12 +635,11 @@ static Value createFPReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { Value result = - rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf); + LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf); if (accumulator) { - result = - rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>( - loc, result, accumulator); + result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create( + rewriter, loc, result, accumulator); } return result; @@ -702,7 +670,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics(); auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value); - return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue); + return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue); } /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked @@ -717,8 +685,8 @@ lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, Value mask, LLVM::FastmathFlagsAttr fmf) { const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>( rewriter, loc, llvmType, vectorOperand.getType()); - const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>( - loc, mask, vectorOperand, vectorMaskNeutral); + const Value selectedVectorByMask = LLVM::SelectOp::create( + rewriter, loc, mask, vectorOperand, vectorMaskNeutral); return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>( rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); } @@ -730,9 +698,9 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, Value accumulator, LLVM::FastmathFlagsAttr fmf) { accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, llvmType, accumulator); - return rewriter.create<LLVMRedIntrinOp>(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand, fmf); + return LLVMRedIntrinOp::create(rewriter, loc, llvmType, + /*startValue=*/accumulator, vectorOperand, + fmf); } /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic @@ -745,9 +713,8 @@ lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, Value vectorOperand, Value accumulator) { accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, llvmType, accumulator); - return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand); + return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, + /*startValue=*/accumulator, vectorOperand); } template <class LLVMVPRedIntrinOp, class ReductionNeutral> @@ -758,9 +725,9 @@ static Value lowerPredicatedReductionWithStartValue( llvmType, accumulator); Value vectorLength = createVectorLengthValue(rewriter, loc, vectorOperand.getType()); - return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand, mask, vectorLength); + return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, + /*startValue=*/accumulator, vectorOperand, + mask, vectorLength); } template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral, @@ -1071,8 +1038,8 @@ public: // For rank 0 and 1, where both operands have *exactly* the same vector // type, there is direct shuffle support in LLVM. Use it! if (rank <= 1 && v1Type == v2Type) { - Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>( - loc, adaptor.getV1(), adaptor.getV2(), + Value llvmShuffleOp = LLVM::ShuffleVectorOp::create( + rewriter, loc, adaptor.getV1(), adaptor.getV2(), llvm::to_vector_of<int32_t>(mask)); rewriter.replaceOp(shuffleOp, llvmShuffleOp); return success(); @@ -1085,7 +1052,7 @@ public: eltType = arrayType.getElementType(); else eltType = cast<VectorType>(llvmType).getElementType(); - Value insert = rewriter.create<LLVM::PoisonOp>(loc, llvmType); + Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType); int64_t insPos = 0; for (int64_t extPos : mask) { Value value = adaptor.getV1(); @@ -1103,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1193,13 +1127,14 @@ public: if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) { return failure(); } - extracted = rewriter.create<LLVM::ExtractValueOp>( - loc, extracted, getAsIntegers(position)); + extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted, + getAsIntegers(position)); } if (extractsScalar) { - extracted = rewriter.create<LLVM::ExtractElementOp>( - loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back())); + extracted = LLVM::ExtractElementOp::create( + rewriter, loc, extracted, + getAsLLVMValue(rewriter, loc, positionVec.back())); } rewriter.replaceOp(extractOp, extracted); @@ -1238,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -1342,8 +1244,8 @@ public: // llvm.extractvalue does not support dynamic dimensions. return failure(); } - sourceAggregate = rewriter.create<LLVM::ExtractValueOp>( - loc, adaptor.getDest(), + sourceAggregate = LLVM::ExtractValueOp::create( + rewriter, loc, adaptor.getDest(), getAsIntegers(positionOf1DVectorWithinAggregate)); } else { // No-aggregate case. The destination for the InsertElementOp is just @@ -1351,16 +1253,16 @@ public: sourceAggregate = adaptor.getDest(); } // Insert the scalar into the 1D vector. - sourceAggregate = rewriter.create<LLVM::InsertElementOp>( - loc, sourceAggregate.getType(), sourceAggregate, + sourceAggregate = LLVM::InsertElementOp::create( + rewriter, loc, sourceAggregate.getType(), sourceAggregate, adaptor.getValueToStore(), getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector)); } Value result = sourceAggregate; if (isNestedAggregate) { - result = rewriter.create<LLVM::InsertValueOp>( - loc, adaptor.getDest(), sourceAggregate, + result = LLVM::InsertValueOp::create( + rewriter, loc, adaptor.getDest(), sourceAggregate, getAsIntegers(positionOf1DVectorWithinAggregate)); } @@ -1408,7 +1310,7 @@ struct VectorScalableExtractOpLowering /// ``` /// is rewritten into: /// ``` -/// %r = splat %f0: vector<2x4xf32> +/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32> /// %va = vector.extractvalue %a[0] : vector<2x4xf32> /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> @@ -1439,15 +1341,15 @@ public: auto loc = op.getLoc(); auto elemType = vType.getElementType(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, elemType, rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero); + Value zero = arith::ConstantOp::create(rewriter, loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { - Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i); - Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i); - Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i); - Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC); - desc = rewriter.create<InsertOp>(loc, fma, desc, i); + Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i); + Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i); + Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i); + Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC); + desc = InsertOp::create(rewriter, loc, fma, desc, i); } rewriter.replaceOp(op, desc); return success(); @@ -1537,7 +1439,7 @@ public: desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); - auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); + auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr); desc.setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. @@ -1546,11 +1448,12 @@ public: int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); - auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); + auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), (*targetStrides)[index]); - auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); + auto stride = + LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } @@ -1578,14 +1481,15 @@ public: IntegerType idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); auto loc = op->getLoc(); - Value indices = rewriter.create<LLVM::StepVectorOp>( - loc, LLVM::getVectorType(idxType, dstType.getShape()[0], - /*isScalable=*/true)); + Value indices = LLVM::StepVectorOp::create( + rewriter, loc, + LLVM::getVectorType(idxType, dstType.getShape()[0], + /*isScalable=*/true)); auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, adaptor.getOperands()[0]); - Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound); - Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, - indices, bounds); + Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound); + Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + indices, bounds); rewriter.replaceOp(op, comp); return success(); } @@ -1741,16 +1645,16 @@ private: switch (conversion) { case PrintConversion::ZeroExt64: - value = rewriter.create<arith::ExtUIOp>( - loc, IntegerType::get(rewriter.getContext(), 64), value); + value = arith::ExtUIOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::SignExt64: - value = rewriter.create<arith::ExtSIOp>( - loc, IntegerType::get(rewriter.getContext(), 64), value); + value = arith::ExtSIOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::Bitcast16: - value = rewriter.create<LLVM::BitcastOp>( - loc, IntegerType::get(rewriter.getContext(), 16), value); + value = LLVM::BitcastOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value); break; case PrintConversion::None: break; @@ -1762,68 +1666,83 @@ private: // Helper to emit a call. static void emitCall(ConversionPatternRewriter &rewriter, Location loc, Operation *ref, ValueRange params = ValueRange()) { - rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref), - params); + LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref), + params); } }; -/// The Splat operation is lowered to an insertelement + a shufflevector -/// operation. Splat to only 0-d and 1-d vector result types are lowered. -struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> { - using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern; +/// A broadcast of a scalar is lowered to an insertelement + a shufflevector +/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this +/// pattern, the higher rank cases are handled by another pattern. +struct VectorBroadcastScalarToLowRankLowering + : public ConvertOpToLLVMPattern<vector::BroadcastOp> { + using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = cast<VectorType>(splatOp.getType()); + if (isa<VectorType>(broadcast.getSourceType())) + return rewriter.notifyMatchFailure( + broadcast, "broadcast from vector type not handled"); + + VectorType resultType = broadcast.getType(); if (resultType.getRank() > 1) - return failure(); + return rewriter.notifyMatchFailure(broadcast, + "broadcast to 2+-d handled elsewhere"); // First insert it into a poison vector so we can shuffle it. - auto vectorType = typeConverter->convertType(splatOp.getType()); + auto vectorType = typeConverter->convertType(broadcast.getType()); Value poison = - rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType); - auto zero = rewriter.create<LLVM::ConstantOp>( - splatOp.getLoc(), + LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType); + auto zero = LLVM::ConstantOp::create( + rewriter, broadcast.getLoc(), typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); // For 0-d vector, we simply do `insertelement`. if (resultType.getRank() == 0) { rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - splatOp, vectorType, poison, adaptor.getInput(), zero); + broadcast, vectorType, poison, adaptor.getSource(), zero); return success(); } // For 1-d vector, we additionally do a `vectorshuffle`. - auto v = rewriter.create<LLVM::InsertElementOp>( - splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero); + auto v = + LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, + poison, adaptor.getSource(), zero); - int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0); + int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0); SmallVector<int32_t> zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison, + rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison, zeroValues); return success(); } }; -/// The Splat operation is lowered to an insertelement + a shufflevector -/// operation. Splat to only 2+-d vector result types are lowered by the -/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. -struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { - using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern; +/// The broadcast of a scalar is lowered to an insertelement + a shufflevector +/// operation. Only broadcasts to 2+-d vector result types are lowered by this +/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors +/// are not converted to LLVM, only broadcasts from scalars are. +struct VectorBroadcastScalarToNdLowering + : public ConvertOpToLLVMPattern<BroadcastOp> { + using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, + matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = splatOp.getType(); + if (isa<VectorType>(broadcast.getSourceType())) + return rewriter.notifyMatchFailure( + broadcast, "broadcast from vector type not handled"); + + VectorType resultType = broadcast.getType(); if (resultType.getRank() <= 1) - return failure(); + return rewriter.notifyMatchFailure( + broadcast, "broadcast to 1-d or 0-d handled elsewhere"); // First insert it into an undef vector so we can shuffle it. - auto loc = splatOp.getLoc(); + auto loc = broadcast.getLoc(); auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; @@ -1832,28 +1751,28 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { return failure(); // Construct returned value. - Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy); + Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy); - // Construct a 1-D vector with the splatted value that we insert in all the - // places within the returned descriptor. - Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy); - auto zero = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter->convertType(rewriter.getIntegerType(32)), + // Construct a 1-D vector with the broadcasted value that we insert in all + // the places within the returned descriptor. + Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy); + auto zero = LLVM::ConstantOp::create( + rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); - Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc, - adaptor.getInput(), zero); + Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy, + vdesc, adaptor.getSource(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector<int32_t> zeroValues(width, 0); - v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues); + v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues); - // Iterate of linear index, convert to coords space and insert splatted 1-D - // vector in each position. + // Iterate of linear index, convert to coords space and insert broadcasted + // 1-D vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { - desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position); + desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position); }); - rewriter.replaceOp(splatOp, desc); + rewriter.replaceOp(broadcast, desc); return success(); } }; @@ -1921,13 +1840,13 @@ struct VectorDeinterleaveOpLowering auto deinterleaveResults = deinterleaveOp.getResultTypes(); auto packedOpResults = llvmTypeConverter->packOperationResults(deinterleaveResults); - auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>( - loc, packedOpResults, adaptor.getSource()); + auto intrinsic = LLVM::vector_deinterleave2::create( + rewriter, loc, packedOpResults, adaptor.getSource()); - auto evenResult = rewriter.create<LLVM::ExtractValueOp>( - loc, intrinsic->getResult(0), 0); - auto oddResult = rewriter.create<LLVM::ExtractValueOp>( - loc, intrinsic->getResult(0), 1); + auto evenResult = LLVM::ExtractValueOp::create( + rewriter, loc, intrinsic->getResult(0), 0); + auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc, + intrinsic->getResult(0), 1); rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult}); return success(); @@ -1950,11 +1869,11 @@ struct VectorDeinterleaveOpLowering oddShuffleMask.push_back(i); } - auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType); - auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>( - loc, adaptor.getSource(), poison, evenShuffleMask); - auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>( - loc, adaptor.getSource(), poison, oddShuffleMask); + auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType); + auto evenShuffle = LLVM::ShuffleVectorOp::create( + rewriter, loc, adaptor.getSource(), poison, evenShuffleMask); + auto oddShuffle = LLVM::ShuffleVectorOp::create( + rewriter, loc, adaptor.getSource(), poison, oddShuffleMask); rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle}); return success(); @@ -1977,9 +1896,9 @@ struct VectorFromElementsLowering return rewriter.notifyMatchFailure(fromElementsOp, "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); - Value result = rewriter.create<LLVM::PoisonOp>(loc, llvmType); + Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) - result = rewriter.create<vector::InsertOp>(loc, val, result, idx); + result = vector::InsertOp::create(rewriter, loc, val, result, idx); rewriter.replaceOp(fromElementsOp, result); return success(); } @@ -2003,12 +1922,12 @@ struct VectorToElementsLowering if (element.use_empty()) continue; - auto constIdx = rewriter.create<LLVM::ConstantOp>( - loc, idxType, rewriter.getIntegerAttr(idxType, idx)); + auto constIdx = LLVM::ConstantOp::create( + rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx)); auto llvmType = typeConverter->convertType(element.getType()); - Value result = rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, - source, constIdx); + Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType, + source, constIdx); results[idx] = result; } @@ -2035,6 +1954,196 @@ struct VectorScalableStepOpLowering } }; +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`. +// +/// This only kicks in when vectorContractLowering is set to Matmul and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToMatmulOpLowering + : public vector::MaskableOpRewritePattern<vector::ContractionOp> { +public: + using MaskableOpRewritePattern::MaskableOpRewritePattern; + + ContractionOpToMatmulOpLowering( + vector::VectorContractLowering vectorContractLowering, + MLIRContext *context, PatternBenefit benefit = 100) + : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {} + + FailureOr<Value> + matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override; +}; + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %mta = maybe_transpose +/// %mtb = maybe_transpose +/// %flattened_a = vector.shape_cast %mta +/// %flattened_b = vector.shape_cast %mtb +/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b +/// %mtd = vector.shape_cast %flattened_d +/// %d = maybe_untranspose %mtd +/// %e = add %c, %d +/// ``` +// +/// This only kicks in when vectorContractLowering is set to `Matmul`. +/// vector.transpose operations are inserted if the vector.contract op is not a +/// row-major matrix multiply. +/// +/// Scalable vectors are not supported. +FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp( + vector::ContractionOp op, MaskingOpInterface maskOp, + PatternRewriter &rew) const { + // TODO: Support vector.mask. + if (maskOp) + return failure(); + + auto iteratorTypes = op.getIteratorTypes().getValue(); + if (!isParallelIterator(iteratorTypes[0]) || + !isParallelIterator(iteratorTypes[1]) || + !isReductionIterator(iteratorTypes[2])) + return failure(); + + Type opResType = op.getType(); + VectorType vecType = dyn_cast<VectorType>(opResType); + if (vecType && vecType.isScalable()) { + // Note - this is sufficient to reject all cases with scalable vectors. + return failure(); + } + + Type elementType = op.getLhsType().getElementType(); + if (!elementType.isIntOrFloat()) + return failure(); + + Type dstElementType = vecType ? vecType.getElementType() : opResType; + if (elementType != dstElementType) + return failure(); + + // Perform lhs + rhs transpositions to conform to matmul row-major semantics. + // Bail out if the contraction cannot be put in this form. + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + AffineExpr m, n, k; + bindDims(rew.getContext(), m, n, k); + // LHS must be A(m, k) or A(k, m). + Value lhs = op.getLhs(); + auto lhsMap = op.getIndexingMapsArray()[0]; + if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) + lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0}); + else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) + return failure(); + + // RHS must be B(k, n) or B(n, k). + Value rhs = op.getRhs(); + auto rhsMap = op.getIndexingMapsArray()[1]; + if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) + rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0}); + else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) + return failure(); + + // At this point lhs and rhs are in row-major. + VectorType lhsType = cast<VectorType>(lhs.getType()); + VectorType rhsType = cast<VectorType>(rhs.getType()); + int64_t lhsRows = lhsType.getDimSize(0); + int64_t lhsColumns = lhsType.getDimSize(1); + int64_t rhsColumns = rhsType.getDimSize(1); + + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs); + + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs); + + Value mul = LLVM::MatrixMultiplyOp::create( + rew, loc, + VectorType::get(lhsRows * rhsColumns, + cast<VectorType>(lhs.getType()).getElementType()), + lhs, rhs, lhsRows, lhsColumns, rhsColumns); + + mul = vector::ShapeCastOp::create( + rew, loc, + VectorType::get({lhsRows, rhsColumns}, + getElementTypeOrSelf(op.getAcc().getType())), + mul); + + // ACC must be C(m, n) or C(n, m). + auto accMap = op.getIndexingMapsArray()[2]; + if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) + mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0}); + else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) + llvm_unreachable("invalid contraction semantics"); + + Value res = isa<IntegerType>(elementType) + ? static_cast<Value>( + arith::AddIOp::create(rew, loc, op.getAcc(), mul)) + : static_cast<Value>( + arith::AddFOp::create(rew, loc, op.getAcc(), mul)); + + return res; +} + +/// Lowers vector.transpose to llvm.intr.matrix.transpose +class TransposeOpToMatrixTransposeOpLowering + : public OpRewritePattern<vector::TransposeOp> { +public: + using OpRewritePattern<TransposeOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value input = op.getVector(); + VectorType inputType = op.getSourceVectorType(); + VectorType resType = op.getResultVectorType(); + + if (inputType.isScalable()) + return rewriter.notifyMatchFailure( + op, "This lowering does not support scalable vectors"); + + // Set up convenience transposition table. + ArrayRef<int64_t> transp = op.getPermutation(); + + if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) { + return failure(); + } + + Type flattenedType = + VectorType::get(resType.getNumElements(), resType.getElementType()); + auto matrix = + vector::ShapeCastOp::create(rewriter, loc, flattenedType, input); + auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); + auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); + Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType, + matrix, rows, columns); + rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); + return success(); + } +}; + +/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from +/// `vector.broadcast` through other patterns. +struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(), + adaptor.getInput()); + return success(); + } +}; + } // namespace void mlir::vector::populateVectorRankReducingFMAPattern( @@ -2042,6 +2151,17 @@ void mlir::vector::populateVectorRankReducingFMAPattern( patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext()); } +void mlir::vector::populateVectorContractToMatrixMultiply( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit); +} + +void mlir::vector::populateVectorTransposeToFlatTranspose( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(), + benefit); +} + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, @@ -2058,12 +2178,12 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatOpLowering, VectorSplatNdOpLowering, + VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering, + VectorBroadcastScalarToNdLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering, VectorDeinterleaveOpLowering, VectorFromElementsLowering, @@ -2071,12 +2191,6 @@ void mlir::populateVectorToLLVMConversionPatterns( converter); } -void mlir::populateVectorToLLVMMatrixConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<VectorMatmulOpConversion>(converter); - patterns.add<VectorFlatTransposeOpConversion>(converter); -} - namespace { struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 549d021..cf10869 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -70,10 +70,22 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorBitCastLoweringPatterns(patterns); populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, vectorContractLowering); + if (vectorContractLowering == vector::VectorContractLowering::Matmul) { + // This pattern creates a dependency on the LLVM dialect, hence we don't + // include it in `populateVectorContractLoweringPatterns` that is part of + // the Vector dialect (and should not depend on LLVM). + populateVectorContractToMatrixMultiply(patterns); + } populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); populateVectorInterleaveLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering); + if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) { + // This pattern creates a dependency on the LLVM dialect, hence we don't + // include it in `populateVectorTransposeLoweringPatterns` that is part of + // the Vector dialect (and should not depend on LLVM). + populateVectorTransposeToFlatTranspose(patterns); + } // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); populateVectorMaskMaterializationPatterns(patterns, @@ -84,9 +96,15 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorGatherLoweringPatterns(patterns); if (armI8MM) { if (armNeon) - arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns); + arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); + if (armSVE) + populateLowerContractionToSVEI8MMPatterns(patterns); + } + if (armBF16) { + if (armNeon) + arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns); if (armSVE) - populateLowerContractionToSVEI8MMPatternPatterns(patterns); + populateLowerContractionToSVEBFMMLAPatterns(patterns); } (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } @@ -96,11 +114,9 @@ void ConvertVectorToLLVMPass::runOnOperation() { LLVMTypeConverter converter(&getContext(), options); RewritePatternSet patterns(&getContext()); populateVectorTransferLoweringPatterns(patterns); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns( converter, patterns, reassociateFPReductions, force32BitVectorIndices, useVectorAlignment); - populateVectorToLLVMMatrixConversionPatterns(converter, patterns); // Architecture specific augmentations. LLVMConversionTarget target(getContext()); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index ed51b21..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -24,7 +24,6 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -132,9 +131,9 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, Value value) { if (hasRetVal) { assert(value && "Expected non-empty value"); - b.create<scf::YieldOp>(loc, value); + scf::YieldOp::create(b, loc, value); } else { - b.create<scf::YieldOp>(loc); + scf::YieldOp::create(b, loc); } } @@ -154,7 +153,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { return Value(); Location loc = xferOp.getLoc(); - return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv); + return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv); } /// Helper function TransferOpConversion and TransferOp1dConversion. @@ -201,22 +200,22 @@ static Value generateInBoundsCheck( Value base = xferOp.getIndices()[*dim]; Value memrefIdx = affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); - cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, - memrefIdx); + cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim, + memrefIdx); } // Condition check 2: Masked in? if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { if (cond) - cond = lb.create<arith::AndIOp>(cond, maskCond); + cond = arith::AndIOp::create(lb, cond, maskCond); else cond = maskCond; } // If the condition is non-empty, generate an SCF::IfOp. if (cond) { - auto check = lb.create<scf::IfOp>( - cond, + auto check = scf::IfOp::create( + lb, cond, /*thenBuilder=*/ [&](OpBuilder &b, Location loc) { maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); @@ -226,7 +225,7 @@ static Value generateInBoundsCheck( if (outOfBoundsCase) { maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); } else { - b.create<scf::YieldOp>(loc); + scf::YieldOp::create(b, loc); } }); @@ -303,14 +302,15 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { BufferAllocs result; auto bufferType = MemRefType::get({}, xferOp.getVectorType()); - result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); + result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType); if (xferOp.getMask()) { auto maskType = MemRefType::get({}, xferOp.getMask().getType()); - auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); + auto maskBuffer = memref::AllocaOp::create(b, loc, maskType); b.setInsertionPoint(xferOp); - b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer); - result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange()); + memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer); + result.maskBuffer = + memref::LoadOp::create(b, loc, maskBuffer, ValueRange()); } return result; @@ -421,14 +421,15 @@ struct Strategy<TransferReadOp> { auto bufferType = dyn_cast<ShapedType>(buffer.getType()); auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - auto newXferOp = b.create<vector::TransferReadOp>( - loc, vecType, xferOp.getBase(), xferIndices, + auto newXferOp = vector::TransferReadOp::create( + b, loc, vecType, xferOp.getBase(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.getPadding(), Value(), inBoundsAttr); maybeApplyPassLabel(b, newXferOp, options.targetRank); - b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices); + memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer, + storeIndices); return newXferOp; } @@ -444,8 +445,9 @@ struct Strategy<TransferReadOp> { Location loc = xferOp.getLoc(); auto bufferType = dyn_cast<ShapedType>(buffer.getType()); auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); - auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding()); - b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); + auto vec = + vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding()); + memref::StoreOp::create(b, loc, vec, buffer, storeIndices); return Value(); } @@ -506,12 +508,12 @@ struct Strategy<TransferWriteOp> { getXferIndices(b, xferOp, iv, xferIndices); Location loc = xferOp.getLoc(); - auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); + auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto source = loopState.empty() ? xferOp.getBase() : loopState[0]; Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); - auto newXferOp = b.create<vector::TransferWriteOp>( - loc, type, vec, source, xferIndices, + auto newXferOp = vector::TransferWriteOp::create( + b, loc, type, vec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); @@ -610,8 +612,8 @@ struct PrepareTransferReadConversion } Location loc = xferOp.getLoc(); - rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), - buffers.dataBuffer); + memref::StoreOp::create(rewriter, loc, newXfer->getResult(0), + buffers.dataBuffer); rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); return success(); @@ -653,9 +655,9 @@ struct PrepareTransferWriteConversion Location loc = xferOp.getLoc(); auto buffers = allocBuffers(rewriter, xferOp); - rewriter.create<memref::StoreOp>(loc, xferOp.getVector(), - buffers.dataBuffer); - auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); + memref::StoreOp::create(rewriter, loc, xferOp.getVector(), + buffers.dataBuffer); + auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer); rewriter.modifyOpInPlace(xferOp, [&]() { xferOp.getValueToStoreMutable().assign(loadedVec); xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); @@ -688,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -735,17 +737,17 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { auto signlessTargetVectorType = vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy)); auto targetVectorType = vectorType.cloneWith({}, legalIntTy); - value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType, - value); + value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType, + value); if (value.getType() != signlessTargetVectorType) { if (width == 1 || intTy.isUnsigned()) - value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType, - value); + value = arith::ExtUIOp::create(rewriter, loc, + signlessTargetVectorType, value); else - value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType, - value); + value = arith::ExtSIOp::create(rewriter, loc, + signlessTargetVectorType, value); } - value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value); + value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value); vectorType = targetVectorType; } @@ -762,29 +764,30 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { std::multiplies<int64_t>()); auto flatVectorType = VectorType::get({flatLength}, vectorType.getElementType()); - value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value); + value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value); } vector::PrintOp firstClose; SmallVector<Value, 8> loopIndices; for (unsigned d = 0; d < shape.size(); d++) { // Setup loop bounds and step. - Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]); - Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1); + Value lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value upperBound = + arith::ConstantIndexOp::create(rewriter, loc, shape[d]); + Value step = arith::ConstantIndexOp::create(rewriter, loc, 1); if (!scalableDimensions.empty() && scalableDimensions[d]) { - auto vscale = rewriter.create<vector::VectorScaleOp>( - loc, rewriter.getIndexType()); - upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale); + auto vscale = vector::VectorScaleOp::create(rewriter, loc, + rewriter.getIndexType()); + upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale); } - auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step); + auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step); // Create a loop to print the elements surrounded by parentheses. - rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); auto loop = - rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); - auto printClose = rewriter.create<vector::PrintOp>( - loc, vector::PrintPunctuation::Close); + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); + auto printClose = vector::PrintOp::create( + rewriter, loc, vector::PrintPunctuation::Close); if (!firstClose) firstClose = printClose; @@ -793,14 +796,14 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { // Print a comma after all but the last element. rewriter.setInsertionPointToStart(loop.getBody()); - auto notLastIndex = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); - rewriter.create<scf::IfOp>(loc, notLastIndex, - [&](OpBuilder &builder, Location loc) { - builder.create<vector::PrintOp>( - loc, vector::PrintPunctuation::Comma); - builder.create<scf::YieldOp>(loc); - }); + auto notLastIndex = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); + scf::IfOp::create(rewriter, loc, notLastIndex, + [&](OpBuilder &builder, Location loc) { + vector::PrintOp::create( + builder, loc, vector::PrintPunctuation::Comma); + scf::YieldOp::create(builder, loc); + }); rewriter.setInsertionPointToStart(loop.getBody()); } @@ -810,22 +813,23 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { Value flatIndex; auto currentStride = 1; for (int d = shape.size() - 1; d >= 0; d--) { - auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride); - auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]); + auto stride = + arith::ConstantIndexOp::create(rewriter, loc, currentStride); + auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]); if (flatIndex) - flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index); + flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index); else flatIndex = index; currentStride *= shape[d]; } // Print the scalar elements in the inner most loop. - auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex); - rewriter.create<vector::PrintOp>(loc, element, - vector::PrintPunctuation::NoPunctuation); + auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex); + vector::PrintOp::create(rewriter, loc, element, + vector::PrintPunctuation::NoPunctuation); rewriter.setInsertionPointAfter(firstClose); - rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation()); + vector::PrintOp::create(rewriter, loc, printOp.getPunctuation()); rewriter.eraseOp(printOp); return success(); } @@ -916,7 +920,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { "Failed to unpack one vector dim."); auto castedDataBuffer = - locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer); + vector::TypeCastOp::create(locB, *castedDataType, dataBuffer); // If the xferOp has a mask: Find and cast mask buffer. Value castedMaskBuffer; @@ -935,22 +939,22 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { auto maskBufferType = cast<MemRefType>(maskBuffer.getType()); MemRefType castedMaskType = *unpackOneDim(maskBufferType); castedMaskBuffer = - locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); + vector::TypeCastOp::create(locB, castedMaskType, maskBuffer); } } // Loop bounds and step. - auto lb = locB.create<arith::ConstantIndexOp>(0); - auto ub = locB.create<arith::ConstantIndexOp>( - castedDataType->getDimSize(castedDataType->getRank() - 1)); - auto step = locB.create<arith::ConstantIndexOp>(1); + auto lb = arith::ConstantIndexOp::create(locB, 0); + auto ub = arith::ConstantIndexOp::create( + locB, castedDataType->getDimSize(castedDataType->getRank() - 1)); + auto step = arith::ConstantIndexOp::create(locB, 1); // TransferWriteOps that operate on tensors return the modified tensor and // require a loop state. auto loopState = Strategy<OpTy>::initialLoopState(xferOp); // Generate for loop. - auto result = locB.create<scf::ForOp>( - lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), + auto result = scf::ForOp::create( + locB, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { Type stateType = loopState.empty() ? Type() : loopState[0].getType(); @@ -975,8 +979,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> { SmallVector<Value, 8> loadIndices; getMaskBufferLoadIndices(xferOp, castedMaskBuffer, loadIndices, iv); - auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, - loadIndices); + auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer, + loadIndices); rewriter.modifyOpInPlace(newXfer, [&]() { newXfer.getMaskMutable().assign(mask); }); @@ -1119,30 +1123,30 @@ struct ScalableTransposeTransferWriteConversion auto transposeSource = transposeOp.getVector(); SmallVector<Value> transposeSourceSlices = llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { - return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx); + return vector::ExtractOp::create(rewriter, loc, transposeSource, idx); }); // Loop bounds and step. - auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0); auto ub = maskDims->empty() ? Value(createVscaleMultiple(vectorType.getDimSize(0))) : vector::getAsValues(rewriter, loc, maskDims->front()).front(); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); // Generate a new mask for the slice. VectorType sliceType = VectorType::Builder(vectorType).dropDim(0); Value sliceMask = nullptr; if (!maskDims->empty()) { - sliceMask = rewriter.create<vector::CreateMaskOp>( - loc, sliceType.clone(rewriter.getI1Type()), + sliceMask = vector::CreateMaskOp::create( + rewriter, loc, sliceType.clone(rewriter.getI1Type()), ArrayRef<OpFoldResult>(*maskDims).drop_front()); } Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{}; ValueRange initLoopArgs = initDest ? initDest : ValueRange{}; - auto result = rewriter.create<scf::ForOp>( - loc, lb, ub, step, initLoopArgs, + auto result = scf::ForOp::create( + rewriter, loc, lb, ub, step, initLoopArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) { // Indices for the new transfer op. SmallVector<Value, 8> xferIndices; @@ -1151,25 +1155,25 @@ struct ScalableTransposeTransferWriteConversion // Extract a transposed slice from the source vector. SmallVector<Value> transposeElements = llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { - return b.create<vector::ExtractOp>( - loc, transposeSourceSlices[idx], iv); + return vector::ExtractOp::create( + b, loc, transposeSourceSlices[idx], iv); }); - auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType, - transposeElements); + auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType, + transposeElements); // Create the transfer_write for the slice. Value dest = loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front(); - auto newWriteOp = b.create<vector::TransferWriteOp>( - loc, sliceVec, dest, xferIndices, + auto newWriteOp = vector::TransferWriteOp::create( + b, loc, sliceVec, dest, xferIndices, ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()); if (sliceMask) newWriteOp.getMaskMutable().assign(sliceMask); // Yield from the loop. - b.create<scf::YieldOp>(loc, loopIterArgs.empty() - ? ValueRange{} - : newWriteOp.getResult()); + scf::YieldOp::create(b, loc, + loopIterArgs.empty() ? ValueRange{} + : newWriteOp.getResult()); }); if (isTensorOp(writeOp)) @@ -1207,7 +1211,7 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, llvm::SmallVector<int64_t, 1> indices({i}); Location loc = xferOp.getLoc(); - auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices); + auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices); newXferOp.getMaskMutable().assign(newMask); } @@ -1261,8 +1265,8 @@ struct UnrollTransferReadConversion if (auto insertOp = getInsertOp(xferOp)) return insertOp.getDest(); Location loc = xferOp.getLoc(); - return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), - xferOp.getPadding()); + return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(), + xferOp.getPadding()); } /// If the result of the TransferReadOp has exactly one user, which is a @@ -1317,7 +1321,7 @@ struct UnrollTransferReadConversion // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); + Value iv = arith::ConstantIndexOp::create(rewriter, loc, i); // FIXME: Rename this lambda - it does much more than just // in-bounds-check generation. @@ -1336,8 +1340,8 @@ struct UnrollTransferReadConversion auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - auto newXferOp = b.create<vector::TransferReadOp>( - loc, newXferVecType, xferOp.getBase(), xferIndices, + auto newXferOp = vector::TransferReadOp::create( + b, loc, newXferVecType, xferOp.getBase(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.getPadding(), Value(), inBoundsAttr); maybeAssignMask(b, xferOp, newXferOp, i); @@ -1346,11 +1350,11 @@ struct UnrollTransferReadConversion if (newXferVecType.getRank() == 0) { // vector.insert does not accept rank-0 as the non-indexed // argument. Extract the scalar before inserting. - valToInser = b.create<vector::ExtractOp>(loc, valToInser, - SmallVector<int64_t>()); + valToInser = vector::ExtractOp::create(b, loc, valToInser, + SmallVector<int64_t>()); } - return b.create<vector::InsertOp>(loc, valToInser, vec, - insertionIndices); + return vector::InsertOp::create(b, loc, valToInser, vec, + insertionIndices); }, /*outOfBoundsCase=*/ [&](OpBuilder &b, Location loc) { @@ -1460,7 +1464,7 @@ struct UnrollTransferWriteConversion // Generate fully unrolled loop of transfer ops. Location loc = xferOp.getLoc(); for (int64_t i = 0; i < dimSize; ++i) { - Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); + Value iv = arith::ConstantIndexOp::create(rewriter, loc, i); auto updatedSource = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), @@ -1477,20 +1481,20 @@ struct UnrollTransferWriteConversion extractionIndices.push_back(b.getI64IntegerAttr(i)); auto extracted = - b.create<vector::ExtractOp>(loc, vec, extractionIndices); + vector::ExtractOp::create(b, loc, vec, extractionIndices); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); Value xferVec; if (inputVectorTy.getRank() == 1) { // When target-rank=0, unrolling would causes the vector input // argument into `transfer_write` to become a scalar. We solve // this by broadcasting the scalar to a 0D vector. - xferVec = b.create<vector::BroadcastOp>( - loc, VectorType::get({}, extracted.getType()), extracted); + xferVec = vector::BroadcastOp::create( + b, loc, VectorType::get({}, extracted.getType()), extracted); } else { xferVec = extracted; } - auto newXferOp = b.create<vector::TransferWriteOp>( - loc, sourceType, xferVec, source, xferIndices, + auto newXferOp = vector::TransferWriteOp::create( + b, loc, sourceType, xferVec, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), inBoundsAttr); @@ -1572,19 +1576,19 @@ struct Strategy1d<TransferReadOp> { b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), /*inBoundsCase=*/ [&](OpBuilder &b, Location loc) { - Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices); - return b.create<vector::InsertOp>(loc, val, vec, iv); + Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices); + return vector::InsertOp::create(b, loc, val, vec, iv); }, /*outOfBoundsCase=*/ [&](OpBuilder & /*b*/, Location loc) { return vec; }); - b.create<scf::YieldOp>(loc, nextVec); + scf::YieldOp::create(b, loc, nextVec); } static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { // Inititalize vector with padding value. Location loc = xferOp.getLoc(); - return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), - xferOp.getPadding()); + return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(), + xferOp.getPadding()); } }; @@ -1601,10 +1605,10 @@ struct Strategy1d<TransferWriteOp> { generateInBoundsCheck( b, xferOp, iv, dim, /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { - auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv); - b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices); + auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv); + memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices); }); - b.create<scf::YieldOp>(loc); + scf::YieldOp::create(b, loc); } static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { @@ -1639,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` @@ -1665,15 +1669,15 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { // Loop bounds, step, state... Location loc = xferOp.getLoc(); auto vecType = xferOp.getVectorType(); - auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0); Value ub = - rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); + arith::ConstantIndexOp::create(rewriter, loc, vecType.getDimSize(0)); if (vecType.isScalable()) { Value vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); - ub = rewriter.create<arith::MulIOp>(loc, ub, vscale); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + ub = arith::MulIOp::create(rewriter, loc, ub, vscale); } - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); // Generate for loop. diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 21d8e1d..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -79,6 +79,20 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> { } }; +// Convert `vector.splat` to `vector.broadcast`. There is a path from +// `vector.broadcast` to SPIRV via other patterns. +struct VectorSplatToBroadcast final + : public OpConversionPattern<vector::SplatOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(), + adaptor.getInput()); + return success(); + } +}; + struct VectorBitcastConvert final : public OpConversionPattern<vector::BitCastOp> { using OpConversionPattern::OpConversionPattern; @@ -147,19 +161,19 @@ static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter, Location loc, Value dynamicIndex, int64_t kPoisonIndex, unsigned vectorSize) { if (llvm::isPowerOf2_32(vectorSize)) { - Value inBoundsMask = rewriter.create<spirv::ConstantOp>( - loc, dynamicIndex.getType(), + Value inBoundsMask = spirv::ConstantOp::create( + rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1)); - return rewriter.create<spirv::BitwiseAndOp>(loc, dynamicIndex, - inBoundsMask); + return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex, + inBoundsMask); } - Value poisonIndex = rewriter.create<spirv::ConstantOp>( - loc, dynamicIndex.getType(), + Value poisonIndex = spirv::ConstantOp::create( + rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex)); Value cmpResult = - rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex); - return rewriter.create<spirv::SelectOp>( - loc, cmpResult, + spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex); + return spirv::SelectOp::create( + rewriter, loc, cmpResult, spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter), dynamicIndex); } @@ -321,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -427,8 +384,8 @@ static SmallVector<Value> extractAllElements( Location loc = reduceOp.getLoc(); for (int i = 0; i < numElements; ++i) { - values.push_back(rewriter.create<spirv::CompositeExtractOp>( - loc, srcVectorType.getElementType(), adaptor.getVector(), + values.push_back(spirv::CompositeExtractOp::create( + rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(), rewriter.getI32ArrayAttr({i}))); } if (Value acc = adaptor.getAcc()) @@ -481,16 +438,16 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ case vector::CombiningKind::kind: \ if (llvm::isa<IntegerType>(resultType)) { \ - result = rewriter.create<spirv::iop>(loc, resultType, result, next); \ + result = spirv::iop::create(rewriter, loc, resultType, result, next); \ } else { \ assert(llvm::isa<FloatType>(resultType)); \ - result = rewriter.create<spirv::fop>(loc, resultType, result, next); \ + result = spirv::fop::create(rewriter, loc, resultType, result, next); \ } \ break #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create<fop>(loc, resultType, result, next); \ + result = fop::create(rewriter, loc, resultType, result, next); \ break INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); @@ -537,7 +494,7 @@ struct VectorReductionFloatMinMax final #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create<fop>(loc, resultType, result, next); \ + result = fop::create(rewriter, loc, resultType, result, next); \ break INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); @@ -556,22 +513,27 @@ struct VectorReductionFloatMinMax final } }; -class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> { +class VectorScalarBroadcastPattern final + : public OpConversionPattern<vector::BroadcastOp> { public: - using OpConversionPattern<vector::SplatOp>::OpConversionPattern; + using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern; LogicalResult - matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, + matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (isa<VectorType>(op.getSourceType())) { + return rewriter.notifyMatchFailure( + op, "only conversion of 'broadcast from scalar' is supported"); + } Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); if (isa<spirv::ScalarType>(dstType)) { - rewriter.replaceOp(op, adaptor.getInput()); + rewriter.replaceOp(op, adaptor.getSource()); } else { auto dstVecType = cast<VectorType>(dstType); SmallVector<Value, 4> source(dstVecType.getNumElements(), - adaptor.getInput()); + adaptor.getSource()); rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType, source); } @@ -613,8 +575,8 @@ struct VectorShuffleOpConvert final auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( Value scalarOrVec, int32_t idx) -> Value { if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType())) - return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec, - idx); + return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec, + idx); assert(idx == 0 && "Invalid scalar element index"); return scalarOrVec; @@ -712,11 +674,13 @@ struct VectorDeinterleaveOpConvert final // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to // use `spirv::CompositeExtractOp`. if (n == 2) { - auto elem0 = rewriter.create<spirv::CompositeExtractOp>( - loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0})); + auto elem0 = spirv::CompositeExtractOp::create( + rewriter, loc, newResultType, sourceVector, + rewriter.getI32ArrayAttr({0})); - auto elem1 = rewriter.create<spirv::CompositeExtractOp>( - loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1})); + auto elem1 = spirv::CompositeExtractOp::create( + rewriter, loc, newResultType, sourceVector, + rewriter.getI32ArrayAttr({1})); rewriter.replaceOp(deinterleaveOp, {elem0, elem1}); return success(); @@ -733,12 +697,12 @@ struct VectorDeinterleaveOpConvert final llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; }); // Create two SPIR-V shuffles. - auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>( - loc, newResultType, sourceVector, sourceVector, + auto shuffleEven = spirv::VectorShuffleOp::create( + rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesEven)); - auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>( - loc, newResultType, sourceVector, sourceVector, + auto shuffleOdd = spirv::VectorShuffleOp::create( + rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesOdd)); rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd}); @@ -774,15 +738,19 @@ struct VectorLoadOpConverter final // Use the converted vector type instead of original (single element vector // would get converted to scalar). auto spirvVectorType = typeConverter.convertType(vectorType); + if (!spirvVectorType) + return rewriter.notifyMatchFailure(loadOp, "unsupported vector type"); + auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass); // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. - Value castedAccessChain = (vectorType.getNumElements() == 1) - ? accessChain - : rewriter.create<spirv::BitcastOp>( - loc, vectorPtrType, accessChain); + Value castedAccessChain = + (vectorType.getNumElements() == 1) + ? accessChain + : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, + accessChain); rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType, castedAccessChain); @@ -821,10 +789,11 @@ struct VectorStoreOpConverter final // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. - Value castedAccessChain = (vectorType.getNumElements() == 1) - ? accessChain - : rewriter.create<spirv::BitcastOp>( - loc, vectorPtrType, accessChain); + Value castedAccessChain = + (vectorType.getNumElements() == 1) + ? accessChain + : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, + accessChain); rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain, adaptor.getValueToStore()); @@ -905,10 +874,10 @@ private: auto v4i8Type = VectorType::get({4}, i8Type); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); - lhsIn = rewriter.create<spirv::CompositeConstructOp>( - loc, v4i8Type, ValueRange{lhsIn, zero}); - rhsIn = rewriter.create<spirv::CompositeConstructOp>( - loc, v4i8Type, ValueRange{rhsIn, zero}); + lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, + ValueRange{lhsIn, zero}); + rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, + ValueRange{rhsIn, zero}); } // There's no variant of dot prod ops for unsigned LHS and signed RHS, so @@ -971,14 +940,14 @@ struct VectorReductionToFPDotProd final Attribute oneAttr = rewriter.getFloatAttr(vectorType.getElementType(), 1.0); oneAttr = SplatElementsAttr::get(vectorType, oneAttr); - rhs = rewriter.create<spirv::ConstantOp>(loc, vectorType, oneAttr); + rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr); } assert(lhs); assert(rhs); - Value res = rewriter.create<spirv::DotOp>(loc, resultType, lhs, rhs); + Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs); if (acc) - res = rewriter.create<spirv::FAddOp>(loc, acc, res); + res = spirv::FAddOp::create(rewriter, loc, acc, res); rewriter.replaceOp(op, res); return success(); @@ -1013,7 +982,8 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> { source.reserve(numElements); for (int64_t i = 0; i < numElements; ++i) { Attribute intAttr = rewriter.getIntegerAttr(intType, i); - Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr); + Value constOp = + spirv::ConstantOp::create(rewriter, loc, intType, intAttr); source.push_back(constOp); } rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType, @@ -1056,8 +1026,8 @@ struct VectorToElementOpConvert final if (element.use_empty()) continue; - Value result = rewriter.create<spirv::CompositeExtractOp>( - loc, elementType, adaptor.getSource(), + Value result = spirv::CompositeExtractOp::create( + rewriter, loc, elementType, adaptor.getSource(), rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)})); results[idx] = result; } @@ -1080,20 +1050,19 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, - VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, - VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter, - VectorStepOpConvert>(typeConverter, patterns.getContext(), - PatternBenefit(1)); + VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert, + VectorShuffleOpConvert, VectorInterleaveOpConvert, + VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern, + VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>( + typeConverter, patterns.getContext(), PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 2e6a16d..8010755 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -108,15 +108,15 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { - ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src, - getAsOpFoldResult(offsets)); + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, + getAsOpFoldResult(offsets)); } else { // In case of any dynamic shapes, source's shape and strides have to be // explicitly provided. SmallVector<Value> sourceDims; unsigned srcRank = srcTy.getRank(); for (unsigned i = 0; i < srcRank; ++i) - sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i)); + sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); SmallVector<int64_t> constOffsets; SmallVector<Value> dynOffsets; @@ -135,18 +135,18 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, // Compute strides in reverse order. SmallVector<Value> dynStrides; - Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1); + Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); // Last stride is guaranteed to be static and unit. for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) { accStride = - rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]); + arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); if (strides[i] == ShapedType::kDynamic) dynStrides.push_back(accStride); } std::reverse(dynStrides.begin(), dynStrides.end()); - ndDesc = rewriter.create<xegpu::CreateNdDescOp>( - loc, descType, src, dynOffsets, dynShapes, dynStrides, + ndDesc = xegpu::CreateNdDescOp::create( + rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), DenseI64ArrayAttr::get(rewriter.getContext(), strides)); @@ -200,10 +200,10 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { ArrayRef<int64_t>{1, 0}); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadOp = rewriter.create<xegpu::LoadNdOp>( - loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, + /*packed=*/nullptr, transposeAttr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(readOp, loadOp); return success(); @@ -238,9 +238,9 @@ struct TransferWriteLowering // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeOp = - rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(writeOp, storeOp); return success(); @@ -269,8 +269,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> { // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadNdOp = rewriter.create<xegpu::LoadNdOp>( - loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, + auto loadNdOp = xegpu::LoadNdOp::create( + rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(loadOp, loadNdOp); @@ -303,9 +303,9 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> { // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; auto storeNdOp = - rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint); rewriter.replaceOp(storeOp, storeNdOp); return success(); @@ -339,8 +339,9 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> { if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps"); - auto dpasOp = rewriter.create<xegpu::DpasOp>( - loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); + auto dpasOp = xegpu::DpasOp::create(rewriter, loc, + TypeRange{contractOp.getResultType()}, + ValueRange{lhs, rhs, acc}); rewriter.replaceOp(contractOp, dpasOp); return success(); diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index a8380b9..4dfcb2b 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -10,7 +10,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" @@ -251,7 +250,7 @@ static LLVM::CallOp createDeviceFunctionCall( for (auto [idx, attrName] : paramAttrs) funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); - auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args); + auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args); callOp->setAttrs(funcOp->getAttrs()); return callOp; @@ -299,7 +298,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { VectorType newTy = VectorType::get( vecBitSize / packedType.getIntOrFloatBitWidth(), packedType); if (origTy != newTy) - val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val); + val = LLVM::BitcastOp::create(rewriter, loc, newTy, val); return val; }; @@ -326,7 +325,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { : cOrigTy; VectorType resTy = cTy; if (cOrigTy != cTy) - c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c); + c = LLVM::BitcastOp::create(rewriter, loc, cTy, c); constexpr int32_t systolicDepth{8}; std::string fnName = @@ -352,7 +351,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { ->getResult(0); if (resOrigTy != resTy) - result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result); + result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result); rewriter.replaceOp(op, result); return success(); @@ -383,7 +382,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> { auto loc = op.getLoc(); const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; Value one = - rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1); SmallVector<Value> args{op.getPtr(), one}; SmallVector<Type> argTypes; for (auto arg : args) @@ -439,11 +438,11 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> { op, "Fence only supports workgroup and device memory scopes."); } Type i32Type = rewriter.getI32Type(); - Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4); + Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4); Value memScopeConst = - rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope); + LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope); Value addrSpaceConst = - rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace); + LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace); SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst}; SmallVector<Type> argTypes{3, i32Type}; createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), @@ -477,13 +476,13 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { auto i32Type = rewriter.getI32Type(); Value byteCoord = - rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type)); - Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0); - Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1); - byteCoord = rewriter.create<LLVM::InsertElementOp>( - loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); - byteCoord = rewriter.create<LLVM::InsertElementOp>( - loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); + LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type)); + Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0); + Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1); + byteCoord = LLVM::InsertElementOp::create( + rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); + byteCoord = LLVM::InsertElementOp::create( + rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), op.getBasePitch(), byteCoord}; SmallVector<Type> retTypes; @@ -504,11 +503,11 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { } else { auto vecElemType = vecType.getElementType(); auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); - Value numElems = rewriter.create<LLVM::ConstantOp>( - loc, i32Type, vecType.getNumElements()); - auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, - numElems); + Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type, + vecType.getNumElements()); + auto dstOrSrcPtr = LLVM::AllocaOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + vecElemType, numElems); args.push_back(dstOrSrcPtr); if constexpr (isLoad) { // Load funcName += "read"; @@ -530,7 +529,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { bitWidthId = (vecElemBitWidth == 32) ? "j" : ((vecElemBitWidth == 16) ? "t" : "h"); - rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr); + LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr); paramAttrs = { std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), @@ -563,7 +562,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { } if constexpr (isLoad) rewriter.replaceOp( - op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr)); + op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr)); else rewriter.eraseOp(op); return success(); |