diff options
Diffstat (limited to 'mlir/lib/Conversion')
52 files changed, 1295 insertions, 1260 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 309476c..64720bf 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, if (i32 == valTy) return val; return valTy.getWidth() > 32 - ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val)) - : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val)); + ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) + : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); } static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type i32 = rewriter.getI32Type(); - return rewriter.create<LLVM::ConstantOp>(loc, i32, value); + return LLVM::ConstantOp::create(rewriter, loc, i32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value) { Type llvmI1 = rewriter.getI1Type(); - return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value); + return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); } /// Returns the linear index used to access an element in the memref. @@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, ShapedType::isDynamic(stride) ? convertUnsignedToI32(rewriter, loc, memRefDescriptor.stride(rewriter, loc, i)) - : rewriter.create<LLVM::ConstantOp>(loc, i32, stride); - increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue); + : LLVM::ConstantOp::create(rewriter, loc, i32, stride); + increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); } - index = - index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; + index = index ? LLVM::AddOp::create(rewriter, loc, index, increment) + : increment; } return index ? index : createI32Constant(rewriter, loc, 0); } @@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); - Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride); + Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); maxIndex = maxIndex - ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim) + ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) : maxThisDim; } Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); - return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst); + return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst); } static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, @@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value stride; if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { Value cacheStrideZext = - rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride); - Value swizzleBit = rewriter.create<LLVM::ConstantOp>( - loc, i16, rewriter.getI16IntegerAttr(1 << 14)); - stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit, - /*isDisjoint=*/true); + LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); + Value swizzleBit = LLVM::ConstantOp::create( + rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); + stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, + /*isDisjoint=*/true); } else { - stride = rewriter.create<LLVM::ConstantOp>(loc, i16, - rewriter.getI16IntegerAttr(0)); + stride = LLVM::ConstantOp::create(rewriter, loc, i16, + rewriter.getI16IntegerAttr(0)); } // Get the number of elements. // Flag word: @@ -209,20 +209,21 @@ struct FatRawBufferCastLowering : descriptor.alignedPtr(rewriter, loc); Value offset = adaptor.getResetOffset() - ? rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(0)) + ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0)) : descriptor.offset(rewriter, loc); bool hasSizes = memrefType.getRank() > 0; // No need to unpack() and pack() all the individual sizes and strides, // so we'll just extract the arrays. - Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>( - loc, descriptor, kSizePosInMemRefDescriptor) - : Value{}; - Value strides = hasSizes - ? rewriter.create<LLVM::ExtractValueOp>( - loc, descriptor, kStridePosInMemRefDescriptor) - : Value{}; + Value sizes = hasSizes + ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kSizePosInMemRefDescriptor) + : Value{}; + Value strides = + hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kStridePosInMemRefDescriptor) + : Value{}; Value fatPtr = makeBufferRsrc( rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), @@ -231,17 +232,17 @@ struct FatRawBufferCastLowering Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset, - kOffsetPosInMemRefDescriptor); + SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor}; + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, + kAlignedPtrPosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, + kOffsetPosInMemRefDescriptor); if (hasSizes) { - result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes, - kSizePosInMemRefDescriptor); - result = rewriter.create<LLVM::InsertValueOp>( - loc, result, strides, kStridePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, + kSizePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, + kStridePosInMemRefDescriptor); } rewriter.replaceOp(op, result); return success(); @@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { SmallVector<Value, 6> args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForStore = - rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData); + Value castForStore = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); @@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForCmp = rewriter.create<LLVM::BitcastOp>( - loc, llvmBufferValType, atomicCmpData); + Value castForCmp = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { args.push_back(atomicCmpData); @@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset(); indexOffset && *indexOffset > 0) { Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); - voffset = - voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst) - : extraOffsetConst; + voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset, + extraOffsetConst) + : extraOffsetConst; } - voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst); + voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); args.push_back(voffset); // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst); + sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) @@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> { llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(), llvmBufferValType); - Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args, - ArrayRef<NamedAttribute>()); + Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, + ArrayRef<NamedAttribute>()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { - replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType, - replacement); + replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, + replacement); } rewriter.replaceOp(gpuOp, replacement); } else { @@ -480,16 +481,16 @@ struct MemoryCounterWaitOpLowering if (chipset.majorVersion >= 12) { Location loc = op.getLoc(); if (std::optional<int> ds = adaptor.getDs()) - rewriter.create<ROCDL::WaitDscntOp>(loc, *ds); + ROCDL::WaitDscntOp::create(rewriter, loc, *ds); if (std::optional<int> load = adaptor.getLoad()) - rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load); + ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); if (std::optional<int> store = adaptor.getStore()) - rewriter.create<ROCDL::WaitStorecntOp>(loc, *store); + ROCDL::WaitStorecntOp::create(rewriter, loc, *store); if (std::optional<int> exp = adaptor.getExp()) - rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp); + ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); rewriter.eraseOp(op); return success(); @@ -571,12 +572,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> { << chipset.majorVersion; Location loc = op->getLoc(); - rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits); + ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits); rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op); } else { Location loc = op->getLoc(); - rewriter.create<ROCDL::WaitDscntOp>(loc, 0); - rewriter.create<ROCDL::BarrierSignalOp>(loc, -1); + ROCDL::WaitDscntOp::create(rewriter, loc, 0); + ROCDL::BarrierSignalOp::create(rewriter, loc, -1); rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1); } @@ -622,19 +623,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); if (auto vectorType = dyn_cast<VectorType>(inputType)) { if (vectorType.getElementType().isBF16() && !allowBf16) - return rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && vectorType.getNumElements() <= 8) - return rewriter.create<LLVM::BitcastOp>( - loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); + return LLVM::BitcastOp::create( + rewriter, loc, + rewriter.getIntegerType(vectorType.getNumElements() * 8), input); if (isa<IntegerType>(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); - return rewriter.create<LLVM::BitcastOp>( - loc, VectorType::get(numWords, rewriter.getI32Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), + input); } } return input; @@ -655,8 +658,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); Type outputType = rewriter.getI32Type(); if (auto intType = dyn_cast<IntegerType>(inputType)) - return rewriter.create<LLVM::ZExtOp>(loc, outputType, input); - return rewriter.create<LLVM::BitcastOp>(loc, outputType, input); + return LLVM::ZExtOp::create(rewriter, loc, outputType, input); + return LLVM::BitcastOp::create(rewriter, loc, outputType, input); } /// Push an input operand. If it is a float type, nothing to do. If it is @@ -682,8 +685,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - llvmInput = rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), llvmInput); + llvmInput = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -719,7 +722,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. // Add in the zeros here. if (numBits < 32) - castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput); + castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); operands.push_back(castInput); } @@ -739,8 +742,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, auto vectorType = dyn_cast<VectorType>(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - output = rewriter.create<LLVM::BitcastOp>( - loc, vectorType.clone(rewriter.getI16Type()), output); + output = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); @@ -1098,7 +1101,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { }; Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) - lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered); + lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } @@ -1198,8 +1201,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { Operation *maybeCastBack = lowered; if (rawOutType != outType) - maybeCastBack = - rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0)); + maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType, + lowered->getResult(0)); rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); @@ -1249,22 +1252,22 @@ struct TransposeLoadOpLowering switch (elementTypeSize) { case 4: { assert(numElements == 16); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } case 6: { assert(numElements == 16); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } case 8: { assert(numElements == 8); - auto rocdlOp = - rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp); break; } @@ -1422,21 +1425,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { - Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8); + Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); if (!sourceVecType) { - longVec = rewriter.create<LLVM::InsertElementOp>( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (resultVecType) { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source, @@ -1488,21 +1491,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( // Extend to a packedVectorType if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { - Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType); + Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); if (!sourceVecType) { - longVec = rewriter.create<LLVM::InsertElementOp>( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>( @@ -1560,54 +1563,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( Value scale = adaptor.getScale(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); else - existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType); + existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); if (sourceVecType.getNumElements() < 2) { Value c0 = createI32Constant(rewriter, loc, 0); - Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); + Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); VectorType v2 = VectorType::get(2, sourceElemType); - source = rewriter.create<LLVM::ZeroOp>(loc, v2); - source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0); + source = LLVM::ZeroOp::create(rewriter, loc, v2); + source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); } Value sourceA, sourceB; if (sourceElemType.isF32()) { Value c0 = createI32Constant(rewriter, loc, 0); Value c1 = createI32Constant(rewriter, loc, 1); - sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0); - sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1); + sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); + sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); } Value result; if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType)) - result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else return failure(); @@ -1632,20 +1638,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) - sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType()); + sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create<LLVM::UndefOp>(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( op, getTypeConverter()->convertType(resultType), result); @@ -1669,17 +1675,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create<LLVM::UndefOp>(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtSrBf8F32Op>( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create<ROCDL::CvtSrFp8F32Op>( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( op, getTypeConverter()->convertType(resultType), result); @@ -1723,14 +1729,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { if (operandType.getIntOrFloatBitWidth() <= 16) { if (llvm::isa<FloatType>(operandType)) { operand = - rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand); + LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); } auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); - Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType); - operand = rewriter.create<LLVM::InsertElementOp>( - loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); - operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand); + Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); + operand = + LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, + createI32Constant(rewriter, loc, 0)); + operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); } return operand; }; @@ -1817,14 +1824,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue(); // create a ROCDL_DPPMovOp instruction with the appropriate attributes - auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>( - loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); + auto dppMovOp = + ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, + rowMask, bankMask, boundCtrl); Value result = dppMovOp.getRes(); if (srcType.getIntOrFloatBitWidth() < 32) { - result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result); + result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); if (!llvm::isa<IntegerType>(srcType)) { - result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result); + result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); } } @@ -1858,7 +1866,7 @@ struct AMDGPUSwizzleBitModeLowering SmallVector<Value> swizzled; for (Value v : decomposed) { Value res = - rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue); + ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); swizzled.emplace_back(res); } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 3b143ca..3b148f9 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc, Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { if (predicate == arith::CmpIPredicate::sgt) - value = builder.create<arith::MaxSIOp>(loc, value, *valueIt); + value = arith::MaxSIOp::create(builder, loc, value, *valueIt); else - value = builder.create<arith::MinSIOp>(loc, value, *valueIt); + value = arith::MinSIOp::create(builder, loc, value, *valueIt); } return value; @@ -154,9 +154,9 @@ public: Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); Value step = - rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt()); - auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, - step, op.getInits()); + arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt()); + auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, + step, op.getInits()); rewriter.eraseBlock(scfForOp.getBody()); rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(), scfForOp.getRegion().end()); @@ -197,7 +197,7 @@ public: } steps.reserve(op.getSteps().size()); for (int64_t step : op.getSteps()) - steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step)); + steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step)); // Get the terminator op. auto affineParOpTerminator = @@ -205,9 +205,9 @@ public: scf::ParallelOp parOp; if (op.getResults().empty()) { // Case with no reduction operations/return values. - parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple, - upperBoundTuple, steps, - /*bodyBuilderFn=*/nullptr); + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, + upperBoundTuple, steps, + /*bodyBuilderFn=*/nullptr); rewriter.eraseBlock(parOp.getBody()); rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(), parOp.getRegion().end()); @@ -233,9 +233,9 @@ public: identityVals.push_back( arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); } - parOp = rewriter.create<scf::ParallelOp>( - loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, - /*bodyBuilderFn=*/nullptr); + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, + upperBoundTuple, steps, identityVals, + /*bodyBuilderFn=*/nullptr); // Copy the body of the affine.parallel op. rewriter.eraseBlock(parOp.getBody()); @@ -261,7 +261,7 @@ public: Value reductionResult = arith::getReductionOp( reductionOpValue, rewriter, loc, reductionBody.getArgument(0), reductionBody.getArgument(1)); - rewriter.create<scf::ReduceReturnOp>(loc, reductionResult); + scf::ReduceReturnOp::create(rewriter, loc, reductionResult); } rewriter.replaceOp(op, parOp.getResults()); return success(); @@ -278,7 +278,7 @@ public: // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector<Value, 8> operands(op.getOperands()); auto operandsRef = llvm::ArrayRef(operands); @@ -298,18 +298,18 @@ public: auto pred = isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; Value cmpVal = - rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant); - cond = cond - ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult() - : cmpVal; + arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant); + cond = + cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult() + : cmpVal; } cond = cond ? cond - : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1, - /*width=*/1); + : arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, + /*width=*/1); bool hasElseRegion = !op.getElseRegion().empty(); - auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond, - hasElseRegion); + auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond, + hasElseRegion); rewriter.inlineRegionBefore(op.getThenRegion(), &ifOp.getThenRegion().back()); rewriter.eraseBlock(&ifOp.getThenRegion().back()); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 156c679..8230591 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { using OpRewritePattern::OpRewritePattern; Chipset chipset; - ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {} + ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset, + PatternBenefit benefit) + : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {} LogicalResult matchAndRewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; @@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> { struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> { bool saturateFP8 = false; TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8, - Chipset chipset) - : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8), - chipset(chipset) {} + Chipset chipset, PatternBenefit benefit) + : OpRewritePattern::OpRewritePattern(ctx, benefit), + saturateFP8(saturateFP8), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(arith::TruncFOp op, @@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> { using OpRewritePattern::OpRewritePattern; - ScalingExtFRewritePattern(MLIRContext *ctx) - : OpRewritePattern::OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const override; }; @@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> { using OpRewritePattern::OpRewritePattern; - ScalingTruncFRewritePattern(MLIRContext *ctx) - : OpRewritePattern::OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const override; }; @@ -115,9 +110,9 @@ static Value castF32To(Type desType, Value f32, Location loc, if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) - return rewriter.create<arith::TruncFOp>(loc, desType, f32); + return arith::TruncFOp::create(rewriter, loc, desType, f32); if (elementType.getIntOrFloatBitWidth() > 32) - return rewriter.create<arith::ExtFOp>(loc, desType, f32); + return arith::ExtFOp::create(rewriter, loc, desType, f32); llvm_unreachable("The only 32-bit float type is f32"); } @@ -139,27 +134,27 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Type outElemType = getElementTypeOrSelf(op.getOut().getType()); VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { - Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, rewriter.getF32Type(), in, 0); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); rewriter.replaceOp(op, result); return success(); } int64_t numElements = inVecType.getNumElements(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); VectorType outType = cast<VectorType>(op.getOut().getType()); if (inVecType.getShape().empty()) { Value zerodSplat = rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero); Value scalarIn = - rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{}); Value scalarExt = - rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn); - Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat, - ArrayRef<int64_t>{}); + arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarExt, + zerodSplat, ArrayRef<int64_t>{}); rewriter.replaceOp(op, result); return success(); } @@ -171,32 +166,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, if (inVecType.getRank() > 1) { inVecType = VectorType::get(SmallVector<int64_t>{numElements}, inVecType.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; - Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, i, elemsThisOp, 1); + Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i, + elemsThisOp, 1); for (int64_t j = 0; j < elemsThisOp; j += 2) { if (i + j + 1 < numElements) { // Convert two 8-bit elements - Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, extResType, inSlice, j / 2); + Value asFloats = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, extResType, inSlice, j / 2); Type desType = VectorType::get(2, outElemType); Value asType = castF32To(desType, asFloats, loc, rewriter); - result = rewriter.create<vector::InsertStridedSliceOp>( - loc, asType, result, i + j, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, asType, + result, i + j, 1); } else { // Convert a 8-bit element - Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>( - loc, rewriter.getF32Type(), inSlice, j / 2 * 2); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2); Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j); + result = vector::InsertOp::create(rewriter, loc, asType, result, i + j); } } } if (inVecType.getRank() != outType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outType, result); } rewriter.replaceOp(op, result); @@ -208,9 +203,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { if (type.isF32()) return value; if (type.getIntOrFloatBitWidth() < 32) - return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value); + return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value); if (type.getIntOrFloatBitWidth() > 32) - return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value); + return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value); llvm_unreachable("The only 32-bit float type is f32"); } @@ -250,13 +245,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc, loc, arith::CmpFPredicate::OEQ, source, negInf); Value isNan = rewriter.createOrFold<arith::CmpFOp>( loc, arith::CmpFPredicate::UNO, source, source); - Value isNonFinite = rewriter.create<arith::OrIOp>( - loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan); + Value isNonFinite = arith::OrIOp::create( + rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), + isNan); - Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst); - Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst); + Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst); + Value clamped = + arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst); Value res = - rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped); + arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped); return res; } @@ -290,25 +287,25 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, VectorType truncResType = VectorType::get(4, outElemType); if (!inVectorTy) { Value asFloat = castToF32(in, loc, rewriter); - Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( - loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, + Value asF8s = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); - Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0); + Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0); rewriter.replaceOp(op, result); return success(); } int64_t numElements = outVecType.getNumElements(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); if (outVecType.getShape().empty()) { Value scalarIn = - rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarTrunc = - rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn); - Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero, - ArrayRef<int64_t>{}); + arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero, + ArrayRef<int64_t>{}); rewriter.replaceOp(op, result); return success(); } @@ -320,32 +317,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, inVectorTy.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value thisResult = nullptr; for (int64_t j = 0; j < elemsThisOp; j += 2) { - Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j); Value asFloatA = castToF32(elemA, loc, rewriter); Value asFloatB = nullptr; if (j + 1 < elemsThisOp) { - Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1); + Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1); asFloatB = castToF32(elemB, loc, rewriter); } - thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>( - loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); + thisResult = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); } if (elemsThisOp < 4) - thisResult = rewriter.create<vector::ExtractStridedSliceOp>( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -373,10 +370,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( // Handle the case where input type is not a vector type if (!inVectorTy) { - auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); + auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); Value asF16s = - rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB); - Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB); + Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0); rewriter.replaceOp(op, result); return success(); } @@ -389,7 +386,7 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector<int64_t>{numElements}, inVectorTy.getElementType()); - in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } // Handle the vector case. We also handle the (uncommon) case where the vector @@ -397,25 +394,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( for (int64_t i = 0; i < numElements; i += 2) { int64_t elemsThisOp = std::min(numElements, i + 2) - i; Value thisResult = nullptr; - Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i); - Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type()); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i); + Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); if (elemsThisOp == 2) { - elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1); + elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1); } thisResult = - rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB); // Place back the truncated result into the possibly larger vector. If we // are operating on a size 2 vector, these operations should be folded away - thisResult = rewriter.create<vector::ExtractStridedSliceOp>( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -452,7 +449,7 @@ LogicalResult ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opOutWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -463,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleType = getElementTypeOrSelf(scale); Type outType = getElementTypeOrSelf(out); + int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth(); + VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); @@ -472,28 +471,29 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - VectorType extScaleResultType = VectorType::get(opWidth, outType); + VectorType extScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { - Value inCast = rewriter.create<vector::BroadcastOp>( - loc, VectorType::get(1, inType), in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, + VectorType::get(1, inType), in); // TODO: replace this with non-packed ScaledExtOp - Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( - loc, extScaleResultType, inCast, scale, 0); + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inCast, scale, 0); scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0); return success(); } VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) + if (origScaleVecType) llvm::append_range(originalScaleShape, origScaleVecType.getShape()); originalScaleShape.insert(originalScaleShape.end(), @@ -508,45 +508,52 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, int64_t blockSize = computeProduct(ratio); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero); for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector<int64_t> strides(offsets.size(), 1); - Value block = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create<vector::ShapeCastOp>(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create<vector::ExtractOp>(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i); i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 - Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>( - loc, extScaleResultType, slice, uniformScale, 0); - if (sliceWidth != opWidth) - scaleExt = rewriter.create<vector::ExtractStridedSliceOp>( - loc, scaleExt, 0, sliceWidth, 1); - blockResult = rewriter.create<vector::InsertStridedSliceOp>( - loc, scaleExt, blockResult, i, 1); + i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) { + Value inSlice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, inSliceWidth, 1); + for (int64_t j = 0, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j); + j < inSliceWidth; j += outSliceWidth, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) { + // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inSlice, uniformScale, + j / opOutWidth); + if (outSliceWidth < opOutWidth) { + scaleExt = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleExt, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleExt, blockResult, i + j, 1); + } } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); @@ -558,7 +565,7 @@ LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opInWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -571,28 +578,28 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); - if (outVecType && outVecType.isScalable()) return failure(); Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - Value zero = rewriter.create<arith::ConstantOp>( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); - unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); - VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); + int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth(); + VectorType truncScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Type inVecType = VectorType::get(1, inType); - Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in); // TODO: replace this with non-packed ScaledTruncOp - Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( - loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr); + Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, inCast, scale, 0, + /*existing=*/nullptr); scaleTrunc = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0); return success(); @@ -600,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); - SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) - llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + SmallVector<int64_t> scaleShape; + if (origScaleVecType) + llvm::append_range(scaleShape, origScaleVecType.getShape()); - originalScaleShape.insert(originalScaleShape.end(), - inShape.size() - originalScaleShape.size(), 1); + scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1); - auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + auto maybeRatio = computeShapeRatio(inShape, scaleShape); assert(maybeRatio && "failed to derive block size from broadcast or splat operation"); @@ -623,41 +630,57 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector<int64_t> strides(offsets.size(), 1); - Value block = rewriter.create<vector::ExtractStridedSliceOp>( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create<vector::ShapeCastOp>(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create<vector::ExtractOp>(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); - i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create<vector::ExtractStridedSliceOp>( - loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 - Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>( - loc, truncScaleResultType, slice, uniformScale, 0, - /*existing=*/nullptr); - int64_t packedWidth = - cast<VectorType>(scaleTrunc.getType()).getNumElements(); - if (packedWidth != opWidth) - scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>( - loc, scaleTrunc, 0, sliceWidth, 1); - blockResult = rewriter.create<vector::InsertStridedSliceOp>( - loc, scaleTrunc, blockResult, i, 1); + for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i); + i < blockSize; i += outSliceWidth, + outSliceWidth = std::min(opOutWidth, blockSize - i)) { + Value scaleTrunc; + // Case where <= 2 elements are being truncated. + if (outSliceWidth <= opInWidth) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, outSliceWidth, 1); + // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, 0, + /*existing=*/nullptr); + } else { + scaleTrunc = vector::BroadcastOp::create(rewriter, loc, + truncScaleResultType, zero); + for (int64_t j = 0, + inSliceWidth = std::min(opInWidth, outSliceWidth - j); + j < outSliceWidth; j += opInWidth, + inSliceWidth = std::min(opInWidth, outSliceWidth - j)) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i + j, inSliceWidth, 1); + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, + j / opInWidth, scaleTrunc); + } + } + if (outSliceWidth != opOutWidth) { + scaleTrunc = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleTrunc, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleTrunc, blockResult, i, 1); } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult); - result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); @@ -667,19 +690,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, void mlir::arith::populateArithToAMDGPUConversionPatterns( RewritePatternSet &patterns, bool convertFP8Arithmetic, - bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { + bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset, + PatternBenefit benefit) { if (convertFP8Arithmetic) { - patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset); - patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(), - saturateFP8Truncf, chipset); + patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset, + benefit); + patterns.add<TruncFToFloat8RewritePattern>( + patterns.getContext(), saturateFP8Truncf, chipset, benefit); } if (allowPackedF16Rtz) - patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext()); + patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit); if (chipset >= kGfx950) { - patterns.add<ScalingExtFRewritePattern>(patterns.getContext()); - patterns.add<ScalingTruncFRewritePattern>(patterns.getContext()); + patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit); + patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit); } } diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp index cbe0b3f..ba48943 100644 --- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp +++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp @@ -74,15 +74,15 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> { VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); auto denseAttr1D = DenseElementsAttr::get( tileSliceType, denseAttr.getSplatValue<Attribute>()); - auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D); + auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D); - auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.insert_tile_slice' to write vector to tile // slice. - auto nextTile = b.create<arm_sme::InsertTileSliceOp>( - loc, tileType, constantOp1D, currentTile, tileSliceIndex); + auto nextTile = arm_sme::InsertTileSliceOp::create( + b, loc, tileType, constantOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; auto forOp = mlir::arm_sme::createLoopOverTileSlices( diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index a5c08a6..515fe5c 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -110,9 +110,9 @@ public: emitc::CmpPredicate predicate; switch (op.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: { - auto constant = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rewriter.getI1Type(), - rewriter.getBoolAttr(/*value=*/false)); + auto constant = + emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/false)); rewriter.replaceOp(op, constant); return success(); } @@ -179,9 +179,9 @@ public: return success(); } case arith::CmpFPredicate::AlwaysTrue: { - auto constant = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rewriter.getI1Type(), - rewriter.getBoolAttr(/*value=*/true)); + auto constant = + emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/true)); rewriter.replaceOp(op, constant); return success(); } @@ -189,8 +189,8 @@ public: // Compare the values naively auto cmpResult = - rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate, - adaptor.getLhs(), adaptor.getRhs()); + emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate, + adaptor.getLhs(), adaptor.getRhs()); // Adjust the results for unordered/ordered semantics if (unordered) { @@ -213,16 +213,16 @@ private: Value isNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is NaN exactly when it compares unequal to itself. - return rewriter.create<emitc::CmpOp>( - loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, operand, operand); } /// Return a value that is true if \p operand is not NaN. Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is not NaN exactly when it compares equal to itself. - return rewriter.create<emitc::CmpOp>( - loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand); + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, operand, operand); } /// Return a value that is true if the operands \p first and \p second are @@ -231,8 +231,8 @@ private: Location loc, Value first, Value second) const { auto firstIsNaN = isNaN(rewriter, loc, first); auto secondIsNaN = isNaN(rewriter, loc, second); - return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(), - firstIsNaN, secondIsNaN); + return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(), + firstIsNaN, secondIsNaN); } /// Return a value that is true if the operands \p first and \p second are @@ -241,8 +241,8 @@ private: Value first, Value second) const { auto firstIsNotNaN = isNotNaN(rewriter, loc, first); auto secondIsNotNaN = isNotNaN(rewriter, loc, second); - return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(), - firstIsNotNaN, secondIsNotNaN); + return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(), + firstIsNotNaN, secondIsNotNaN); } }; @@ -378,10 +378,10 @@ public: Type attrType = (emitc::isPointerWideType(operandType)) ? rewriter.getIndexType() : operandType; - auto constOne = rewriter.create<emitc::ConstantOp>( - op.getLoc(), operandType, rewriter.getOneAttr(attrType)); - auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>( - op.getLoc(), operandType, adaptor.getIn(), constOne); + auto constOne = emitc::ConstantOp::create( + rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType)); + auto oneAndOperand = emitc::BitwiseAndOp::create( + rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne); rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType, oneAndOperand); return success(); @@ -402,8 +402,8 @@ public: Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); // Actual cast (may change bitwidth) - auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), - castDestType, actualOp); + auto cast = + emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp); // Cast to the expected output type auto result = adaptValueType(cast, rewriter, opReturnType); @@ -466,9 +466,8 @@ public: Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType); Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType); - auto newDivOp = - rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType, - ArrayRef<Value>{lhsAdapted, rhsAdapted}); + auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType, + ArrayRef<Value>{lhsAdapted, rhsAdapted}); Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy); rewriter.replaceOp(uiBinOp, resultAdapted); return success(); @@ -508,8 +507,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -548,8 +547,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -588,38 +587,40 @@ public: // Add a runtime check for overflow Value width; if (emitc::isPointerWideType(type)) { - Value eight = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rhsType, rewriter.getIndexAttr(8)); - emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>( - op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight}); - width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight, - sizeOfCall.getResult(0)); + Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType, + rewriter.getIndexAttr(8)); + emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create( + rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight}); + width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight, + sizeOfCall.getResult(0)); } else { - width = rewriter.create<emitc::ConstantOp>( - op.getLoc(), rhsType, + width = emitc::ConstantOp::create( + rewriter, op.getLoc(), rhsType, rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); } - Value excessCheck = rewriter.create<emitc::CmpOp>( - op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); + Value excessCheck = + emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + emitc::CmpPredicate::lt, rhs, width); // Any concrete value is a valid refinement of poison. - Value poison = rewriter.create<emitc::ConstantOp>( - op.getLoc(), arithmeticType, + Value poison = emitc::ConstantOp::create( + rewriter, op.getLoc(), arithmeticType, (isa<IntegerType>(arithmeticType) ? rewriter.getIntegerAttr(arithmeticType, 0) : rewriter.getIndexAttr(0))); - emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>( - op.getLoc(), arithmeticType, /*do_not_inline=*/false); + emitc::ExpressionOp ternary = emitc::ExpressionOp::create( + rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false); Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); auto currentPoint = rewriter.getInsertionPoint(); rewriter.setInsertionPointToStart(&bodyBlock); Value arithmeticResult = - rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs); - Value resultOrPoison = rewriter.create<emitc::ConditionalOp>( - op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); - rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison); + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); + Value resultOrPoison = + emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType, + excessCheck, arithmeticResult, poison); + emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison); rewriter.setInsertionPoint(op->getBlock(), currentPoint); Value result = adaptValueType(ternary, rewriter, type); @@ -700,11 +701,12 @@ public: /*isSigned=*/false); } - Value result = rewriter.create<emitc::CastOp>( - castOp.getLoc(), actualResultType, adaptor.getOperands()); + Value result = emitc::CastOp::create( + rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands()); if (isa<arith::FPToUIOp>(castOp)) { - result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result); + result = + emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result); } rewriter.replaceOp(castOp, result); @@ -746,8 +748,8 @@ public: } Value fpCastOperand = adaptor.getIn(); if (actualOperandType != operandType) { - fpCastOperand = rewriter.template create<emitc::CastOp>( - castOp.getLoc(), actualOperandType, fpCastOperand); + fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(), + actualOperandType, fpCastOperand); } rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index f7bf581..18e857c 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -293,11 +293,11 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite( [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { - return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); } - return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); }, rewriter); } @@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite( Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); - Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>( - loc, structType, adaptor.getLhs(), adaptor.getRhs()); + Value addOverflow = LLVM::UAddWithOverflowOp::create( + rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = - rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0); Value overflowExtracted = - rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1); rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); return success(); } @@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite( "LLVM dialect should support all signless integer types"); using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>; - Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs()); - Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs()); - Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt); + Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs()); + Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs()); + Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt); // Split the 2*N-bit wide result into two N-bit values. - Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt); - Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr); - Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal); - Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt); + Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt); + Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr); + Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal); + Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt); rewriter.replaceOp(op, {low, high}); return success(); @@ -435,8 +435,8 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create<LLVM::ICmpOp>( - op.getLoc(), llvm1DVectorTy, + return LLVM::ICmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, @@ -471,8 +471,8 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create<LLVM::FCmpOp>( - op.getLoc(), llvm1DVectorTy, + return LLVM::FCmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs(), fmf); }, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 434d7df..265293b 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -117,12 +128,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value, if (auto vectorType = dyn_cast<VectorType>(type)) { Attribute element = IntegerAttr::get(vectorType.getElementType(), value); auto attr = SplatElementsAttr::get(vectorType, element); - return builder.create<spirv::ConstantOp>(loc, vectorType, attr); + return spirv::ConstantOp::create(builder, loc, vectorType, attr); } if (auto intType = dyn_cast<IntegerType>(type)) - return builder.create<spirv::ConstantOp>( - loc, type, builder.getIntegerAttr(type, value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getIntegerAttr(type, value)); return nullptr; } @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector<Attribute, 8> elements; if (isa<FloatType>(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa<IntegerType>(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa<FloatType>(srcType)) { auto srcAttr = cast<FloatAttr>(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); if (!dstAttr) return failure(); @@ -418,18 +447,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Type type = lhs.getType(); // Calculate the remainder with spirv.UMod. - Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); - Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); - Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); + Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs); + Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs); + Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) - isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); + isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs); else - isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); - Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); - return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); + isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs); + Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs); + return spirv::SelectOp::create(builder, loc, type, isPositive, abs, + absNegate); } /// Converts arith.remsi to GLSL SPIR-V ops. @@ -601,13 +631,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { Value allOnes; if (auto intTy = dyn_cast<IntegerType>(dstType)) { unsigned componentBitwidth = intTy.getWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, intTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, vectorTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, vectorTy, SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); } else { @@ -653,8 +683,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's // bitwidth. - auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>( - op.getLoc(), dstType, adaptor.getIn(), shiftSize); + auto shiftLOp = spirv::ShiftLeftLogicalOp::create( + rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize); // Then we perform arithmetic right shift to make sure we have the right // sign bits for negative values. @@ -757,9 +787,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); + Value maskedSrc = spirv::BitwiseAndOp::create( + rewriter, loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -914,9 +944,9 @@ public: if (auto vectorType = dyn_cast<VectorType>(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs()); Value extRhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, extRhs); @@ -1067,12 +1097,12 @@ public: replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); } } else { - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); + replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); + replace = spirv::LogicalNotOp::create(rewriter, loc, replace); } rewriter.replaceOp(op, replace); @@ -1094,17 +1124,17 @@ public: ConversionPatternRewriter &rewriter) const override { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); - Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(), - adaptor.getRhs()); + Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(), + adaptor.getRhs()); - Value sumResult = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(0)); - Value carryValue = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(1)); + Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one); + Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); @@ -1125,12 +1155,12 @@ public: ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value result = - rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs()); + SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); - Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(0)); - Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(1)); + Value low = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value high = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); rewriter.replaceOp(op, {low, high}); return success(); @@ -1183,20 +1213,20 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { rewriter.replaceOp(op, spirvOp); return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getLhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getRhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getLhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getRhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1237,7 +1267,7 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (!shouldInsertNanGuards<SPIRVOp>() || bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { @@ -1245,13 +1275,13 @@ public: return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getRhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getLhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getRhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getLhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1351,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 9c6de93..e34b368 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -41,11 +40,11 @@ public: Value c2d = op.getC(); Location loc = op.getLoc(); Value b1d = - rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d); + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d); Value c1d = - rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d); - Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(), - b1d, c1d); + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d); + Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(), + op.getA(), b1d, c1d); rewriter.replaceOp(op, {newOp}); return success(); } diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 21ea444..8a2e3b63 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); break; } } @@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_st1b_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_st1h_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_st1w_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_st1d_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_st1q_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } llvm_unreachable("unknown type in createStoreTileSliceIntrinsic"); @@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc, // Move to the first operation in the function. rewriter.setInsertionPointToStart(&func.getBlocks().front()); // Create an alloca matching the tile size of the `tileOp`. - auto vscale = rewriter.create<vector::VectorScaleOp>(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto tileElementType = tileOp.getTileType().getElementType(); auto memrefType = MemRefType::get( {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); auto minElementsOp = - rewriter.create<arith::ConstantIndexOp>(loc, minElements); - auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp); - auto alloca = rewriter.create<memref::AllocaOp>( - loc, memrefType, ValueRange{vectorLen, vectorLen}); + arith::ConstantIndexOp::create(rewriter, loc, minElements); + auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp); + auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType, + ValueRange{vectorLen, vectorLen}); return alloca; } @@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { Value tileMemory, Value sliceIndex) const { auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); auto descriptor = - rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory); - auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64); - auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), sliceIndex); + UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory); + auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64); + auto sliceIndexI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), sliceIndex); return getStridedElementPtr( static_cast<ConversionPatternRewriter &>(rewriter), loc, llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0), @@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { arm_sme::ArmSMETileType tileType, VectorType sliceType, IntegerAttr tileId, Value sliceIndex) const { // Cast the slice index to an i32. - auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create an all-true predicate for the slice. auto predicateType = sliceType.clone(rewriter.getI1Type()); - auto allTruePredicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Create padding vector (never used due to all-true predicate). - auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType); + auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType); // Get a pointer to the current slice. auto slicePtr = getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); // Read the value of the current slice from ZA. - auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>( - loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); + auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create( + rewriter, loc, sliceType, padVector, allTruePredicate, tileId, + sliceIndexI32); // Load the new tile slice back from memory into ZA. createLoadTileSliceIntrinsic( rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, allTruePredicate, slicePtr, tileId, sliceIndexI32); // Store the current tile slice to memory. - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca, - ValueRange{sliceIndex, zero}); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca, + ValueRange{sliceIndex, zero}); } /// Emits a full in-place swap of the contents of a tile in ZA and a @@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { RewriterBase::InsertionGuard guard(rewriter); // Create an scf.for over all tile slices. auto minNumElts = - rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0)); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto upperBound = rewriter.create<arith::MulIOp>( - loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc)); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = + arith::MulIOp::create(rewriter, loc, minNumElts, + vector::VectorScaleOp::create(rewriter, loc)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Emit a swap for each tile slice. rewriter.setInsertionPointToStart(forOp.getBody()); auto sliceIndex = forOp.getInductionVar(); @@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> { // // This holds for all tile sizes. int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); - rewriter.create<arm_sme::aarch64_sme_zero>( - loc, rewriter.getI32IntegerAttr(zeroMask)); + arm_sme::aarch64_sme_zero::create(rewriter, loc, + rewriter.getI32IntegerAttr(zeroMask)); // Create a placeholder op to preserve dataflow. // Note: Place the `get_tile` op at the start of the block. This ensures @@ -513,8 +516,8 @@ struct LoadTileSliceConversion auto tileSlice = loadTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. auto maskOp = loadTileSliceOp.getMask(); @@ -559,8 +562,8 @@ struct StoreTileSliceConversion auto tileSlice = storeTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); auto maskOp = storeTileSliceOp.getMask(); @@ -595,28 +598,29 @@ struct InsertTileSliceConversion auto tileSlice = insertTileSliceOp.getTileSliceIndex(); // Cast tile slice from index to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. - auto one = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI1Type(), + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); + auto allActiveMask = + vector::BroadcastOp::create(rewriter, loc, predTy, one); // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. switch (insertTileSliceOp.getLayout()) { case arm_sme::TileSliceLayout::Horizontal: - rewriter.create<arm_sme::aarch64_sme_write_horiz>( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; case arm_sme::TileSliceLayout::Vertical: - rewriter.create<arm_sme::aarch64_sme_write_vert>( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; } @@ -646,16 +650,16 @@ struct ExtractTileSliceConversion // Create an 'all true' predicate for the tile slice. auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); - auto allTruePredicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Zero destination/fallback for tile slice extraction. - auto zeroVector = rewriter.create<arith::ConstantOp>( - loc, sliceType, rewriter.getZeroAttr(sliceType)); + auto zeroVector = arith::ConstantOp::create( + rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType)); // Cast tile slice from index to i32 for intrinsic. - auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. switch (extractTileSlice.getLayout()) { @@ -743,7 +747,7 @@ struct OuterProductOpConversion Value acc = outerProductOp.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType); zero.setTileId(tileId); acc = zero; } @@ -754,16 +758,16 @@ struct OuterProductOpConversion if (!lhsMask || !rhsMask) { auto predTy = outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } // Create 'arm_sme.intr.mopa' outer product intrinsic. - rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask, - outerProductOp.getLhs(), - outerProductOp.getRhs()); + arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask, + outerProductOp.getLhs(), + outerProductOp.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -792,7 +796,7 @@ struct OuterProductWideningOpConversion Value acc = op.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType()); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType()); zero.setTileId(tileId); acc = zero; } @@ -801,14 +805,14 @@ struct OuterProductWideningOpConversion Value rhsMask = op.getRhsMask(); if (!lhsMask || !rhsMask) { auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } - rewriter.create<OuterProductWideningIntrOp>( - loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); + OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask, + adaptor.getLhs(), adaptor.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -843,13 +847,13 @@ struct StreamingVLOpConversion auto *intrOp = [&]() -> Operation * { switch (streamingVlOp.getTypeSize()) { case arm_sme::TypeSize::Byte: - return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type); + return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Half: - return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type); + return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Word: - return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type); + return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Double: - return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type); + return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); } llvm_unreachable("unknown type size in StreamingVLOpConversion"); }(); @@ -872,8 +876,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) { if (zeroOpsToMerge.size() <= 1) return; IRRewriter rewriter(zeroOpsToMerge.front()); - rewriter.create<arm_sme::aarch64_sme_zero>( - zeroOpsToMerge.front().getLoc(), + arm_sme::aarch64_sme_zero::create( + rewriter, zeroOpsToMerge.front().getLoc(), rewriter.getI32IntegerAttr(mergedZeroMask)); for (auto zeroOp : zeroOpsToMerge) rewriter.eraseOp(zeroOp); diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 458628c..e28d5122 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -39,7 +39,7 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank, auto tileSliceOffset = tileSliceIndex; auto baseIndexPlusTileSliceOffset = - rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset); + arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset); outIndices.push_back(baseIndexPlusTileSliceOffset); outIndices.push_back(indices[1]); @@ -59,10 +59,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( if (memrefIndices.size() != 2) return rewriter.notifyMatchFailure(loc, "invalid number of indices"); - auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, + arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); auto vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); @@ -70,7 +71,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( // elements in a vector of SVL bits for a given element type (SVL_B, // SVL_H, ..., SVL_Q). auto numTileSlices = - rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); Value predicate; Value upperBound; @@ -82,30 +83,30 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( // The upper bound of the loop must be clamped at `numTileSlices` as // `vector.create_mask` allows operands to be greater than the size of a // dimension. - auto numRowI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), maskDim0); - auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), numTileSlices); + auto numRowI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), maskDim0); + auto numTileSlicesI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), numTileSlices); auto upperBoundI64 = - rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64); - upperBound = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), upperBoundI64); + arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64); + upperBound = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), upperBoundI64); predicate = - rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1); + vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1); } else { upperBound = numTileSlices; // No mask. Create an 'all true' predicate for the tile slice. - predicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + predicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); } bool hasCarriedArgs = bool(initTile); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step, - hasCarriedArgs ? ValueRange{initTile} - : ValueRange{}); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, + hasCarriedArgs ? ValueRange{initTile} : ValueRange{}); rewriter.setInsertionPointToStart(forOp.getBody()); Value tileSliceIndex = forOp.getInductionVar(); @@ -118,7 +119,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( assert(bool(nextTile) == hasCarriedArgs); if (nextTile) - rewriter.create<scf::YieldOp>(loc, nextTile); + scf::YieldOp::create(rewriter, loc, nextTile); return forOp; } @@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { // Initialize tile with zero to satisfy padding. Inactive cols will be // zeroed anyway since the loads use zeroing predication. For inactive // rows however, no load will occur so these need to be zeroed. - initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType); + initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType); } else { - initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); } // Create a loop to load the active tile slices from memory. @@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { Value currentTile) -> Value { // Create 'arm_sme.load_tile_slice' to load tile slice from memory // into tile. - return rewriter.create<arm_sme::LoadTileSliceOp>( - loc, tileType, tileLoadOp.getBase(), predicate, currentTile, - memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + return arm_sme::LoadTileSliceOp::create( + rewriter, loc, tileType, tileLoadOp.getBase(), predicate, + currentTile, memrefIndices, tileSliceIndex, + tileLoadOp.getLayout()); }); if (failed(forOp)) @@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto numRows = createMaskOp.getOperands()[0]; auto numCols = createMaskOp.getOperands()[1]; - auto numColsI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), numCols); + auto numColsI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), numCols); - auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); // Create a loop that loads each ZA tile slice from memory. - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto numTileSlices = - rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, - step, ValueRange{initTile}); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices, + step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); @@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto currentTile = forOp.getRegionIterArg(0); // Combine masks. - auto rowIsActive = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); - auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>( - loc, rewriter.getI32Type(), rowIsActive); - auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32); - auto maskIndex = - rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask); + auto rowIsActive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); + auto rowIsActiveI32 = arith::ExtSIOp::create( + rewriter, loc, rewriter.getI32Type(), rowIsActive); + auto mask = + arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32); + auto maskIndex = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), mask); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); - auto maskOp1D = rewriter.create<vector::CreateMaskOp>( - loc, predicateType, maskIndex.getResult()); + auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType, + maskIndex.getResult()); auto memrefIndices = getMemrefIndices( tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(), @@ -324,17 +327,19 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion // Splat pad into 1-D vector matching type of tile slice. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp); + auto pad1DOp = + vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp); - auto loadSlice = rewriter.create<vector::MaskedLoadOp>( - loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D, - /*passthru=*/pad1DOp); + auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType, + tileLoadOp.getBase(), + memrefIndices, maskOp1D, + /*passthru=*/pad1DOp); // Create 'arm_sme.insert_tile_slice' to insert slice into tile. - auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>( - loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex, - tileLoadOp.getLayout()); - rewriter.create<scf::YieldOp>(loc, insertSlice.getResult()); + auto insertSlice = arm_sme::InsertTileSliceOp::create( + rewriter, loc, tileType, loadSlice->getResult(0), currentTile, + tileSliceIndex, tileLoadOp.getLayout()); + scf::YieldOp::create(rewriter, loc, insertSlice.getResult()); rewriter.setInsertionPointAfter(forOp); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 94f7caa..29e6552 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -203,7 +202,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; - builder.create<func::FuncOp>(name, type).setPrivate(); + func::FuncOp::create(builder, name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); @@ -254,15 +253,15 @@ static void addResumeFunction(ModuleOp module) { auto voidTy = LLVM::LLVMVoidType::get(ctx); Type ptrType = AsyncAPI::opaquePointerType(ctx); - auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( - kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); + auto resumeOp = LLVM::LLVMFuncOp::create( + moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(moduleBuilder); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); - blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); - blockBuilder.create<LLVM::ReturnOp>(ValueRange()); + LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0)); + LLVM::ReturnOp::create(blockBuilder, ValueRange()); } //===----------------------------------------------------------------------===// @@ -282,7 +281,8 @@ public: // in patterns for other dialects. auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); + auto cast = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return cast.getResult(0); }; @@ -343,8 +343,8 @@ public: // Constants for initializing coroutine frame. auto constZero = - rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); - auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType); // Get coroutine id: @llvm.coro.id. rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( @@ -372,33 +372,33 @@ public: // Get coroutine frame size: @llvm.coro.size.i64. Value coroSize = - rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); + LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type()); // Get coroutine frame alignment: @llvm.coro.align.i64. Value coroAlign = - rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); + LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type()); // Round up the size to be multiple of the alignment. Since aligned_alloc // requires the size parameter be an integral multiple of the alignment // parameter. auto makeConstant = [&](uint64_t c) { - return rewriter.create<LLVM::ConstantOp>(op->getLoc(), - rewriter.getI64Type(), c); + return LLVM::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getI64Type(), c); }; - coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); + coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign); coroSize = - rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); + LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1)); Value negCoroAlign = - rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); + LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign); coroSize = - rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); + LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign); // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); if (failed(allocFuncOp)) return failure(); - auto coroAlloc = rewriter.create<LLVM::CallOp>( - loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); + auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), + ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); @@ -427,7 +427,7 @@ public: // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = - rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands()); + LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands()); // Free the memory. auto freeFuncOp = @@ -455,15 +455,15 @@ public: matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. - auto constFalse = rewriter.create<LLVM::ConstantOp>( - op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); - auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc()); + auto constFalse = + LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(false)); + auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc()); // Mark the end of a coroutine: @llvm.coro.end. auto coroHdl = adaptor.getHandle(); - rewriter.create<LLVM::CoroEndOp>( - op->getLoc(), rewriter.getI1Type(), - ValueRange({coroHdl, constFalse, noneToken})); + LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), + ValueRange({coroHdl, constFalse, noneToken})); rewriter.eraseOp(op); return success(); @@ -534,13 +534,13 @@ public: auto loc = op->getLoc(); // This is not a final suspension point. - auto constFalse = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + auto constFalse = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend auto coroState = adaptor.getState(); - auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( - loc, i8, ValueRange({coroState, constFalse})); + auto coroSuspend = LLVM::CoroSuspendOp::create( + rewriter, loc, i8, ValueRange({coroState, constFalse})); // Cast return code to i32. @@ -551,7 +551,7 @@ public: llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(), op.getCleanupDest()}; rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( - op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), + op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()), /*defaultDestination=*/op.getSuspendDest(), /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, @@ -602,11 +602,11 @@ public: // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i64 - auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType); auto gep = - rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType, - nullPtr, ArrayRef<LLVM::GEPArg>{1}); - return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); + LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType, + nullPtr, ArrayRef<LLVM::GEPArg>{1}); + return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep); }; rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType, @@ -739,8 +739,8 @@ public: .Case<ValueType>([](Type) { return kAwaitValue; }) .Case<GroupType>([](Type) { return kAwaitGroup; }); - rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(), - adaptor.getOperands()); + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), + adaptor.getOperands()); rewriter.eraseOp(op); return success(); @@ -772,13 +772,12 @@ public: // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType<ModuleOp>()); - auto resumePtr = rewriter.create<LLVM::AddressOfOp>( - op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), - kResume); + auto resumePtr = LLVM::AddressOfOp::create( + rewriter, op->getLoc(), + AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); - rewriter.create<func::CallOp>( - op->getLoc(), apiFuncName, TypeRange(), - ValueRange({operand, handle, resumePtr.getRes()})); + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), + ValueRange({operand, handle, resumePtr.getRes()})); rewriter.eraseOp(op); return success(); @@ -801,9 +800,9 @@ public: ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType<ModuleOp>()); - auto resumePtr = rewriter.create<LLVM::AddressOfOp>( - op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), - kResume); + auto resumePtr = LLVM::AddressOfOp::create( + rewriter, op->getLoc(), + AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); // Call async runtime API to execute a coroutine in the managed thread. auto coroHdl = adaptor.getHandle(); @@ -832,8 +831,8 @@ public: // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create<func::CallOp>( - loc, kGetValueStorage, TypeRange(ptrType), storage); + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, + TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getValue().getType(); @@ -845,7 +844,7 @@ public: Value castedStoragePtr = storagePtr.getResult(0); // Store the yielded value into the async value storage. auto value = adaptor.getValue(); - rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr); + LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr); // Erase the original runtime store operation. rewriter.eraseOp(op); @@ -872,8 +871,8 @@ public: // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create<func::CallOp>( - loc, kGetValueStorage, TypeRange(ptrType), storage); + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, + TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getResult().getType(); @@ -960,9 +959,9 @@ public: LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto count = rewriter.create<arith::ConstantOp>( - op->getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(op.getCount())); + auto count = + arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(op.getCount())); auto operand = adaptor.getOperand(); rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName, diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index b9991f3..3edcbb8 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -47,30 +47,29 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) { // Constants - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); // Dynamically evaluate the size and shape of the unranked memref - Value rank = rewriter.create<memref::RankOp>(loc, op.getInput()); + Value rank = memref::RankOp::create(rewriter, loc, op.getInput()); MemRefType allocType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); - Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank); + Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank); // Create a loop to query dimension sizes, store them as a shape, and // compute the total size of the memref auto loopBody = [&](OpBuilder &builder, Location loc, Value i, ValueRange args) { auto acc = args.front(); - auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i); + auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i); - rewriter.create<memref::StoreOp>(loc, dim, shape, i); - acc = rewriter.create<arith::MulIOp>(loc, acc, dim); + memref::StoreOp::create(rewriter, loc, dim, shape, i); + acc = arith::MulIOp::create(rewriter, loc, acc, dim); - rewriter.create<scf::YieldOp>(loc, acc); + scf::YieldOp::create(rewriter, loc, acc); }; - auto size = rewriter - .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), - loopBody) + auto size = scf::ForOp::create(rewriter, loc, zero, rank, one, + ValueRange(one), loopBody) .getResult(0); MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, @@ -78,9 +77,9 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { // Allocate new memref with 1D dynamic shape, then reshape into the // shape of the original unranked memref - alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size); + alloc = memref::AllocOp::create(rewriter, loc, memrefType, size); alloc = - rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape); + memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape); } else { MemRefType memrefType = cast<MemRefType>(type); MemRefLayoutAttrInterface layout; @@ -103,14 +102,15 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { } // Allocate a memref with identity layout. - alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands); + alloc = + memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands); // Cast the allocation to the specified type if needed. if (memrefType != allocType) alloc = - rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc); + memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc); } - rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc); + memref::CopyOp::create(rewriter, loc, op.getInput(), alloc); rewriter.replaceOp(op, alloc); return success(); } diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index f84375b..785cb82 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) -add_subdirectory(MeshToMPI) +add_subdirectory(ShardToMPI) add_subdirectory(MPIToLLVM) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc29..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index eeff8a9..5ad514d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include <type_traits> diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index c8311eb..5ac838c 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, return emitError(loc, "Cannot create unreachable terminator for '") << parentOp->getName() << "'"; - return builder - .create<func::ReturnOp>( - loc, llvm::map_to_vector(funcOp.getResultTypes(), - [&](Type type) { - return getUndefValue(loc, builder, type); - })) + return func::ReturnOp::create( + builder, loc, + llvm::map_to_vector( + funcOp.getResultTypes(), + [&](Type type) { return getUndefValue(loc, builder, type); })) .getOperation(); } diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4..56b6181 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp index c9b1dc1..ee6d7d5 100644 --- a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp +++ b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp @@ -9,8 +9,6 @@ #include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h" #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp index 252245d..c70b5f0 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp @@ -9,7 +9,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" -#include "llvm/ADT/DenseSet.h" using namespace mlir; diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f65..c0439a4 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 63eb6c58..3cfbd89 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -579,8 +579,8 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, auto function = [&] { if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) return function; - return OpBuilder::atBlockEnd(module.getBody()) - .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); + auto builder = OpBuilder::atBlockEnd(module.getBody()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); }(); return LLVM::CallOp::create(builder, loc, function, arguments); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index a19194e..3545acb 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -385,6 +385,14 @@ LogicalResult GPUModuleConversion::matchAndRewrite( if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>( spirv::getTargetEnvAttrName())) spvModule->setAttr(spirv::getTargetEnvAttrName(), attr); + if (ArrayAttr targets = moduleOp.getTargetsAttr()) { + for (Attribute targetAttr : targets) + if (auto spirvTargetEnvAttr = + dyn_cast<spirv::TargetEnvAttr>(targetAttr)) { + spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr); + break; + } + } rewriter.eraseOp(moduleOp); return success(); @@ -507,25 +515,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( - rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), - adaptor.getWidth()); + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { + IntegerAttr widthAttr = adaptor.getWidthAttr(); Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); @@ -559,8 +569,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); - return builder - .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue) + return NonUniformOp::create(builder, loc, type, scope, groupOp, arg, + clusterSizeValue) .getResult(); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index a344f88..5eab057 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -48,9 +48,36 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> { void runOnOperation() override; private: + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp`. + spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp); + + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp` or returns target environment as returned by + /// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'. + spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp); bool mapMemorySpace; }; +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) { + if (ArrayAttr targets = moduleOp.getTargetsAttr()) { + for (Attribute targetAttr : targets) + if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr)) + return spirvTargetEnvAttr; + } + + return {}; +} + +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) { + if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp)) + return targetEnvAttr; + + return spirv::lookupTargetEnvOrDefault(moduleOp); +} + void GPUToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -58,9 +85,8 @@ void GPUToSPIRVPass::runOnOperation() { SmallVector<Operation *, 1> gpuModules; OpBuilder builder(context); - auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) { - Operation *gpuModule = moduleOp.getOperation(); - auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule); + auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) { + auto targetAttr = lookupTargetEnvOrDefault(moduleOp); spirv::TargetEnv targetEnv(targetAttr); return targetEnv.allows(spirv::Capability::Kernel); }; @@ -86,7 +112,7 @@ void GPUToSPIRVPass::runOnOperation() { // TargetEnv attributes. for (Operation *gpuModule : gpuModules) { spirv::TargetEnvAttr targetAttr = - spirv::lookupTargetEnvOrDefault(gpuModule); + lookupTargetEnvOrDefault(cast<gpu::GPUModuleOp>(gpuModule)); // Map MemRef memory space to SPIR-V storage class first if requested. if (mapMemorySpace) { diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index ecd5b63..2568044 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = - toDynamic - ? builder - .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) - .getResult() - : LLVM::AllocaOp::create(builder, loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); + toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); if (!toDynamic) diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 5b68eb8..e5496e5 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, if (!(ret = moduleOp.lookupSymbol<Op>(name))) { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); + ret = Op::create(rewriter, loc, std::forward<Args>(args)...); } return ret; } diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 08a4566..cde2340 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -17,13 +17,12 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS @@ -33,7 +32,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-funcs" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace { // Pattern to convert vector operations to scalar operations. @@ -654,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op, /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { if (!isa<IntegerType>(elementType)) { - LLVM_DEBUG({ - DBGS() << "non-integer element type for CtlzFunc; type was: "; - elementType.print(llvm::dbgs()); - }); + LDBG() << "non-integer element type for CtlzFunc; type was: " + << elementType; llvm_unreachable("non-integer element type"); } int64_t bitWidth = elementType.getIntOrFloatBitWidth(); @@ -699,7 +695,8 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { scf::IfOp ifOp = scf::IfOp::create(builder, elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); - ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue); + auto thenBuilder = ifOp.getThenBodyBuilder(); + scf::YieldOp::create(thenBuilder, loc, bitWidthValue); auto elseBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 93d8b49..df219f3 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,7 +22,6 @@ #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOROCDL @@ -31,7 +31,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-rocdl" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") template <typename OpTy> static void populateOpPatterns(const LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index a877ad2..1787e0a 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -488,7 +488,12 @@ namespace mlir { void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // Core patterns - patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); + patterns + .add<CopySignPattern, + CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>, + CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>, + CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>( + typeConverter, patterns.getContext()); // GLSL patterns patterns diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e882845..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 53a1912..dc2035b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -24,11 +24,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" + #include <optional> #define DEBUG_TYPE "memref-to-llvm" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS @@ -575,8 +576,8 @@ private: Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); - return rewriter - .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) + return LLVM::LoadOp::create(rewriter, loc, + getTypeConverter()->getIndexType(), sizePtr) .getResult(); } @@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " - "from fmax to fmaximum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw maximumf changed " + "from fmax to fmaximum, expect more NaNs"; return LLVM::AtomicBinOp::fmaximum; case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; @@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " - "from fmin to fminimum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw minimum changed " + "from fmin to fminimum, expect more NaNs"; return LLVM::AtomicBinOp::fminimum; case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 905287e1..2549a9c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -21,19 +21,17 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <optional> #define DEBUG_TYPE "nvgpu-to-nvvm" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGSE() (llvm::dbgs()) namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS @@ -1106,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) - << ")\n start_addr : " << baseAddr << "\n"); + LDBG() << "Generating warpgroup.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); return success(); @@ -1282,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering } else { llvm_unreachable("msg: not supported K shape"); } - LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM - << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); + LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]"; } /// Generates WGMMATypesAttr from MLIR Type @@ -1367,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering int tileShapeA = matrixTypeA.getDimSize(1); int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k - << "] [wgmma descriptors] Descriptor A + " - << incrementVal << " | \t "); + LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal + << " | \t "; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1392,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering int byte = elemB.getIntOrFloatBitWidth() / 8; int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + LDBG() << "Descriptor B + " << incrementVal; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1401,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" - << (iterationM * wgmmaM) + wgmmaM << "][" - << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" - << wgmmaN << "])\n"); + LDBG() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" + << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM + << "][" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN + << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1468,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN - << "] += A[" << totalM << "][" << totalK << "] * B[" - << totalK << "][" << totalN << "] ---===\n"); + LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A[" + << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN + << "] ---==="; // Find the shape for one wgmma instruction findWgmmaShape( diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 662ee9e..91788f9 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -25,11 +25,10 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS @@ -52,17 +51,17 @@ struct PtxLowering LogicalResult matchAndRewrite(BasicPtxBuilderInterface op, PatternRewriter &rewriter) const override { if (op.hasIntrinsic()) { - LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n"); + LDBG() << "Ptx Builder does not lower \n\t" << op; return failure(); } SmallVector<std::pair<Value, PTXRegisterMod>> asmValues; - LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); + LDBG() << op.getPtx(); PtxBuilder generator(op, rewriter); op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { - LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier); + LDBG() << asmValue << "\t Modifier : " << &modifier; generator.insertValue(asmValue, modifier); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 3e434ea..5bd1d49 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList, assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type"); predList.emplace_back(pos, builder.getIsNotNull()); - if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { + if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) { // If the attribute has a type or value, add a constraint. if (Value type = attr.getValueType()) getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp index e1a9fa59..2d9c661f 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp @@ -14,9 +14,7 @@ #include "RootOrdering.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" -#include <queue> #include <utility> using namespace mlir; diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 240491a..807be7e 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); + SmallVector<Value> args = llvm::to_vector(condOp.getArgs()); rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); @@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. - rewriter.replaceOp(whileOp, condOp.getArgs()); + rewriter.replaceOp(whileOp, args); return success(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index f191f35..badd2f6 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -25,9 +25,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/Debug.h" #include <optional> diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index aae3271..9b61540 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1493,11 +1493,11 @@ public: Value extended; if (op2TypeWidth < dstTypeWidth) { if (isUnsignedIntegerOrVector(op2Type)) { - extended = rewriter.template create<LLVM::ZExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } else { - extended = rewriter.template create<LLVM::SExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } } else if (op2TypeWidth == dstTypeWidth) { extended = adaptor.getOperand2(); @@ -1505,8 +1505,8 @@ public: return failure(); } - Value result = rewriter.template create<LLVMOp>( - loc, dstType, adaptor.getOperand1(), extended); + Value result = + LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 7025c5a..0ff9fb3 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt index 15560aa..564f36f 100644 --- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt +++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt @@ -1,8 +1,8 @@ -add_mlir_conversion_library(MLIRMeshToMPI - MeshToMPI.cpp +add_mlir_conversion_library(MLIRShardToMPI + ShardToMPI.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI DEPENDS MLIRConversionPassIncGen @@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI MLIRLinalgTransforms MLIRMemRefDialect MLIRPass - MLIRMeshDialect + MLIRShardDialect MLIRMPIDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index 63b1fda..fa9e544 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -1,4 +1,4 @@ -//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// +//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation of Mesh communication ops tp MPI ops. +// This file implements a translation of Shard communication ops to MPI ops. // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" +#include "mlir/Conversion/ShardToMPI/ShardToMPI.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -20,11 +20,11 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" @@ -35,16 +35,15 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "mesh-to-mpi" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DEBUG_TYPE "shard-to-mpi" namespace mlir { -#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS +#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; -using namespace mesh; +using namespace shard; namespace { /// Converts a vector of OpFoldResults (ints) into vector of Values of the @@ -177,9 +176,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { auto type = RankedTensorType::get({nSplits, 2}, i64); Value resHaloSizes = haloSizes.empty() - ? rewriter - .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, - i64) + ? tensor::EmptyOp::create(rewriter, loc, + std::array<int64_t, 2>{0, 0}, i64) .getResult() : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) .getResult(); @@ -188,18 +186,18 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { // maxSplitSize+1}. Store the offsets in the tensor but set trailing // elements for smaller split-groups to -1. Computing the max size of the // split groups needs using collectiveProcessGroupSize (which needs the - // MeshOp) + // GridOp) Value resOffsets; if (adaptor.getStaticShardedDimsOffsets().empty()) { resOffsets = tensor::EmptyOp::create(rewriter, loc, std::array<int64_t, 2>{0, 0}, i64); } else { SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); int64_t maxSplitSize = 0; for (auto axes : splitAxes) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic); maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); } @@ -218,7 +216,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { int64_t curr = 0; for (auto [i, axes] : llvm::enumerate(splitAxes)) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); ++splitSize; // add one for the total size ArrayRef<Value> values(&offsets[curr], splitSize); @@ -264,20 +262,20 @@ struct ConvertProcessMultiIndexOp SymbolTableCollection symbolTableCollection; Location loc = op.getLoc(); - auto meshOp = getMesh(op, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(op, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); - Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp); + Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); - // optionally extract subset of mesh axes + // optionally extract subset of grid axes auto axes = adaptor.getAxes(); if (!axes.empty()) { SmallVector<Value> subIndex; @@ -306,13 +304,11 @@ public: auto ctx = op.getContext(); Value commWorld = mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); - auto rank = - rewriter - .create<mpi::CommRankOp>( - loc, - TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, - commWorld) - .getRank(); + auto rank = mpi::CommRankOp::create( + rewriter, loc, + TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, + commWorld) + .getRank(); rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank); return success(); @@ -338,12 +334,12 @@ struct ConvertNeighborsLinearIndicesOp Location loc = op.getLoc(); SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); auto mIdx = adaptor.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); Value dimSz = dims[axes[0]]; @@ -394,14 +390,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); if (!sharding) { return op->emitError() - << "Expected SharingOp as defining op for sharding" + << "Expected ShardingOp as defining op for sharding" << " but found " << adaptor.getSharding()[0].getDefiningOp(); } // Compute the sharded shape by applying the sharding to the input shape. // If shardedDimsOffsets is not defined in the sharding, the shard shape is // computed by dividing the dimension size by the number of shards in that - // dimension (which is given by the size of the mesh axes provided in + // dimension (which is given by the size of the grid axes provided in // split-axes). Odd elements get distributed to trailing shards. If a // shardedDimsOffsets is provided, the shard shape is computed by // subtracting the offset of the current shard from the offset of the next @@ -431,11 +427,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { SmallVector<Value> multiIdx = getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); - // Get the MeshOp, the mesh shape is needed to compute the sharded shape. + // Get the GridOp, the grid shape is needed to compute the sharded shape. SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(sharding, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(sharding, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); auto splitAxes = sharding.getSplitAxes().getAxes(); @@ -455,7 +451,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { tmp); } - // With static mesh shape the sizes of the split axes are known. + // With static grid shape the sizes of the split axes are known. // Hence the start/pos for each split axes in shardDimsOffsets can be // computed statically. int64_t pos = 0; @@ -475,10 +471,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { // Create a value from the static position in shardDimsOffsets. Value posVal = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(pos)); - // Get the index of the local shard in the mesh axis. + // Get the index of the local shard in the grid axis. Value idx = multiIdx[axes[0]]; auto numShards = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); if (shardedDimsOffs) { // If sharded dims offsets are provided, use them to compute the // sharded shape. @@ -556,13 +552,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SymbolTableCollection symbolTableCollection; - auto mesh = adaptor.getMesh(); - mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); - if (!meshOp) - return op->emitError() << "No mesh found for AllReduceOp"; - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto grid = adaptor.getGrid(); + mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection); + if (!gridOp) + return op->emitError() << "No grid found for AllReduceOp"; + if (ShapedType::isDynamicShape(gridOp.getShape())) return op->emitError() - << "Dynamic mesh shape not supported in AllReduceOp"; + << "Dynamic grid shape not supported in AllReduceOp"; ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); Value input = adaptor.getInput(); @@ -592,27 +588,27 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { linalg::CopyOp::create(iBuilder, input, buffer); // Get an MPI_Comm_split for the AllReduce operation. - // The color is the linear index of the process in the mesh along the - // non-reduced axes. The key is the linear index of the process in the mesh + // The color is the linear index of the process in the grid along the + // non-reduced axes. The key is the linear index of the process in the grid // along the reduced axes. - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), iBuilder.getIndexType()); SmallVector<Value> myMultiIndex = - ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh) + ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid) .getResult(); Value zero = arith::ConstantIndexOp::create(iBuilder, 0); SmallVector<Value> multiKey(myMultiIndex.size(), zero); - auto redAxes = adaptor.getMeshAxes(); + auto redAxes = adaptor.getGridAxes(); for (auto axis : redAxes) { multiKey[axis] = myMultiIndex[axis]; myMultiIndex[axis] = zero; } Value color = - createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); + createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder); color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); - Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); + Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder); key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator @@ -698,15 +694,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { } auto rank = cast<ShapedType>(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); - auto mesh = adaptor.getMesh(); - auto meshOp = getMesh(op, symbolTableCollection); + auto grid = adaptor.getGrid(); + auto gridOp = getGrid(op, symbolTableCollection); // subviews need Index values for (auto &sz : haloSizes) { if (auto value = dyn_cast<Value>(sz)) - sz = - rewriter - .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) - .getResult(); + sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), + value) + .getResult(); } // most of the offset/size/stride data is the same for all dims @@ -745,10 +740,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = - ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh) + ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { @@ -758,9 +753,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split - auto tmp = rewriter - .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex, - splitAxes) + auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid, + myMultiIndex, splitAxes) .getResults(); // MPI operates on i32... Value neighbourIDs[2] = { @@ -791,7 +785,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive - // Processes on the mesh borders have only one neighbor + // Processes on the grid borders have only one neighbor auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; auto hasFrom = arith::CmpIOp::create( @@ -869,8 +863,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { } }; -struct ConvertMeshToMPIPass - : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { +struct ConvertShardToMPIPass + : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> { using Base::Base; /// Run the dialect converter on the module. @@ -879,12 +873,12 @@ struct ConvertMeshToMPIPass RewritePatternSet patterns(ctxt); ConversionTarget target(getContext()); - // Define a type converter to convert mesh::ShardingType, + // Define a type converter to convert shard::ShardingType, // mostly for use in return operations. TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - // convert mesh::ShardingType to a tuple of RankedTensorTypes + // convert shard::ShardingType to a tuple of RankedTensorTypes typeConverter.addConversion( [](ShardingType type, SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { @@ -920,10 +914,10 @@ struct ConvertMeshToMPIPass return results; }); - // No mesh dialect should left after conversion... - target.addIllegalDialect<mesh::MeshDialect>(); - // ...except the global MeshOp. MeshShapeOp which will get folded later. - target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>(); + // No shard dialect should left after conversion... + target.addIllegalDialect<shard::ShardDialect>(); + // ...except the global GridOp. GridShapeOp which will get folded later. + target.addLegalOp<shard::GridOp, shard::GridShapeOp>(); // Allow all the stuff that our patterns will convert to target.addLegalDialect< BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, @@ -951,7 +945,7 @@ struct ConvertMeshToMPIPass // Folding patterns cannot be mixed with conversion patterns -> extra pass. patterns.clear(); SymbolTableCollection symbolTableCollection; - mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); + mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386e..8cd650e 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ec55091..0e3de06 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -570,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( // to UIToFP. if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) { auto unrealizedCast = - rewriter - .create<UnrealizedConversionCastOp>( - loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), - args[0]) + UnrealizedConversionCastOp::create( + rewriter, loc, + rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], unrealizedCast); @@ -869,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Emit 'linalg.generic' op auto resultTensor = - opBuilder - .create<linalg::GenericOp>( - loc, outputTensor.getType(), operand, outputTensor, affineMaps, - getNParallelLoopsAttrs(rank), - [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { - // Emit 'linalg.yield' op - linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); - }) + linalg::GenericOp::create( + opBuilder, loc, outputTensor.getType(), operand, outputTensor, + affineMaps, getNParallelLoopsAttrs(rank), + [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { + // Emit 'linalg.yield' op + linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); + }) .getResult(0); // Cast to original operand type if necessary @@ -1156,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, inputs.push_back(input); // First fill the output buffer with the init value. - auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), - dynDims) - .getResult(); + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) + .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) @@ -1168,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, op, "No initial value found for reduction operation"); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); outputs.push_back(filledTensor); bool isNanIgnoreMode = false; @@ -1187,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto trueAttr = rewriter.getBoolAttr(true); auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(), - dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + trueValue.getType(), dynDims) .getResult(); auto allResultsNaNTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{trueValue}, - ValueRange{emptyBoolTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{trueValue}, + ValueRange{emptyBoolTensor}) .result(); // Note that because the linalg::ReduceOp has two variadic arguments // (inputs and outputs) and it has the SameVariadicOperandSize trait we @@ -1262,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false)); auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); auto nanFilledTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{nanValue}, - ValueRange{emptyNanTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{nanValue}, + ValueRange{emptyNanTensor}) .result(); // Create an empty tensor, non need to fill this since it will be // overwritten by the select. auto finalEmptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); // Do a selection between the tensors akin to: @@ -1504,12 +1494,11 @@ public: Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (valueTy.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>( - nestedLoc, - nestedBuilder.getIntegerType( - valueTy.getIntOrFloatBitWidth()), - value) + value = UnrealizedConversionCastOp::create( + nestedBuilder, nestedLoc, + nestedBuilder.getIntegerType( + valueTy.getIntOrFloatBitWidth()), + value) .getResult(0); } if (valueTy.getIntOrFloatBitWidth() < 32) { @@ -1558,9 +1547,8 @@ public: } if (outIntType.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>(nestedLoc, - outIntType, value) + value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc, + outIntType, value) .getResult(0); } linalg::YieldOp::create(nestedBuilder, loc, value); @@ -2096,10 +2084,9 @@ public: Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, inputTy.getShape(), - inputTy.getElementType(), - ArrayRef<Value>({dynDims})) + auto emptyTensor = tensor::EmptyOp::create( + rewriter, loc, inputTy.getShape(), + inputTy.getElementType(), ArrayRef<Value>({dynDims})) .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; @@ -2242,23 +2229,22 @@ public: } // First fill the output buffer for the index. - auto emptyTensorIdx = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - outElementTy, dynDims) - .getResult(); + auto emptyTensorIdx = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + outElementTy, dynDims) + .getResult(); auto fillValueIdx = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, - ValueRange{emptyTensorIdx}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx}, + ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. - auto emptyTensorMax = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - inElementTy, dynDims) - .getResult(); + auto emptyTensorMax = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy, + dynDims) + .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); @@ -2269,9 +2255,8 @@ public: auto fillValueMax = arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueMax}, - ValueRange{emptyTensorMax}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax}, + ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along @@ -2372,9 +2357,8 @@ public: auto loc = op.getLoc(); auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, - dynamicDims) + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynamicDims) .getResult(); SmallVector<AffineMap, 2> affineMaps = { @@ -2449,10 +2433,10 @@ public: } } - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - resultElementTy, dynDims) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynDims) + .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), @@ -2586,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); return filledTensor; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 3a20524..da1fb20 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef<AffineMap> indexingMaps) { ShapedType resultTy = cast<ShapedType>(conv.getType()); - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal); - } - Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]); - linalg::YieldOp::create(builder, loc, added); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, conv}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + arith::ExtSIOp::create(builder, loc, resType, biasVal); + } + Value added = + arith::AddIOp::create(builder, loc, biasVal, args[1]); + linalg::YieldOp::create(builder, loc, added); + }) .getResult(0); } @@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); // Build the broadcast-like operation as a linalg.generic. - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({source}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = - resultTy.getElementType().isFloat() - ? arith::ExtFOp::create(builder, loc, resType, biasVal) - .getResult() - : arith::ExtSIOp::create(builder, loc, resType, biasVal) - .getResult(); - } - linalg::YieldOp::create(builder, loc, biasVal); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({source}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + resultTy.getElementType().isFloat() + ? arith::ExtFOp::create(builder, loc, resType, biasVal) + .getResult() + : arith::ExtSIOp::create(builder, loc, resType, + biasVal) + .getResult(); + } + linalg::YieldOp::create(builder, loc, biasVal); + }) .getResult(0); } @@ -397,21 +398,19 @@ public: auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); - Value conv = - rewriter - .create<LinalgConvQOp>( - loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) - ->getResult(0); + Value conv = LinalgConvQOp::create( + rewriter, loc, resultTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) + ->getResult(0); rewriter.replaceOp(op, conv); return success(); } - Value conv = rewriter - .create<LinalgConvOp>( - loc, accTy, ValueRange{input, weight}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) + Value conv = LinalgConvOp::create( + rewriter, loc, accTy, ValueRange{input, weight}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) ->getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -529,9 +528,8 @@ public: Value emptyTensor = tensor::EmptyOp::create( rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); Value biasEmptyTensor = tensor::EmptyOp::create( @@ -544,10 +542,9 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); if (hasNullZps) { - Value conv = rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmOp>( - loc, linalgConvTy, ValueRange{input, weight}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) + Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create( + rewriter, loc, linalgConvTy, ValueRange{input, weight}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -565,22 +562,20 @@ public: rewriter, loc, resultTy, conv, reassociationMap); Value result = - rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, convReshape}), - biasEmptyTensor, indexingMaps, - getNParallelLoopsAttrs(resultRank), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added; - if (llvm::isa<FloatType>(inputETy)) - added = arith::AddFOp::create(nestedBuilder, loc, args[0], - args[1]); - else - added = arith::AddIOp::create(nestedBuilder, loc, args[0], - args[1]); - linalg::YieldOp::create(nestedBuilder, nestedLoc, added); - }) + linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, convReshape}), + biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added; + if (llvm::isa<FloatType>(inputETy)) + added = arith::AddFOp::create(nestedBuilder, loc, args[0], + args[1]); + else + added = arith::AddIOp::create(nestedBuilder, loc, args[0], + args[1]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, added); + }) .getResult(0); rewriter.replaceOp(op, result); } else { @@ -588,12 +583,11 @@ public: IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); - Value conv = - rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( - loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) - .getResult(0); + Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create( + rewriter, loc, linalgConvTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + .getResult(0); SmallVector<ReassociationExprs, 4> reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); Value convReshape = tensor::CollapseShapeOp::create( @@ -639,9 +633,8 @@ public: auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); FailureOr<int64_t> maybeAZp = op.getAZeroPoint(); @@ -910,20 +903,18 @@ public: rewriter, loc, accTy.getShape(), accETy, dynamicDims); Value filledEmptyTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{initialValue}, - ValueRange{poolEmptyTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{initialValue}, + ValueRange{poolEmptyTensor}) .result(); Value fakeWindowDims = tensor::EmptyOp::create(rewriter, loc, kernel, accETy); // Sum across the pooled region. - Value poolingOp = rewriter - .create<linalg::PoolingNhwcSumOp>( - loc, ArrayRef<Type>{accTy}, - ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr) + Value poolingOp = linalg::PoolingNhwcSumOp::create( + rewriter, loc, ArrayRef<Type>{accTy}, + ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr) .getResult(0); // Normalize the summed value by the number of elements grouped in each @@ -1050,10 +1041,9 @@ public: Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); auto scaled = - rewriter - .create<tosa::ApplyScaleOp>( - loc, rewriter.getI32Type(), poolVal, multiplier, shift, - rewriter.getStringAttr("SINGLE_ROUND")) + tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), poolVal, multiplier, + shift, rewriter.getStringAttr("SINGLE_ROUND")) .getResult(); // If we have quantization information we need to apply output diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp index b83f5ec9..f8efb34 100644 --- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp +++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp @@ -13,7 +13,6 @@ #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 125ea1e..9efa34a 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering } }; -/// Conversion pattern for vector.splat. -/// -/// Example: -/// -/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> -/// -/// is converted to: -/// -/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> -/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices -/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) -/// { -/// %tile_update = arm_sme.insert_tile_slice -/// %broadcast_to_1d, %iter_tile[%tile_slice_index] : -/// vector<[4]xi32> into vector<[4]x[4]xi32> -/// scf.yield %tile_update : vector<[4]x[4]xi32> -/// } -/// -/// This is identical to vector.broadcast of a scalar. -struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { - using OpRewritePattern<vector::SplatOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::SplatOp splatOp, - PatternRewriter &rewriter) const final { - auto tileType = splatOp.getResult().getType(); - if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) - return failure(); - - auto loc = splatOp.getLoc(); - auto srcType = splatOp.getOperand().getType(); - - assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); - // Avoid unused-variable warning when building without assertions. - (void)srcType; - - // First, broadcast the scalar to a 1-d vector. - VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - Value broadcastOp1D = vector::BroadcastOp::create( - rewriter, loc, tileSliceType, splatOp.getInput()); - - auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); - - auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, - Value currentTile) { - auto nextTile = arm_sme::InsertTileSliceOp::create( - b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); - return nextTile.getResult(); - }; - - // Next, create a loop over ZA tile slices and "move" the generated 1-d - // vector to each slice. - auto forOp = - createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); - - rewriter.replaceOp(splatOp, forOp.getResult(0)); - - return success(); - } -}; - /// Conversion pattern for vector.transpose. /// /// Stores the input tile to memory and reloads vertically. @@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering } }; +// Convert all `vector.splat` to `vector.broadcast`. There is a path from +// `vector.broadcast` to ArmSME via another pattern. +struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> { + using OpRewritePattern<vector::SplatOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::SplatOp splatOp, + PatternRewriter &rewriter) const final { + + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), + splatOp.getInput()); + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, + patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast, TransferReadToArmSMELowering, TransferWriteToArmSMELowering, TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 77aab85..1d1904f 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -31,10 +31,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-to-gpu" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU @@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op, // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { if (!supportsMMaMatrixType(op, useNvGpu)) { - LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); + LDBG() << "cannot convert op: " << *op; return true; } return false; @@ -482,14 +481,12 @@ struct CombineTransferReadOpTranspose final permutationMap.compose(transferReadOp.getPermutationMap()); auto loc = op.getLoc(); - Value result = - rewriter - .create<vector::TransferReadOp>( - loc, resultType, transferReadOp.getBase(), - transferReadOp.getIndices(), AffineMapAttr::get(newMap), - transferReadOp.getPadding(), transferReadOp.getMask(), - transferReadOp.getInBoundsAttr()) - .getResult(); + Value result = vector::TransferReadOp::create( + rewriter, loc, resultType, transferReadOp.getBase(), + transferReadOp.getIndices(), AffineMapAttr::get(newMap), + transferReadOp.getPadding(), transferReadOp.getMask(), + transferReadOp.getInBoundsAttr()) + .getResult(); // Fuse through the integer extend op. if (extOp) { @@ -550,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } @@ -585,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; - LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); + LDBG() << "transfer read to: " << load; return success(); } @@ -599,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no mapping\n"); + LDBG() << "no mapping"; return rewriter.notifyMatchFailure(op, "no mapping"); } @@ -615,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; - LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); + LDBG() << "transfer write to: " << store; - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -643,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); if (!dense) { - LLVM_DEBUG(DBGS() << "not a splat\n"); + LDBG() << "not a splat"; return rewriter.notifyMatchFailure(op, "not a splat"); } @@ -679,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { mlir::AffineMap map = op.getPermutationMap(); if (map.getNumResults() != 2) { - LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " - "is not a 2d operand\n"); + LDBG() << "Failed because the result of `vector.transfer_read` " + "is not a 2d operand"; return failure(); } @@ -693,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { auto exprN = dyn_cast<AffineDimExpr>(dN); if (!exprM || !exprN) { - LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " - "expressions, then transpose cannot be determined.\n"); + LDBG() << "Failed because expressions are not affine dim " + "expressions, then transpose cannot be determined."; return failure(); } @@ -711,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } FailureOr<bool> transpose = isTransposed(op); if (failed(transpose)) { - LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); + LDBG() << "failed to determine the transpose"; return rewriter.notifyMatchFailure( op, "Op should likely not be converted to a nvgpu.ldmatrix call."); } @@ -733,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); if (failed(params)) { - LLVM_DEBUG( - DBGS() - << "failed to convert vector.transfer_read to ldmatrix. " - << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); + LDBG() << "failed to convert vector.transfer_read to ldmatrix. " + << "Op should likely not be converted to a nvgpu.ldmatrix call."; return rewriter.notifyMatchFailure( op, "failed to convert vector.transfer_read to ldmatrix; this op " "likely should not be converted to a nvgpu.ldmatrix call."); @@ -747,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<AffineMap> offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { - LLVM_DEBUG(DBGS() << "no offsets\n"); + LDBG() << "no offsets"; return rewriter.notifyMatchFailure(op, "no offsets"); } @@ -936,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1134,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, loop.getNumResults()))) rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); - LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); - LLVM_DEBUG(DBGS() << "erase: " << loop); + LDBG() << "newLoop now: " << newLoop; + LDBG() << "stripped scf.for: " << loop; + LDBG() << "erase: " << loop; rewriter.eraseOp(loop); return newLoop; @@ -1152,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, for (const auto &operand : llvm::enumerate(op.getInitArgs())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); + LDBG() << "no value mapping for: " << operand.value(); continue; } argMapping.push_back(std::make_pair( @@ -1170,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } - LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); + LDBG() << "scf.for to: " << newForOp; return success(); } @@ -1193,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, } scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1246,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, auto globalRes = LogicalResult::success(); for (Operation *op : ops) { - LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); + LDBG() << "Process op: " << *op; // Apparently callers do not want to early exit on failure here. auto res = LogicalResult::success(); if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9cd491c..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -29,7 +29,9 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/APFloat.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/Casting.h" + #include <optional> using namespace mlir; @@ -1068,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1204,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2242,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index d3d0a45..cf10869 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -96,13 +96,16 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorGatherLoweringPatterns(patterns); if (armI8MM) { if (armNeon) - arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns); + arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); if (armSVE) - populateLowerContractionToSVEI8MMPatternPatterns(patterns); + populateLowerContractionToSVEI8MMPatterns(patterns); + } + if (armBF16) { + if (armNeon) + arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns); + if (armSVE) + populateLowerContractionToSVEBFMMLAPatterns(patterns); } - if (armBF16) - populateLowerContractionToSVEBFMMLAPatterns(patterns); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 4c1047a..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -24,7 +24,6 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -691,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1644,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 2411af0..4dfcb2b 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -10,7 +10,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" |