diff options
Diffstat (limited to 'mlir/lib/Conversion')
29 files changed, 2090 insertions, 465 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 3a307a0..7584b17 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -16,8 +16,10 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -42,6 +44,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8); constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); constexpr Chipset kGfx942 = Chipset(9, 4, 2); constexpr Chipset kGfx950 = Chipset(9, 5, 0); +constexpr Chipset kGfx1250 = Chipset(12, 5, 0); /// Convert an unsigned number `val` to i32. static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, @@ -79,12 +82,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter, return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value); } -static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, - bool value) { - Type llvmI1 = rewriter.getI1Type(); - return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); -} - /// Returns the linear index used to access an element in the memref. static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, @@ -509,10 +506,16 @@ struct MemoryCounterWaitOpLowering if (std::optional<int> exp = adaptor.getExp()) ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); + if (std::optional<int> tensor = adaptor.getTensor()) + ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor); + rewriter.eraseOp(op); return success(); } + if (adaptor.getTensor()) + return op.emitOpError("unsupported chipset"); + auto getVal = [](Attribute attr) -> unsigned { if (attr) return cast<IntegerAttr>(attr).getInt(); @@ -684,12 +687,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, /// intrinsics having been defined before the AMD backend supported bfloat. We /// similarly need to pack 8-bit float types into integers as if they were i8 /// (which they are for the backend's purposes). -static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, - Location loc, - const TypeConverter *typeConverter, - bool isUnsigned, Value llvmInput, - Value mlirInput, - SmallVector<Value, 4> &operands) { +static void wmmaPushInputOperand( + ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, + Value mlirInput, SmallVectorImpl<Value> &operands, + SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast<VectorType>(inputType); if (!vectorType) { @@ -697,10 +699,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, return; } Type elemType = vectorType.getElementType(); - - if (elemType.isBF16()) - llvmInput = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -719,8 +717,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, } else if (elemType.isSignedInteger()) { localIsUnsigned = false; } - Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); - operands.push_back(sign); + attrs.push_back( + NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned))); } int64_t numBits = @@ -751,18 +749,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, - bool clamp, SmallVector<Value, 4> &operands) { + bool clamp, SmallVectorImpl<Value> &operands, + SmallVectorImpl<NamedAttribute> &attrs) { Type inputType = output.getType(); auto vectorType = dyn_cast<VectorType>(inputType); Type elemType = vectorType.getElementType(); - if (elemType.isBF16()) - output = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { - operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); + attrs.push_back( + NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset))); } else if (elemType.isInteger(32)) { - operands.push_back(createI1Constant(rewriter, loc, clamp)); + attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp))); } } @@ -1160,7 +1157,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, k, isRDNA3); // Handle gfx1250. - if (chipset == Chipset{12, 5, 0}) + if (chipset == kGfx1250) return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, elemDestType, k); @@ -1311,11 +1308,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); - // The WMMA operations represent vectors of bf16s as vectors of i16s, so we - // need to bitcast bfloats to i16 and then bitcast them back. + bool isGFX1250 = chipset >= kGfx1250; + + // The WMMA operations represent vectors of bf16s as vectors of i16s + // (except on gfx1250), so we need to bitcast bfloats to i16 and then + // bitcast them back. + auto aType = cast<VectorType>(adaptor.getSourceA().getType()); + auto bType = cast<VectorType>(adaptor.getSourceB().getType()); + auto destCType = cast<VectorType>(adaptor.getDestC().getType()); + bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250; + bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250; + bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250; + bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250; VectorType rawOutType = outType; - if (outType.getElementType().isBF16()) + if (castOutToI16) rawOutType = outType.clone(rewriter.getI16Type()); + Value a = adaptor.getSourceA(); + if (castAToI16) + a = LLVM::BitcastOp::create(rewriter, loc, + aType.clone(rewriter.getI16Type()), a); + Value b = adaptor.getSourceB(); + if (castBToI16) + b = LLVM::BitcastOp::create(rewriter, loc, + bType.clone(rewriter.getI16Type()), b); + Value destC = adaptor.getDestC(); + if (castDestCToI16) + destC = LLVM::BitcastOp::create( + rewriter, loc, destCType.clone(rewriter.getI16Type()), destC); std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); @@ -1325,18 +1344,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) return op.emitOpError("subwordOffset not supported on gfx12+"); - OperationState loweredOp(loc, *maybeIntrinsic); - loweredOp.addTypes(rawOutType); - SmallVector<Value, 4> operands; - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), - adaptor.getSourceA(), op.getSourceA(), operands); - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), - adaptor.getSourceB(), op.getSourceB(), operands); - wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), - op.getSubwordOffset(), op.getClamp(), operands); + SmallVector<NamedAttribute, 4> attrs; + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a, + op.getSourceA(), operands, attrs, "signA"); + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b, + op.getSourceB(), operands, attrs, "signB"); + wmmaPushOutputOperand(rewriter, loc, typeConverter, destC, + op.getSubwordOffset(), op.getClamp(), operands, + attrs); + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes(rawOutType); loweredOp.addOperands(operands); + loweredOp.addAttributes(attrs); Operation *lowered = rewriter.create(loweredOp); Operation *maybeCastBack = lowered; @@ -1492,6 +1513,20 @@ struct ExtPackedFp8OpLowering final ConversionPatternRewriter &rewriter) const override; }; +struct ScaledExtPackedMatrixOpLowering final + : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> { + ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(ScaledExtPackedMatrixOp op, + ScaledExtPackedMatrixOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + struct PackedTrunc2xFp8OpLowering final : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, @@ -1600,6 +1635,173 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( return success(); } +int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf, + int32_t firstScaleByte) { + // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f* + // operations, the attributes blockSize, sourceType, scaleWaveHalf, and + // firstScaleByte are merged into a single attribute scaleSel. This is how + // those values are merged together. (Note: scaleWaveHalf isn't a high-level + // attribute but is derifed from firstScaleLane). + assert(llvm::is_contained({16, 32}, blockSize)); + assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth)); + + const bool isFp8 = bitWidth == 8; + const bool isBlock16 = blockSize == 16; + + if (!isFp8) { + int32_t bit0 = isBlock16; + assert(llvm::is_contained({0, 1, 2}, firstScaleByte)); + int32_t bit1 = (firstScaleByte == 2) << 1; + assert(llvm::is_contained({0, 1}, scaleWaveHalf)); + int32_t bit2 = scaleWaveHalf << 2; + return bit2 | bit1 | bit0; + } + + int32_t bit0 = isBlock16; + // firstScaleByte is guaranteed to be defined by two bits. + assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); + int32_t bits2and1 = firstScaleByte << 1; + assert(llvm::is_contained({0, 1}, scaleWaveHalf)); + int32_t bit3 = scaleWaveHalf << 3; + int32_t bits = bit3 | bits2and1 | bit0; + // These are invalid cases. + assert(!llvm::is_contained( + {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); + return bits; +} + +static std::optional<StringRef> +scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; + if (isa<fp4>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); + return std::nullopt; + } + if (isa<fp8>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); + return std::nullopt; + } + if (isa<bf8>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); + return std::nullopt; + } + if (isa<fp6>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); + return std::nullopt; + } + if (isa<bf6>(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); + return std::nullopt; + } + llvm_unreachable("invalid combination of element types for packed conversion " + "instructions"); +} + +LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite( + ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; + Location loc = op.getLoc(); + if (chipset != kGfx1250) { + return rewriter.notifyMatchFailure( + loc, + "Scaled fp packed conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + } + // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that + // is being selected. + int32_t scaleWaveHalf = op.getFirstScaleLane() / 16; + int32_t firstScaleByte = op.getFirstScaleByte(); + int32_t blockSize = op.getBlockSize(); + auto sourceType = cast<VectorType>(op.getSource().getType()); + auto srcElemType = cast<FloatType>(sourceType.getElementType()); + unsigned bitWidth = srcElemType.getWidth(); + + auto targetType = cast<VectorType>(op.getResult().getType()); + auto destElemType = cast<FloatType>(targetType.getElementType()); + + IntegerType i32 = rewriter.getI32Type(); + Value source = adaptor.getSource(); + Type llvmResultType = typeConverter->convertType(op.getResult().getType()); + Type packedType = nullptr; + if (isa<fp4>(srcElemType)) { + packedType = i32; + packedType = getTypeConverter()->convertType(packedType); + } else if (isa<fp8, bf8>(srcElemType)) { + packedType = VectorType::get(2, i32); + packedType = getTypeConverter()->convertType(packedType); + } else if (isa<fp6, bf6>(srcElemType)) { + packedType = VectorType::get(3, i32); + packedType = getTypeConverter()->convertType(packedType); + } else { + llvm_unreachable("invalid element type for packed scaled ext"); + } + + if (!packedType || !llvmResultType) { + return rewriter.notifyMatchFailure(op, "type conversion failed"); + } + + std::optional<StringRef> maybeIntrinsic = + scaledExtPacked816ToIntrinsic(srcElemType, destElemType); + if (!maybeIntrinsic.has_value()) + return op.emitOpError( + "no intrinsic matching packed scaled conversion on the given chipset"); + + int32_t scaleSel = + getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte); + Value castedScale = + LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); + + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes({llvmResultType}); + loweredOp.addOperands({castedSource, castedScale}); + + SmallVector<NamedAttribute, 1> attrs; + attrs.push_back( + NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel))); + + loweredOp.addAttributes(attrs); + Operation *lowered = rewriter.create(loweredOp); + rewriter.replaceOp(op, lowered); + + return success(); +} + LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -2073,6 +2275,441 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> { } }; +struct AMDGPUMakeDmaBaseLowering + : public ConvertOpToLLVMPattern<MakeDmaBaseOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx1250) + return op->emitOpError("make_dma_base is only supported on gfx1250"); + + Location loc = op.getLoc(); + + ValueRange ldsIndices = adaptor.getLdsIndices(); + Value lds = adaptor.getLds(); + auto ldsMemRefType = cast<MemRefType>(op.getLds().getType()); + + Value ldsPtr = + getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices); + + ValueRange globalIndices = adaptor.getGlobalIndices(); + Value global = adaptor.getGlobal(); + auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType()); + + Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType, + global, globalIndices); + + Type i32 = rewriter.getI32Type(); + Type i64 = rewriter.getI64Type(); + + Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr); + Value castForGlobalAddr = + LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr); + + Value lowHalf = + LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr); + + Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr, + createI64Constant(rewriter, loc, 32)); + + Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift); + + Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1); + Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask); + + Value typeField = createI32Constant(rewriter, loc, 2 << 30); + Value highHalfPlusType = + LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField); + + Value c0 = createI32Constant(rewriter, loc, 0); + Value c1 = createI32Constant(rewriter, loc, 1); + Value c2 = createI32Constant(rewriter, loc, 2); + Value c3 = createI32Constant(rewriter, loc, 3); + + Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32)); + assert(v4i32 && "expected type conversion to succeed"); + Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32); + result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + castForLdsAddr, c1); + result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + highHalfPlusType, c3); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct AMDGPUMakeDmaDescriptorLowering + : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter), + chipset(chipset) {} + Chipset chipset; + + Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); } + + Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc, + Value accumulator, Value value, int64_t shift) const { + shift = shift % 32; + Value shiftAmount; + if (shift != 0) { + shiftAmount = createI32Constant(rewriter, loc, shift % 32); + value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount); + } + + if (matchPattern(accumulator, mlir::m_Zero())) + return value; + + return LLVM::OrOp::create(rewriter, loc, accumulator, value); + } + + Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0) const { + Value mask = op.getWorkgroupMask(); + if (!mask) + return sgpr0; + + Type i32 = rewriter.getI32Type(); + Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask); + return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0); + } + + Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + // Compute data_size. + unsigned elementTypeWidthInBits = op.getElementTypeWidth(); + assert( + llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) && + "expected type width to be 8, 16, 32, or 64."); + int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8); + Value size = createI32Constant(rewriter, loc, dataSize); + return setValueAtOffset(rewriter, loc, sgpr0, size, 16); + } + + Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr; + if (!atomic_barrier_enable) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18); + } + + Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool iterate_enable = adaptor.getGlobalIncrement() != nullptr; + if (!iterate_enable) + return sgpr0; + + // TODO: In future PR, add other required fields for iteration. + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19); + } + + Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20); + } + + Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + if (!op.getWorkgroupMask()) + return sgpr0; + + return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21); + } + + Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + IntegerType i32 = rewriter.getI32Type(); + Value padInterval = adaptor.getPadInterval(); + // pre-condition: padInterval can be a power of two between 2 and 256. + padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32, + padInterval, false); + padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]); + // post-condition: padInterval can be a value between 0 and 7. + return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22); + } + + Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr0, ArrayRef<Value> consts) const { + bool pad_enable = op.getPadAmount() != nullptr; + if (!pad_enable) + return sgpr0; + + Value padAmount = adaptor.getPadAmount(); + // pre-condition: padAmount is a value between 1-128. + padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]); + // post-condition: padAmount is a value between 0-127. + return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25); + } + + Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr1, + ArrayRef<Value> consts) const { + bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr; + if (!atomic_barrier_enable) + return sgpr1; + + Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress(); + auto barrierAddressTy = + cast<MemRefType>(op.getAtomicBarrierAddress().getType()); + ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices(); + atomicBarrierAddress = + getStridedElementPtr(rewriter, loc, barrierAddressTy, + atomicBarrierAddress, atomicBarrierIndices); + IntegerType i32 = rewriter.getI32Type(); + // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies + // that the 3 LSBs are zero. + atomicBarrierAddress = + LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress); + atomicBarrierAddress = + LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]); + Value mask = createI32Constant(rewriter, loc, 0xFFFF); + atomicBarrierAddress = + LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask); + return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32); + } + + std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr1, Value sgpr2, + ArrayRef<Value> consts) const { + SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes(); + OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back(); + Value tensorDim0; + if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult)) + tensorDim0 = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDim0 = cast<Value>(tensorDim0OpFoldResult); + + Value c16 = createI32Constant(rewriter, loc, 16); + Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16); + sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48); + sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16); + return {sgpr1, sgpr2}; + } + + std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Location loc, Value sgpr2, Value sgpr3, + ArrayRef<Value> consts) const { + // TODO: Generalize to setTensorDimX. + SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes(); + OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1); + Value tensorDim1; + if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult)) + tensorDim1 = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDim1 = cast<Value>(tensorDim1OpFoldResult); + + Value c16 = createI32Constant(rewriter, loc, 16); + Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16); + sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80); + sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16); + return {sgpr2, sgpr3}; + } + + Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr, ArrayRef<Value> consts, size_t dimX, + int64_t offset) const { + SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes(); + + if (mixedSharedSizes.size() <= dimX) + return sgpr; + + OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX); + Value tileDimX; + if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) + tileDimX = + createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tileDimX = cast<Value>(tileDimXOpFoldResult); + + return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset); + } + + Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr3, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112); + } + + Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr4, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128); + } + + Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr4, ArrayRef<Value> consts) const { + return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144); + } + + std::pair<Value, Value> + setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgprY, Value sgprZ, ArrayRef<Value> consts, + size_t dimX, int64_t offset) const { + SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides(); + + if (mixedGlobalStrides.size() <= dimX) + return {sgprY, sgprZ}; + + OpFoldResult tensorDimXStrideOpFoldResult = + *(mixedGlobalStrides.rbegin() + dimX); + Value tensorDimXStride; + if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult)) + tensorDimXStride = + createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt()); + else + tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult); + + constexpr int64_t first48bits = (1ll << 48) - 1; + Value mask = createI64Constant(rewriter, loc, first48bits); + tensorDimXStride = + LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride); + IntegerType i32 = rewriter.getI32Type(); + Value tensorDimXStrideLow = + LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride); + + int64_t shift = (offset % 32) == 0 ? 32 : offset % 32; + Value shiftVal = createI64Constant(rewriter, loc, shift); + Value tensorDimXStrideHigh = + LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal); + tensorDimXStrideHigh = + LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh); + + sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset); + sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh, + offset + shift); + return {sgprY, sgprZ}; + } + + std::pair<Value, Value> + setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const { + return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, + 0, 160); + } + + std::pair<Value, Value> + setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const { + return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts, + 1, 208); + } + + Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<Value> consts) const { + Value sgprs[8]; + for (int64_t i = 0; i < 8; i++) { + sgprs[i] = consts[0]; + } + + sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]); + sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts); + sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts); + + sgprs[1] = + setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts); + std::tie(sgprs[1], sgprs[2]) = + setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts); + std::tie(sgprs[2], sgprs[3]) = + setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts); + + sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts); + sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts); + sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts); + std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride( + op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts); + std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride( + op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts); + + IntegerType i32 = rewriter.getI32Type(); + Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32)); + assert(v8i32 && "expected type conversion to succeed"); + Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32); + + for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) { + dgroup1 = + LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant); + } + + return dgroup1; + } + + LogicalResult + matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx1250) + return op->emitOpError( + "make_dma_descriptor is only supported on gfx1250"); + + if (op.getRank() > 2) + return op->emitOpError("unimplemented"); + + Location loc = op.getLoc(); + + IntegerType i32 = rewriter.getI32Type(); + [[maybe_unused]] Type v4i32 = + this->typeConverter->convertType(VectorType::get(4, i32)); + assert(v4i32 && "expected type conversion to succeed"); + + SmallVector<Value> consts; + for (int64_t i = 0; i < 8; i++) + consts.push_back(createI32Constant(rewriter, loc, i)); + + Value dgroup0 = this->getDGroup0(adaptor); + Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts); + + SmallVector<Value> results = {dgroup0, dgroup1}; + rewriter.replaceOpWithMultiple(op, {results}); + return success(); + } +}; + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> { using Base::Base; @@ -2087,6 +2724,11 @@ struct ConvertAMDGPUToROCDLPass RewritePatternSet patterns(ctx); LLVMTypeConverter converter(ctx); + converter.addConversion([&](TDMBaseType type) -> Type { + Type i32 = IntegerType::get(type.getContext(), 32); + return converter.convertType(VectorType::get(4, i32)); + }); + populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); LLVMConversionTarget target(getContext()); target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); @@ -2122,25 +2764,27 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { populateAMDGPUMemorySpaceAttributeConversions(converter); - patterns - .add<FatRawBufferCastLowering, - RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, - RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, - RawBufferOpLowering<RawBufferAtomicFaddOp, - ROCDL::RawPtrBufferAtomicFaddOp>, - RawBufferOpLowering<RawBufferAtomicFmaxOp, - ROCDL::RawPtrBufferAtomicFmaxOp>, - RawBufferOpLowering<RawBufferAtomicSmaxOp, - ROCDL::RawPtrBufferAtomicSmaxOp>, - RawBufferOpLowering<RawBufferAtomicUminOp, - ROCDL::RawPtrBufferAtomicUminOp>, - RawBufferOpLowering<RawBufferAtomicCmpswapOp, - ROCDL::RawPtrBufferAtomicCmpSwap>, - AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, - SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, - WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, - PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, - PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, - TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); + patterns.add< + FatRawBufferCastLowering, + RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>, + RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>, + RawBufferOpLowering<RawBufferAtomicFaddOp, + ROCDL::RawPtrBufferAtomicFaddOp>, + RawBufferOpLowering<RawBufferAtomicFmaxOp, + ROCDL::RawPtrBufferAtomicFmaxOp>, + RawBufferOpLowering<RawBufferAtomicSmaxOp, + ROCDL::RawPtrBufferAtomicSmaxOp>, + RawBufferOpLowering<RawBufferAtomicUminOp, + ROCDL::RawPtrBufferAtomicUminOp>, + RawBufferOpLowering<RawBufferAtomicCmpswapOp, + ROCDL::RawPtrBufferAtomicCmpSwap>, + AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering, + ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, + GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering, + AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter, + chipset); patterns.add<AMDGPUSwizzleBitModeLowering>(converter); } diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp new file mode 100644 index 0000000..79816fc --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -0,0 +1,665 @@ +//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::func; + +static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, FunctionType funcT, bool setPrivate, + SymbolTableCollection *symbolTables = nullptr) { + OpBuilder::InsertionGuard g(b); + assert(!symTable->getRegion(0).empty() && "expected non-empty region"); + b.setInsertionPointToStart(&symTable->getRegion(0).front()); + FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT); + if (setPrivate) + funcOp.setPrivate(); + if (symbolTables) { + SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable); + symbolTable.insert(funcOp, symTable->getRegion(0).front().begin()); + } + return funcOp; +} + +/// Helper function to look up or create the symbol for a runtime library +/// function with the given parameter types. Returns an int64_t, unless a +/// different result type is specified. +static FailureOr<FuncOp> +lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, TypeRange paramTypes, + SymbolTableCollection *symbolTables = nullptr, + Type resultType = {}) { + if (!resultType) + resultType = IntegerType::get(symTable->getContext(), 64); + std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); + FailureOr<FuncOp> func = + lookupFnDecl(symTable, funcName, funcT, symbolTables); + // Failed due to type mismatch. + if (failed(func)) + return func; + // Successfully matched existing decl. + if (*func) + return *func; + + return createFnDecl(b, symTable, funcName, funcT, + /*setPrivate=*/true, symbolTables); +} + +/// Helper function to look up or create the symbol for a runtime library +/// function for a binary arithmetic operation. +/// +/// Parameter 1: APFloat semantics +/// Parameter 2: Left-hand side operand +/// Parameter 3: Right-hand side operand +/// +/// This function will return a failure if the function is found but has an +/// unexpected signature. +/// +static FailureOr<FuncOp> +lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, + SymbolTableCollection *symbolTables = nullptr) { + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type}, + symbolTables); +} + +static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) { + int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + return arith::ConstantOp::create(b, loc, b.getI32Type(), + b.getIntegerAttr(b.getI32Type(), sem)); +} + +/// Given two operands of vector type and vector result type (with the same +/// shape), call the given function for each pair of scalar operands and +/// package the result into a vector. If the given operands and result type are +/// not vectors, call the function directly. The second operand is optional. +template <typename Fn, typename... Values> +static Value forEachScalarValue(RewriterBase &rewriter, Location loc, + Value operand1, Value operand2, Type resultType, + Fn fn) { + auto vecTy1 = dyn_cast<VectorType>(operand1.getType()); + if (operand2) { + // Sanity check: Operand types must match. + assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) && + "expected same vector types"); + } + if (!vecTy1) { + // Not a vector. Call the function directly. + return fn(operand1, operand2, resultType); + } + + // Prepare scalar operands. + ResultRange sclars1 = + vector::ToElementsOp::create(rewriter, loc, operand1)->getResults(); + SmallVector<Value> scalars2; + if (!operand2) { + // No second operand. Create a vector of empty values. + scalars2.assign(vecTy1.getNumElements(), Value()); + } else { + llvm::append_range( + scalars2, + vector::ToElementsOp::create(rewriter, loc, operand2)->getResults()); + } + + // Call the function for each pair of scalar operands. + auto resultVecType = cast<VectorType>(resultType); + SmallVector<Value> results; + for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) { + Value result = fn(scalar1, scalar2, resultVecType.getElementType()); + results.push_back(result); + } + + // Package the results into a vector. + return vector::FromElementsOp::create( + rewriter, loc, + vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()), + results); +} + +/// Check preconditions for the conversion: +/// 1. All operands / results must be integers or floats (or vectors thereof). +/// 2. The bitwidth of the operands / results must be <= 64. +static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) { + for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) { + Type type = value.getType(); + if (auto vecTy = dyn_cast<VectorType>(type)) { + type = vecTy.getElementType(); + } + if (!type.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + op, "only integers and floats (or vectors thereof) are supported"); + } + if (type.getIntOrFloatBitWidth() > 64) + return rewriter.notifyMatchFailure(op, + "bitwidth > 64 bits is not supported"); + } + return success(); +} + +/// Rewrite a binary arithmetic operation to an APFloat function call. +template <typename OpTy> +struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> { + BinaryArithOpToAPFloatConversion(MLIRContext *context, + const char *APFloatName, + SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + APFloatName(APFloatName) {}; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + FailureOr<FuncOp> fn = + lookupOrCreateBinaryFn(rewriter, symTable, APFloatName); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), + [&](Value lhs, Value rhs, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(resultType); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, lhs)); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, intWType, rhs)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, + resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, floatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + const char *APFloatName; +}; + +template <typename OpTy> +struct FpToFpConversion final : OpRewritePattern<OpTy> { + FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = lookupOrCreateApFloatFn( + rewriter, symTable, "convert", {i32Type, i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inFloatTy = cast<FloatType>(operand1.getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outFloatTy = cast<FloatType>(resultType); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + std::array<Value, 3> params = {inSemValue, outSemValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create( + rewriter, loc, outIntWType, resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, outFloatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +template <typename OpTy> +struct FpToIntConversion final : OpRewritePattern<OpTy> { + FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + isUnsigned(isUnsigned) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inFloatTy = cast<FloatType>(operand1.getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, operand1)); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outIntTy = cast<IntegerType>(resultType); + Value outWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, outIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector<Value> params = {inSemValue, outWidthValue, + isUnsignedValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + return arith::TruncIOp::create(rewriter, loc, outIntTy, + resultOp->getResult(0)); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + +template <typename OpTy> +struct IntToFpConversion final : OpRewritePattern<OpTy> { + IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + bool isUnsigned, PatternBenefit benefit = 1) + : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), + isUnsigned(isUnsigned) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int", + {i32Type, i32Type, i1Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto inIntTy = cast<IntegerType>(operand1.getType()); + Value operandBits = operand1; + if (operandBits.getType().getIntOrFloatBitWidth() < 64) { + if (isUnsigned) { + operandBits = + arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits); + } else { + operandBits = + arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits); + } + } + + // Call APFloat function. + auto outFloatTy = cast<FloatType>(resultType); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + Value inWidthValue = arith::ConstantOp::create( + rewriter, loc, i32Type, + rewriter.getIntegerAttr(i32Type, inIntTy.getWidth())); + Value isUnsignedValue = arith::ConstantOp::create( + rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, isUnsigned)); + SmallVector<Value> params = {outSemValue, inWidthValue, + isUnsignedValue, operandBits}; + auto resultOp = func::CallOp::create(rewriter, loc, + TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create( + rewriter, loc, outIntWType, resultOp->getResult(0)); + return arith::BitcastOp::create(rewriter, loc, outFloatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; + bool isUnsigned; +}; + +struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> { + CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i8Type = IntegerType::get(symTable->getContext(), 8); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "compare", + {i32Type, i64Type, i64Type}, nullptr, i8Type); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getLhs(), op.getRhs(), op.getType(), + [&](Value lhs, Value rhs, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(lhs.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, lhs)); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, rhs)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + Value comparisonResult = + func::CallOp::create(rewriter, loc, TypeRange(i8Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Generate an i1 SSA value that is "true" if the comparison result + // matches the given `val`. + auto checkResult = [&](llvm::APFloat::cmpResult val) { + return arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, comparisonResult, + arith::ConstantOp::create( + rewriter, loc, i8Type, + rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val))) + .getResult()); + }; + // Generate an i1 SSA value that is "true" if the comparison result + // matches any of the given `vals`. + std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> + checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) { + Value first = checkResult(vals.front()); + if (vals.size() == 1) + return first; + Value rest = checkResults(vals.drop_front()); + return arith::OrIOp::create(rewriter, loc, first, rest) + .getResult(); + }; + + // This switch-case statement was taken from arith::applyCmpPredicate. + Value result; + switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: + result = + arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 0)) + .getResult(); + break; + case arith::CmpFPredicate::OEQ: + result = checkResult(llvm::APFloat::cmpEqual); + break; + case arith::CmpFPredicate::OGT: + result = checkResult(llvm::APFloat::cmpGreaterThan); + break; + case arith::CmpFPredicate::OGE: + result = checkResults( + {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::OLT: + result = checkResult(llvm::APFloat::cmpLessThan); + break; + case arith::CmpFPredicate::OLE: + result = checkResults( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ONE: + // Not cmpUnordered and not cmpUnordered. + result = checkResults( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::ORD: + // Not cmpUnordered. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UEQ: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UGT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::UGE: + result = checkResults({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ULT: + result = checkResults( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan}); + break; + case arith::CmpFPredicate::ULE: + result = checkResults({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UNE: + // Not cmpEqual. + result = checkResults({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpUnordered}); + break; + case arith::CmpFPredicate::UNO: + result = checkResult(llvm::APFloat::cmpUnordered); + break; + case arith::CmpFPredicate::AlwaysTrue: + result = + arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 1)) + .getResult(); + break; + } + return result; + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> { + NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::NegFOp op, + PatternRewriter &rewriter) const override { + if (failed(checkPreconditions(rewriter, op))) + return failure(); + + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type}); + if (failed(fn)) + return fn; + + // Scalarize and convert to APFloat runtime calls. + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + Value repl = forEachScalarValue( + rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(), + [&](Value operand1, Value operand2, Type resultType) { + // Cast operands to 64-bit integers. + auto floatTy = cast<FloatType>(operand1.getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, operand1)); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, operandBits}; + Value negatedBits = + func::CallOp::create(rewriter, loc, TypeRange(i64Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Truncate result to the original width. + Value truncatedBits = + arith::TruncIOp::create(rewriter, loc, intWType, negatedBits); + return arith::BitcastOp::create(rewriter, loc, floatTy, + truncatedBits); + }); + rewriter.replaceOp(op, repl); + return success(); + } + + SymbolOpInterface symTable; +}; + +namespace { +struct ArithToAPFloatConversionPass final + : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { + using Base::Base; + + void runOnOperation() override; +}; + +void ArithToAPFloatConversionPass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add", + getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>( + context, "subtract", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>( + context, "multiply", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>( + context, "divide", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>( + context, "remainder", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>( + context, "minnum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>( + context, "maxnum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>( + context, "minimum", getOperation()); + patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>( + context, "maximum", getOperation()); + patterns + .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>, + CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>( + context, getOperation()); + patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(), + /*isUnsigned=*/true); + patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(), + /*isUnsigned=*/false); + patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(), + /*isUnsigned=*/true); + LogicalResult result = success(); + ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) { + result = failure(); + } + // NB: if you don't return failure, no other diag handlers will fire (see + // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). + return failure(); + }); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + if (failed(result)) + return signalPassFailure(); +} +} // namespace diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt new file mode 100644 index 0000000..31fce7a --- /dev/null +++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArithToAPFloat + ArithToAPFloat.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRArithTransforms + MLIRFuncDialect + MLIRFuncUtils + MLIRVectorDialect + ) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index ba57155..220826d 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" @@ -36,20 +37,23 @@ namespace { /// attribute. template <typename SourceOp, typename TargetOp, bool Constrained, template <typename, typename> typename AttrConvert = - AttrConvertPassThrough> + AttrConvertPassThrough, + bool FailOnUnsupportedFP = false> struct ConstrainedVectorConvertToLLVMPattern - : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> { - using VectorConvertToLLVMPattern<SourceOp, TargetOp, - AttrConvert>::VectorConvertToLLVMPattern; + : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP> { + using VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::VectorConvertToLLVMPattern; LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (Constrained != static_cast<bool>(op.getRoundingModeAttr())) return failure(); - return VectorConvertToLLVMPattern<SourceOp, TargetOp, - AttrConvert>::matchAndRewrite(op, adaptor, - rewriter); + return VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter); } }; @@ -78,7 +82,8 @@ struct IdentityBitcastLowering final using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp, arith::AttrConvertOverflowToLLVM>; @@ -87,53 +92,67 @@ using BitcastOpLowering = VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using DivSIOpLowering = VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; using DivUIOpLowering = VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; -using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; +using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using ExtSIOpLowering = VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; using ExtUIOpLowering = VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; using FPToSIOpLowering = - VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; + VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using FPToUIOpLowering = - VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; + VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using MaximumFOpLowering = VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxNumFOpLowering = VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxSIOpLowering = VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; using MaxUIOpLowering = VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>; using MinimumFOpLowering = VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinNumFOpLowering = VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinSIOpLowering = VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; using MinUIOpLowering = VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>; using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp, arith::AttrConvertOverflowToLLVM>; using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using RemSIOpLowering = VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; using RemUIOpLowering = @@ -151,21 +170,25 @@ using SIToFPOpLowering = VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp, arith::AttrConvertOverflowToLLVM>; using TruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp, - false>; + false, AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, - arith::AttrConverterConstrainedFPToLLVM>; + arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>; using TruncIOpLowering = VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp, arith::AttrConvertOverflowToLLVM>; using UIToFPOpLowering = - VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; + VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; //===----------------------------------------------------------------------===// @@ -240,8 +263,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Adaptor = - typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; + using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(arith::SelectOp op, Adaptor adaptor, @@ -259,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), + /*propAttr=*/Attribute{}, *getTypeConverter(), rewriter); } @@ -460,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(), + op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "unsupported floating point type"); + Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); LLVM::FastmathFlags fmf = diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index bebf1b8..613dc6d 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) +add_subdirectory(ArithToAPFloat) add_subdirectory(ArithToArmSME) add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index 86d02e6..6a0c211 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> { ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - op->getAttrs(), *getTypeConverter(), rewriter); + op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(), + rewriter); } }; diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 798d8b0..b75968e 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -137,8 +137,7 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { /// op to llvm.br. struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; - using Adaptor = - typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; + using Adaptor = ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(cf::BranchOp op, Adaptor adaptor, @@ -163,8 +162,7 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { /// branch op to llvm.cond_br. struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; - using Adaptor = - typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; + using Adaptor = ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor, @@ -204,7 +202,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, + matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Get or convert default block. FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 93fe2ed..2220f61 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -374,9 +374,12 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp( // Create a memory effect attribute corresponding to readnone. if (funcOp->hasAttr(readnoneAttrName)) { auto memoryAttr = LLVM::MemoryEffectsAttr::get( - rewriter.getContext(), - {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef, - LLVM::ModRefInfo::NoModRef}); + rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::NoModRef, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef}); newFuncOp.setMemoryEffectsAttr(memoryAttr); } diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 425594b..f143a9e 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -66,7 +66,10 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef; auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); func.setMemoryEffectsAttr(memAttr); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index d64c4d6..5848489 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -419,7 +419,10 @@ struct LowerGpuOpsToNVVMOpsPass final if (this->hasRedux) populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); configureGpuToNVVMConversionLegality(target); - if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed( + applyPartialConversion(m, target, std::move(llvmPatterns), config))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 99c059c..6254de8 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" using namespace mlir; @@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { if (type.getElementType().isF32()) return type.getOperand() == "COp" ? NVVM::MMATypes::f32 : NVVM::MMATypes::tf32; - + if (type.getElementType().isF64()) + return NVVM::MMATypes::f64; if (type.getElementType().isSignedInteger(8)) return NVVM::MMATypes::s8; if (type.getElementType().isUnsignedInteger(8)) @@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering // then passed on to the intrinsic call. Emit llvm ops to extract individual // values form lowered memrefs. SmallVector<Value> unpackedOps; - auto unpackOp = [&](Value operand) { + // f64 a and b fragments are not structs but scalars. + if (!isa<LLVM::LLVMStructType>(operand.getType())) { + unpackedOps.push_back(operand); + return; + } + // every other type is lowered to an LLVM struct, extract the values. auto structType = cast<LLVM::LLVMStructType>(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i); @@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering return failure(); Location loc = subgroupMmaConstantOp.getLoc(); Value cst = adaptor.getOperands()[0]; - LLVM::LLVMStructType type = convertMMAToLLVMType( + Type type = convertMMAToLLVMType( cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType())); + // If the element is not a struct, it means it's a scalar f64. + auto structType = dyn_cast<LLVM::LLVMStructType>(type); + if (!structType) { + rewriter.replaceOp(subgroupMmaConstantOp, cst); + return success(); + } // If the element type is a vector create a vector from the operand. - if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) { + if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) { Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = LLVM::ConstantOp::create(rewriter, loc, @@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering } cst = vecCst; } - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type); - for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType); + for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) { matrixStruct = LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i); } @@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering return failure(); Location loc = subgroupMmaElementwiseOp.getLoc(); size_t numOperands = adaptor.getOperands().size(); - LLVM::LLVMStructType destType = convertMMAToLLVMType( + Type destType = convertMMAToLLVMType( cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType())); - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType); - for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { + + // If the element is not a struct, it means it's a scalar f64. + LLVM::LLVMStructType structDestTy = + dyn_cast<LLVM::LLVMStructType>(destType); + if (!structDestTy) { + SmallVector<Value> operands; + for (auto operand : adaptor.getOperands()) { + operands.push_back(operand); + } + Value element = createScalarOp( + rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands); + rewriter.replaceOp(subgroupMmaElementwiseOp, element); + return success(); + } + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy); + for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) { SmallVector<Value> extractedOperands; for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { extractedOperands.push_back(LLVM::ExtractValueOp::create( @@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering } // namespace /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. -LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { +Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { NVVM::MMAFrag frag = convertOperand(type.getOperand()); NVVM::MMATypes eltType = getElementType(type); auto nRow = type.getShape()[0]; auto nCol = type.getShape()[1]; std::pair<Type, unsigned> typeInfo = NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); + // Special handling for f64 a and b fragments + Type f64Ty = Float64Type::get(type.getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + return f64Ty; + } return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); } diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp index bc2f2f2..d4b4c46 100644 --- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp +++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp @@ -107,16 +107,16 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); - Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, -1)); // Compute `x`. Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero); @@ -157,14 +157,14 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value one = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value one = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); // Compute the non-zero result. Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one); @@ -193,16 +193,16 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); - Type n_type = n.getType(); + Type nType = n.getType(); Value m = adaptor.getRhs(); // Define the constants - Value zero = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 0)); - Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, 1)); - Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type, - IntegerAttr::get(n_type, -1)); + Value zero = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 0)); + Value posOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, 1)); + Value negOne = spirv::ConstantOp::create(rewriter, loc, nType, + IntegerAttr::get(nType, -1)); // Compute `x`. Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 48a0319..f28a6cc 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Detail methods //===----------------------------------------------------------------------===// -void LLVM::detail::setNativeProperties(Operation *op, - IntegerOverflowFlags overflowFlags) { - if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) - iface.setOverflowFlags(overflowFlags); -} - /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector<Type> resultTypes; @@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite( } // Create the operation through state since we don't know its C++ type. - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, targetAttrs); - - setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands, + resultTypes, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e7dd0b5..e5969c2 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - ArrayRef<NamedAttribute> targetAttrs, - const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - IntegerOverflowFlags overflowFlags) { + ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr, + const LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -116,18 +116,38 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( auto llvmNDVectorTy = operands[0].getType(); if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) - return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, - rewriter, overflowFlags); - - auto callback = [op, targetOp, targetAttrs, overflowFlags, + return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr, + typeConverter, rewriter); + auto callback = [op, targetOp, targetAttrs, propertiesAttr, &rewriter](Type llvm1DVectorTy, ValueRange operands) { - Operation *newOp = - rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), - operands, llvm1DVectorTy, targetAttrs); - LLVM::detail::setNativeProperties(newOp, overflowFlags); + OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), + operands, llvm1DVectorTy, targetAttrs); + state.propertiesAttr = propertiesAttr; + Operation *newOp = rewriter.create(state); return newOp->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } + +/// Return the given type if it's a floating point type. If the given type is +/// a vector type, return its element type if it's a floating point type. +static FloatType getFloatingPointType(Type type) { + if (auto floatType = dyn_cast<FloatType>(type)) + return floatType; + if (auto vecType = dyn_cast<VectorType>(type)) + return dyn_cast<FloatType>(vecType.getElementType()); + return nullptr; +} + +bool LLVM::detail::isUnsupportedFloatingPointType( + const TypeConverter &typeConverter, Type type) { + FloatType floatType = getFloatingPointType(type); + if (!floatType) + return false; + Type convertedType = typeConverter.convertType(floatType); + if (!convertedType) + return true; + return !isa<FloatType>(convertedType); +} diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 16ef11a..59a16df 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -93,13 +93,13 @@ public: /// Different MPI implementations have different communicator types. /// Using i64 as a portable, intermediate type. /// Appropriate cast needs to take place before calling MPI functions. - virtual Value getCommWorld(const Location loc, + virtual Value getCommWorld(Location loc, ConversionPatternRewriter &rewriter) = 0; /// Type converter provides i64 type for communicator type. /// Converts to native type, which might be ptr or int or whatever. - virtual Value castComm(const Location loc, - ConversionPatternRewriter &rewriter, Value comm) = 0; + virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter, + Value comm) = 0; /// Get the MPI_STATUS_IGNORE value (typically a pointer type). virtual intptr_t getStatusIgnore() = 0; @@ -109,13 +109,12 @@ public: /// Gets or creates an MPI datatype as a value which corresponds to the given /// type. - virtual Value getDataType(const Location loc, - ConversionPatternRewriter &rewriter, Type type) = 0; + virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter, + Type type) = 0; /// Gets or creates an MPI_Op value which corresponds to the given /// enum value. - virtual Value getMPIOp(const Location loc, - ConversionPatternRewriter &rewriter, + virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter, mpi::MPI_ReductionOpEnum opAttr) = 0; }; diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index a2dfc12..a922338 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -68,7 +68,7 @@ struct ClampFOpConversion final return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) -> Value { - typename math::ClampFOp::Adaptor adaptor(operands); + math::ClampFOp::Adaptor adaptor(operands); return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, adaptor.getValue(), adaptor.getMin(), adaptor.getMax()); diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 11f866c..0a382d8 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -122,7 +122,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, return totalSizeBytes.getResult(); } -static emitc::ApplyOp +static emitc::AddressOfOp createPointerFromEmitcArray(Location loc, OpBuilder &builder, TypedValue<emitc::ArrayType> arrayValue) { @@ -133,9 +133,9 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder, llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex); emitc::SubscriptOp subPtr = emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); - emitc::ApplyOp ptr = emitc::ApplyOp::create( + emitc::AddressOfOp ptr = emitc::AddressOfOp::create( builder, loc, emitc::PointerType::get(arrayType.getElementType()), - builder.getStringAttr("&"), subPtr); + subPtr); return ptr; } @@ -225,12 +225,12 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> { auto srcArrayValue = cast<TypedValue<emitc::ArrayType>>(operands.getSource()); - emitc::ApplyOp srcPtr = + emitc::AddressOfOp srcPtr = createPointerFromEmitcArray(loc, rewriter, srcArrayValue); auto targetArrayValue = cast<TypedValue<emitc::ArrayType>>(operands.getTarget()); - emitc::ApplyOp targetPtr = + emitc::AddressOfOp targetPtr = createPointerFromEmitcArray(loc, rewriter, targetArrayValue); emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( @@ -319,8 +319,8 @@ struct ConvertGetGlobal final emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create( rewriter, op.getLoc(), lvalueType, operands.getNameAttr()); emitc::PointerType pointerType = emitc::PointerType::get(resultTy); - rewriter.replaceOpWithNewOp<emitc::ApplyOp>( - op, pointerType, rewriter.getStringAttr("&"), globalLValue); + rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType, + globalLValue); return success(); } rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index ec182f1..64a7f56 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -865,13 +865,7 @@ struct NVGPUMBarrierArriveLowering adaptor.getMbarId(), rewriter); Type tokenType = getTypeConverter()->convertType( nvgpu::MBarrierTokenType::get(op->getContext())); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, - barrier); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, - barrier); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier); return success(); } }; @@ -892,13 +886,8 @@ struct NVGPUMBarrierArriveNoCompleteLowering Type tokenType = getTypeConverter()->convertType( nvgpu::MBarrierTokenType::get(op->getContext())); Value count = truncToI32(b, adaptor.getCount()); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( - op, tokenType, barrier, count); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( - op, tokenType, barrier, count); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( + op, tokenType, barrier, count); return success(); } }; @@ -915,13 +904,8 @@ struct NVGPUMBarrierTestWaitLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Type retType = rewriter.getI1Type(); - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( - op, retType, barrier, adaptor.getToken()); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( - op, retType, barrier, adaptor.getToken()); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier, + adaptor.getToken()); return success(); } }; @@ -938,15 +922,12 @@ struct NVGPUMBarrierArriveExpectTxLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value txcount = truncToI32(b, adaptor.getTxcount()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( - op, barrier, txcount, adaptor.getPredicate()); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( - op, barrier, txcount, adaptor.getPredicate()); + op, Type{}, // return-value is optional and is void by default + barrier, txcount, // barrier and txcount + NVVM::MemScopeKind::CTA, // default scope is CTA + false, // relaxed-semantics is false + adaptor.getPredicate()); return success(); } }; @@ -965,13 +946,6 @@ struct NVGPUMBarrierTryWaitParityLowering Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( - op, barrier, phase, ticks); - return success(); - } - rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, phase, ticks); return success(); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 021e31a..7fdc23a 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -66,6 +66,9 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> { for (NamedAttribute attr : op->getAttrs()) { if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) { Type convertedType = converter->convertType(typeAttr.getValue()); + if (!convertedType) + return rewriter.notifyMatchFailure( + op, "failed to convert type in attribute"); convertedAttrs.emplace_back(attr.getName(), TypeAttr::get(convertedType)); } else { diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 37cfc9f..03842cc 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -36,6 +36,7 @@ namespace { struct SCFToControlFlowPass : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> { + using Base::Base; void runOnOperation() override; }; @@ -736,7 +737,9 @@ void SCFToControlFlowPass::runOnOperation() { target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 76a822b..309121f 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -453,10 +453,24 @@ static LogicalResult processParallelLoop( 1, 2, rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1)); + // Map through cloningMap first so we use values valid at the launch + // scope, then ensure they are launch-independent (or cloned constants). + Value mappedStep = cloningMap.lookupOrDefault(step); + Value mappedLowerBound = cloningMap.lookupOrDefault(lowerBound); + + mappedStep = ensureLaunchIndependent(mappedStep); + mappedLowerBound = ensureLaunchIndependent(mappedLowerBound); + + // If either cannot be made available above the launch, fail gracefully. + if (!mappedStep || !mappedLowerBound) { + return rewriter.notifyMatchFailure( + parallelOp, "lower bound / step must be constant or defined above " + "the gpu.launch"); + } + newIndex = AffineApplyOp::create( rewriter, loc, annotation.getMap().compose(lowerAndStep), - ValueRange{operand, ensureLaunchIndependent(step), - ensureLaunchIndependent(lowerBound)}); + ValueRange{operand, mappedStep, mappedLowerBound}); // If there was also a bound, insert that, too. // TODO: Check that we do not assign bounds twice. if (annotation.getBound()) { diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 460595b..6423d49 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -188,7 +188,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, OpBuilder::InsertionGuard guard(builder); Type type = reduce.getOperands()[reductionIndex].getType(); auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(), - "__scf_reduction", type); + "__scf_reduction", type, + /*byref_element_type=*/{}); symbolTable.insert(decl); builder.createBlock(&decl.getInitializerRegion(), diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 50fca56..02b61bd 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1520,20 +1520,12 @@ public: if (!dstType) return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); - Location loc = tanOp.getLoc(); - Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); - Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); - rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); + rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType, + adaptor.getOperands()); return success(); } }; -/// Convert `spirv.Tanh` to -/// -/// exp(2x) - 1 -/// ----------- -/// exp(2x) + 1 -/// class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { public: using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; @@ -1546,18 +1538,8 @@ public: if (!dstType) return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); - Location loc = tanhOp.getLoc(); - Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); - Value multiplied = - LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); - Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); - Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value numerator = - LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); - Value denominator = - LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); - rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, - denominator); + rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType, + adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp index 9921a06..feb0489 100644 --- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp +++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp @@ -23,8 +23,11 @@ namespace mlir { using namespace mlir; -namespace { +//===----------------------------------------------------------------------===// +// PoisonOpLowering +//===----------------------------------------------------------------------===// +namespace { struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> { matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; - } // namespace -//===----------------------------------------------------------------------===// -// PoisonOpLowering -//===----------------------------------------------------------------------===// - LogicalResult PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -61,6 +59,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, } //===----------------------------------------------------------------------===// +// UnreachableOpLowering +//===----------------------------------------------------------------------===// + +namespace { +struct UnreachableOpLowering + : public ConvertOpToLLVMPattern<ub::UnreachableOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace +LogicalResult + +UnreachableOpLowering::matchAndRewrite( + ub::UnreachableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp<LLVM::UnreachableOp>(op); + return success(); +} + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -93,7 +114,7 @@ struct UBToLLVMConversionPass void mlir::ub::populateUBToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<PoisonOpLowering>(converter); + patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp index 244d214..3831387 100644 --- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp +++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp @@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> { } }; +struct UnreachableOpLowering final : OpConversionPattern<ub::UnreachableOp> { + using Base::Base; + + LogicalResult + matchAndRewrite(ub::UnreachableOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<spirv::UnreachableOp>(op); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final void mlir::ub::populateUBToSPIRVConversionPatterns( const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<PoisonOpLowering>(converter, patterns.getContext()); + patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter, + patterns.getContext()); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 69a317ec..05d541f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -345,7 +345,8 @@ public: matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); - MemRefType memRefType = scatter.getMemRefType(); + auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType()); + assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return rewriter.notifyMatchFailure(scatter, "memref type not supported"); @@ -1654,6 +1655,20 @@ private: return failure(); } } + } else if (auto floatTy = dyn_cast<FloatType>(printType)) { + // Print other floating-point types using the APFloat runtime library. + int32_t sem = + llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + Value semValue = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + Value floatBits = + LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value); + printer = + LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables); + emitCall(rewriter, loc, printer.value(), + ValueRange({semValue, floatBits})); + return success(); } else { return failure(); } diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 91c1aa5..079e1e2 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, return success(); } -static xegpu::CreateNdDescOp -createNdDescriptor(PatternRewriter &rewriter, Location loc, - xegpu::TensorDescType descType, TypedValue<MemRefType> src, - Operation::operand_range offsets) { +static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter, + Location loc, + xegpu::TensorDescType descType, + TypedValue<MemRefType> src) { MemRefType srcTy = src.getType(); auto [strides, offset] = srcTy.getStridesAndOffset(); xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { - ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, - getAsOpFoldResult(offsets)); + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src); } else { // In case of any dynamic shapes, source's shape and strides have to be // explicitly provided. - SmallVector<Value> sourceDims; - unsigned srcRank = srcTy.getRank(); - for (unsigned i = 0; i < srcRank; ++i) - sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i)); - - SmallVector<int64_t> constOffsets; - SmallVector<Value> dynOffsets; - for (Value offset : offsets) { - std::optional<int64_t> staticVal = getConstantIntValue(offset); - if (!staticVal) - dynOffsets.push_back(offset); - constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); - } - - SmallVector<Value> dynShapes; - for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { - if (shape == ShapedType::kDynamic) - dynShapes.push_back(sourceDims[idx]); - } - - // Compute strides in reverse order. - SmallVector<Value> dynStrides; - Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1); - // Last stride is guaranteed to be static and unit. - for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) { - accStride = - arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]); - if (strides[i] == ShapedType::kDynamic) - dynStrides.push_back(accStride); - } - std::reverse(dynStrides.begin(), dynStrides.end()); - - ndDesc = xegpu::CreateNdDescOp::create( - rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides, - DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), - DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), - DenseI64ArrayAttr::get(rewriter.getContext(), strides)); + auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src); + ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src, + meta.getConstifiedMixedSizes(), + meta.getConstifiedMixedStrides()); } return ndDesc; @@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp, .getResult(); } +// Collapses shapes of a nD memref to the target rank while applying offsets for +// the collapsed dimensions. Returns the new memref value and the remaining +// offsets for the last targetRank dimensions. For example: +// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3], +// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3] +static std::pair<Value, SmallVector<OpFoldResult>> +convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc, + Value memref, + SmallVector<OpFoldResult> offsets, + int64_t targetRank) { + auto memrefType = cast<MemRefType>(memref.getType()); + unsigned rank = memrefType.getRank(); + + if (rank <= targetRank) + return {memref, offsets}; + + int64_t numCombinedDims = rank - targetRank; + SmallVector<OpFoldResult> subviewOffsets; + SmallVector<OpFoldResult> subviewSizes; + SmallVector<OpFoldResult> subviewStrides; + + // For the combined dimensions: use the provided offsets, size=1, stride=1 + for (unsigned i = 0; i < numCombinedDims; ++i) { + subviewOffsets.push_back(offsets[i]); + subviewSizes.push_back(rewriter.getI64IntegerAttr(1)); + subviewStrides.push_back(rewriter.getI64IntegerAttr(1)); + } + + // For the last targetRank dimensions: offset=0, use full size, stride=1 + SmallVector<int64_t> resultShape; + auto originalShape = memrefType.getShape(); + auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref); + for (unsigned i = numCombinedDims; i < rank; ++i) { + subviewOffsets.push_back(rewriter.getI64IntegerAttr(0)); + if (ShapedType::isDynamic(originalShape[i])) { + subviewSizes.push_back(meta.getSizes()[i]); + resultShape.push_back(ShapedType::kDynamic); + } else { + subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i])); + resultShape.push_back(originalShape[i]); + } + subviewStrides.push_back(rewriter.getI64IntegerAttr(1)); + } + + auto resultType = memref::SubViewOp::inferRankReducedResultType( + resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides); + auto subviewOp = + memref::SubViewOp::create(rewriter, loc, resultType, memref, + subviewOffsets, subviewSizes, subviewStrides); + + // Return the remaining offsets for the last targetRank dimensions + SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims, + offsets.end()); + return {subviewOp.getResult(), newOffsets}; +} + template < typename OpType, typename = std::enable_if_t<llvm::is_one_of< @@ -435,7 +457,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); rewriter.replaceOp(readOp, gatherOp.getResult()); return success(); @@ -469,7 +492,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); rewriter.eraseOp(writeOp); return success(); } @@ -495,8 +519,13 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { return lowerToScatteredLoadOp(readOp, rewriter); } - // Perform common data transfer checks. VectorType vecTy = readOp.getVectorType(); + + // Lower using load.gather in 1D case + if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim()) + return lowerToScatteredLoadOp(readOp, rewriter); + + // Perform common data transfer checks. if (failed(storeLoadPreconditions(rewriter, readOp, vecTy))) return failure(); @@ -523,21 +552,23 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { descShape, elementType, /*array_length=*/1, /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = - createNdDescriptor(rewriter, loc, descType, - dyn_cast<TypedValue<MemRefType>>(readOp.getBase()), - readOp.getIndices()); - DenseI64ArrayAttr transposeAttr = !isTransposeLoad ? nullptr : DenseI64ArrayAttr::get(rewriter.getContext(), ArrayRef<int64_t>{1, 0}); + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()), + vecTy.getRank()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + + auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, /*packed=*/nullptr, transposeAttr, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(readOp, loadOp); return success(); @@ -575,21 +606,24 @@ struct TransferWriteLowering if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, writeOp.getBase(), + getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank()); + auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = - createNdDescriptor(rewriter, loc, descType, - dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()), - writeOp.getIndices()); - // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; - auto storeOp = - xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + + auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), + ndDesc, indices, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(writeOp, storeOp); return success(); @@ -621,7 +655,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> { /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); auto selectOp = arith::SelectOp::create(rewriter, loc, gatherOp.getMask(), @@ -655,7 +690,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> { /*chunk_size=*/IntegerAttr{}, /*l1_hint=*/xegpu::CachePolicyAttr{}, /*l2_hint=*/xegpu::CachePolicyAttr{}, - /*l3_hint=*/xegpu::CachePolicyAttr{}); + /*l3_hint=*/xegpu::CachePolicyAttr{}, + /*layout=*/nullptr); rewriter.eraseOp(scatterOp); return success(); } @@ -674,19 +710,25 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> { // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; + // By default, no specific caching policy is assigned. + xegpu::CachePolicyAttr hint = nullptr; + + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()), + vecTy.getRank()); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = createNdDescriptor( - rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); - // By default, no specific caching policy is assigned. - xegpu::CachePolicyAttr hint = nullptr; - auto loadNdOp = xegpu::LoadNdOp::create( - rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, - /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + auto loadNdOp = + xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices, + /*packed=*/nullptr, /*transpose=*/nullptr, + /*l1_hint=*/hint, + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); rewriter.replaceOp(loadOp, loadNdOp); return success(); @@ -708,18 +750,25 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> { // Boundary check is available only for block instructions. bool boundaryCheck = vecTy.getRank() > 1; + auto [src, indices] = convertMemrefAndOffsetsToTargetRank( + rewriter, loc, storeOp.getBase(), + getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank()); + auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); - xegpu::CreateNdDescOp ndDesc = createNdDescriptor( - rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); // By default, no specific caching policy is assigned. xegpu::CachePolicyAttr hint = nullptr; + xegpu::CreateNdDescOp ndDesc = createNdDescriptor( + rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src)); + auto storeNdOp = - xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, + xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices, /*l1_hint=*/hint, - /*l2_hint=*/hint, /*l3_hint=*/hint); + /*l2_hint=*/hint, /*l3_hint=*/hint, + /*layout=*/nullptr); + rewriter.replaceOp(storeOp, storeNdOp); return success(); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 33e8f2e..0ecb50e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16}; // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. enum class NdTdescOffset : uint32_t { - BasePtr = 0, // Base pointer (i64) - BaseShapeW = 2, // Base shape width (i32) - BaseShapeH = 3, // Base shape height (i32) - TensorOffsetW = 4, // Tensor offset W (i32) - TensorOffsetH = 5 // Tensor offset H (i32) + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + BasePitch = 4, // Base pitch (i32) }; static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { @@ -151,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, } } +// +// Note: +// Block operations for tile of sub byte element types are handled by +// emulating with larger element types. +// Tensor descriptor are keep intact and only ops consuming them are +// emulated +// + class CreateNdDescToXeVMPattern : public OpConversionPattern<xegpu::CreateNdDescOp> { using OpConversionPattern::OpConversionPattern; @@ -179,16 +186,12 @@ class CreateNdDescToXeVMPattern Value baseAddr; Value baseShapeW; Value baseShapeH; - Value offsetW; - Value offsetH; // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); + SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -197,10 +200,20 @@ class CreateNdDescToXeVMPattern if (!sourceMemrefTy.hasRank()) { return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } - baseAddr = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + // Access adaptor after failure check to avoid rolling back generated code + // for materialization cast. + baseAddr = adaptor.getSource(); } else { baseAddr = adaptor.getSource(); + if (baseAddr.getType() != i64Ty) { + // Pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); + } + } + // 1D tensor descriptor is just the base address. + if (rank == 1) { + rewriter.replaceOp(op, baseAddr); + return success(); } // Utility for creating offset values from op fold result. auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec, @@ -209,19 +222,11 @@ class CreateNdDescToXeVMPattern val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - // Offsets are not supported (0 is used). - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); - if (sourceMemrefTy) { - // Cast index to i64. - baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else if (baseAddr.getType() != i64Ty) { - // Pointer type may be i32. Cast to i64 if needed. - baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); - } + // Get pitch value from op fold results. + Value basePitch = createOffset(mixedStrides, 0); // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -235,12 +240,9 @@ class CreateNdDescToXeVMPattern payload = vector::InsertOp::create(rewriter, loc, baseShapeH, payload, static_cast<int>(NdTdescOffset::BaseShapeH)); - payload = vector::InsertOp::create( - rewriter, loc, offsetW, payload, - static_cast<int>(NdTdescOffset::TensorOffsetW)); - payload = vector::InsertOp::create( - rewriter, loc, offsetH, payload, - static_cast<int>(NdTdescOffset::TensorOffsetH)); + payload = + vector::InsertOp::create(rewriter, loc, basePitch, payload, + static_cast<int>(NdTdescOffset::BasePitch)); rewriter.replaceOp(op, payload); return success(); } @@ -257,108 +259,240 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { ConversionPatternRewriter &rewriter) const override { auto mixedOffsets = op.getMixedOffsets(); int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto tileRank = tdescTy.getRank(); + if (opOffsetsSize != tileRank) + return rewriter.notifyMatchFailure( + op, "Expected offset rank to match descriptor rank."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); - if (elemBitSize % 8 != 0) + bool isSubByte = elemBitSize < 8; + uint64_t wScaleFactor = 1; + + if (!isSubByte && (elemBitSize % 8 != 0)) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); + auto tileW = tdescTy.getDimSize(tileRank - 1); + // For sub byte types, only 4bits are currently supported. + if (isSubByte) { + if (elemBitSize != 4) + return rewriter.notifyMatchFailure( + op, "Only sub byte types of 4bits are supported."); + if (tileRank != 2) + return rewriter.notifyMatchFailure( + op, "Sub byte types are only supported for 2D tensor descriptors."); + auto subByteFactor = 8 / elemBitSize; + auto tileH = tdescTy.getDimSize(0); + // Handle special case for packed load. + if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + if (op.getPacked().value_or(false)) { + // packed load is implemented as packed loads of 8bit elements. + if (tileH == systolicDepth * 4 && + tileW == executionSize * subByteFactor) { + // Usage case for loading as Matrix B with pack request. + // source is assumed to pre-packed into 8bit elements + // Emulate with 8bit loads with pack request. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(8); + tileW = executionSize; + wScaleFactor = subByteFactor; + } + } + } + // If not handled by packed load case above, handle other cases. + if (wScaleFactor == 1) { + auto sub16BitFactor = subByteFactor * 2; + if (tileW == executionSize * sub16BitFactor) { + // Usage case for loading as Matrix A operand + // Emulate with 16bit loads/stores. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(16); + tileW = executionSize; + wScaleFactor = sub16BitFactor; + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported tile shape for sub byte types."); + } + } + // recompute element bit size for emulation. + elemBitSize = elemType.getIntOrFloatBitWidth(); + } - VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); - Value payLoadAsI64 = - vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); - Value basePtr = vector::ExtractOp::create( - rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr)); - Value baseShapeW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); - Value baseShapeH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); - // Offsets are provided by the op. - // convert them to i32. - Value offsetW = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); - offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetW); - Value offsetH = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); - offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - // Convert base pointer (i64) to LLVM pointer type. - Value basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); - // Compute element byte size and surface width in bytes. - Value elemByteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); - Value surfaceW = - arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); - - // Get tile sizes and vblocks from the tensor descriptor type. - auto tileW = tdescTy.getDimSize(1); - auto tileH = tdescTy.getDimSize(0); - int32_t vblocks = tdescTy.getArrayLength(); - if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { - Value src = adaptor.getValue(); - // If store value is a scalar, get value from op instead of adaptor. - // Adaptor might have optimized away single element vector - if (src.getType().isIntOrFloat()) { - src = op.getValue(); + if (tileRank == 2) { + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = + vector::ExtractOp::create(rewriter, loc, payLoadAsI64, + static_cast<int>(NdTdescOffset::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); + Value basePitch = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch)); + // Offsets are provided by the op. + // convert them to i32. + Value offsetW = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + Value offsetH = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // FIXME: width or pitch is not the same as baseShapeW it should be the + // stride of the second to last dimension in row major layout. + // Compute width in bytes. + Value baseShapeWInBytes = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + // Compute pitch in bytes. + Value basePitchBytes = + arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize); + + if (wScaleFactor > 1) { + // Scale offsetW, baseShapeWInBytes for sub byte emulation. + // Note: tileW is already scaled above. + Value wScaleFactorValLog2 = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor)); + baseShapeWInBytes = arith::ShRSIOp::create( + rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2); + basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes, + wScaleFactorValLog2); + offsetW = + arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2); } - VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); - if (!srcVecTy) - return rewriter.notifyMatchFailure( - op, "Expected store value to be a vector type."); - // Get flat vector type of integer type with matching element bit size. - VectorType newSrcVecTy = - encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); - if (srcVecTy != newSrcVecTy) - src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); - auto storeCacheControl = - translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, src, - xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); - rewriter.eraseOp(op); - } else { - auto loadCacheControl = - translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { - xevm::BlockPrefetch2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, vblocks, - xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + // Get tile height from the tensor descriptor type. + auto tileH = tdescTy.getDimSize(0); + // Get vblocks from the tensor descriptor type. + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + xevm::BlockStore2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { - VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); - const bool vnni = op.getPacked().value_or(false); - auto transposeValue = op.getTranspose(); - bool transpose = - transposeValue.has_value() && transposeValue.value()[0] == 1; - VectorType loadedTy = encodeVectorTypeTo( - dstVecTy, vnni ? rewriter.getI32Type() - : rewriter.getIntegerType(elemBitSize)); - - Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, - surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - transpose, vnni, + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, + vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes, + baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW, + tileH, vblocks, transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + } else { + // 1D tensor descriptor. + // `tdesc` represents base address as i64 + // Offset in number of elements, need to multiply by element byte size. + // Compute byte offset. + // byteOffset = offset * elementByteSize + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI64Type(), offset); + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemBitSize / 8); + Value byteOffset = + rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize); + // Final address = basePtr + byteOffset + Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>( + loc, tdesc, + getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(), + byteOffset)); + // Convert base pointer (i64) to LLVM pointer type. + Value finalPtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>( + op, finalPtrLLVM, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + VectorType resTy = cast<VectorType>(op.getValue().getType()); + VectorType loadedTy = + encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize)); + Value load = xevm::BlockLoadOp::create( + rewriter, loc, loadedTy, finalPtrLLVM, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); - resultFlatVec = vector::BitCastOp::create( - rewriter, loc, - encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), - resultFlatVec); - rewriter.replaceOp(op, resultFlatVec); + if (loadedTy != resTy) + load = vector::BitCastOp::create(rewriter, loc, resTy, load); + rewriter.replaceOp(op, load); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported operation: xegpu.prefetch_nd with tensor " + "descriptor rank == 1"); } } return success(); @@ -511,9 +645,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { } }; -// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions -// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than -// 32 bits will be converted to 32 bits. class CreateMemDescOpPattern final : public OpConversionPattern<xegpu::CreateMemDescOp> { public: @@ -522,16 +653,7 @@ public: matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resTy = op.getMemDesc(); - - // Create the result MemRefType with the same shape, element type, and - // memory space - auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); - - Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, - op.getSource(), zero, ValueRange()); - rewriter.replaceOp(op, viewOp); + rewriter.replaceOp(op, adaptor.getSource()); return success(); } }; @@ -551,17 +673,27 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); - Value basePtrStruct = adaptor.getMemDesc(); + Value baseAddr32 = adaptor.getMemDesc(); Value mdescVal = op.getMemDesc(); // Load result or Store value Type can be vector or scalar. - Value data; - if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) - data = op.getResult(); - else - data = adaptor.getData(); - VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); + Type dataTy; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + Type resType = op.getResult().getType(); + // Some transforms may leave unit dimension in the 2D vector, adaptors do + // not catch it for results. + if (auto vecType = dyn_cast<VectorType>(resType)) { + assert(llvm::count_if(vecType.getShape(), + [](int64_t d) { return d != 1; }) <= 1 && + "Expected either 1D vector or nD with unit dimensions"); + resType = VectorType::get({vecType.getNumElements()}, + vecType.getElementType()); + } + dataTy = resType; + } else + dataTy = adaptor.getData().getType(); + VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy); if (!valOrResVecTy) - valOrResVecTy = VectorType::get(1, data.getType()); + valOrResVecTy = VectorType::get(1, dataTy); int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); @@ -577,21 +709,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); - Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, basePtrStruct); - - // Convert base pointer (ptr) to i32 - Value basePtrI32 = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), basePtrLLVM); - Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); linearOffset = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getI32Type(), linearOffset); - basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, - elemByteSize); + Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, + linearOffset, elemByteSize); // convert base pointer (i32) to LLVM pointer type - basePtrLLVM = + Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); if (op.getSubgroupBlockIoAttr()) { @@ -927,20 +1052,22 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + // Scattered descriptors are not supported in XeVM lowering. if (type.isScattered()) + return {}; + if (type.getRank() == 1) return IntegerType::get(&getContext(), 64); auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); - // Convert MemDescType into flattened MemRefType for SLM + // Convert MemDescType into i32 for SLM typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - Type elemTy = type.getElementType(); - int numElems = type.getNumElements(); - return MemRefType::get(numElems, elemTy, AffineMap(), 3); + return IntegerType::get(&getContext(), 32); }); typeConverter.addConversion([&](MemRefType type) -> Type { - // Convert MemRefType to i64 type. + if (type.getMemorySpaceAsInt() == 3) + return IntegerType::get(&getContext(), 32); return IntegerType::get(&getContext(), 64); }); @@ -1057,6 +1184,7 @@ struct ConvertXeGPUToXeVMPass }; typeConverter.addSourceMaterialization( singleElementVectorMaterializationCast); + typeConverter.addSourceMaterialization(vectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index f276984..20a420d 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -290,7 +290,7 @@ static LLVM::CallOp createDeviceFunctionCall( ArrayRef<Type> argTypes, ArrayRef<Value> args, mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs, LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { - auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); + auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); assert(moduleOp && "Expecting module"); Location loc = op->getLoc(); @@ -401,7 +401,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::NoModRef, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); auto funcAttrs = convergentNoUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; Value result = @@ -450,7 +453,10 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); funcAttr.memEffectsAttr = memAttr; LLVM::CallOp call = createDeviceFunctionCall( @@ -556,7 +562,10 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef, + /*errnoMem=*/LLVM::ModRefInfo::NoModRef, + /*targetMem0=*/LLVM::ModRefInfo::NoModRef, + /*targetMem1=*/LLVM::ModRefInfo::NoModRef); funcAttr = noUnwindAttrs; funcAttr.memEffectsAttr = memAttr; } else { @@ -798,7 +807,10 @@ class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> { constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); call.setMemoryEffectsAttr(memAttr); rewriter.replaceOp(op, call); return success(); @@ -836,7 +848,10 @@ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> { constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( /*other=*/noModRef, - /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef, + /*errnoMem=*/noModRef, + /*targetMem0=*/noModRef, + /*targetMem1=*/noModRef); call.setMemoryEffectsAttr(memAttr); rewriter.replaceOp(op, call); return success(); |
