aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp758
-rw-r--r--mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp665
-rw-r--r--mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt19
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp77
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp8
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp52
-rw-r--r--mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp38
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp21
-rw-r--r--mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp42
-rw-r--r--mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp13
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp2
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp14
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp46
-rw-r--r--mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp7
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp18
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp3
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp26
-rw-r--r--mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp35
-rw-r--r--mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp14
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp17
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp201
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp426
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp27
29 files changed, 2090 insertions, 465 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0..7584b17 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,8 +16,10 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -42,6 +44,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
+constexpr Chipset kGfx1250 = Chipset(12, 5, 0);
/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
@@ -79,12 +82,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
}
-static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
- bool value) {
- Type llvmI1 = rewriter.getI1Type();
- return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
-}
-
/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
Location loc, MemRefDescriptor &memRefDescriptor,
@@ -509,10 +506,16 @@ struct MemoryCounterWaitOpLowering
if (std::optional<int> exp = adaptor.getExp())
ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
+ if (std::optional<int> tensor = adaptor.getTensor())
+ ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
+
rewriter.eraseOp(op);
return success();
}
+ if (adaptor.getTensor())
+ return op.emitOpError("unsupported chipset");
+
auto getVal = [](Attribute attr) -> unsigned {
if (attr)
return cast<IntegerAttr>(attr).getInt();
@@ -684,12 +687,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
-static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
- Location loc,
- const TypeConverter *typeConverter,
- bool isUnsigned, Value llvmInput,
- Value mlirInput,
- SmallVector<Value, 4> &operands) {
+static void wmmaPushInputOperand(
+ ConversionPatternRewriter &rewriter, Location loc,
+ const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
+ Value mlirInput, SmallVectorImpl<Value> &operands,
+ SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
if (!vectorType) {
@@ -697,10 +699,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}
Type elemType = vectorType.getElementType();
-
- if (elemType.isBF16())
- llvmInput = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -719,8 +717,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
} else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
- Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
- operands.push_back(sign);
+ attrs.push_back(
+ NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
}
int64_t numBits =
@@ -751,18 +749,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
- bool clamp, SmallVector<Value, 4> &operands) {
+ bool clamp, SmallVectorImpl<Value> &operands,
+ SmallVectorImpl<NamedAttribute> &attrs) {
Type inputType = output.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
- if (elemType.isBF16())
- output = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
- operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+ attrs.push_back(
+ NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
} else if (elemType.isInteger(32)) {
- operands.push_back(createI1Constant(rewriter, loc, clamp));
+ attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
}
}
@@ -1160,7 +1157,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
k, isRDNA3);
// Handle gfx1250.
- if (chipset == Chipset{12, 5, 0})
+ if (chipset == kGfx1250)
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
elemDestType, k);
@@ -1311,11 +1308,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
- // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
- // need to bitcast bfloats to i16 and then bitcast them back.
+ bool isGFX1250 = chipset >= kGfx1250;
+
+ // The WMMA operations represent vectors of bf16s as vectors of i16s
+ // (except on gfx1250), so we need to bitcast bfloats to i16 and then
+ // bitcast them back.
+ auto aType = cast<VectorType>(adaptor.getSourceA().getType());
+ auto bType = cast<VectorType>(adaptor.getSourceB().getType());
+ auto destCType = cast<VectorType>(adaptor.getDestC().getType());
+ bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
+ bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
+ bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
+ bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
VectorType rawOutType = outType;
- if (outType.getElementType().isBF16())
+ if (castOutToI16)
rawOutType = outType.clone(rewriter.getI16Type());
+ Value a = adaptor.getSourceA();
+ if (castAToI16)
+ a = LLVM::BitcastOp::create(rewriter, loc,
+ aType.clone(rewriter.getI16Type()), a);
+ Value b = adaptor.getSourceB();
+ if (castBToI16)
+ b = LLVM::BitcastOp::create(rewriter, loc,
+ bType.clone(rewriter.getI16Type()), b);
+ Value destC = adaptor.getDestC();
+ if (castDestCToI16)
+ destC = LLVM::BitcastOp::create(
+ rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
@@ -1325,18 +1344,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
return op.emitOpError("subwordOffset not supported on gfx12+");
- OperationState loweredOp(loc, *maybeIntrinsic);
- loweredOp.addTypes(rawOutType);
-
SmallVector<Value, 4> operands;
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
- adaptor.getSourceA(), op.getSourceA(), operands);
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
- adaptor.getSourceB(), op.getSourceB(), operands);
- wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
- op.getSubwordOffset(), op.getClamp(), operands);
+ SmallVector<NamedAttribute, 4> attrs;
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
+ op.getSourceA(), operands, attrs, "signA");
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
+ op.getSourceB(), operands, attrs, "signB");
+ wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
+ op.getSubwordOffset(), op.getClamp(), operands,
+ attrs);
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes(rawOutType);
loweredOp.addOperands(operands);
+ loweredOp.addAttributes(attrs);
Operation *lowered = rewriter.create(loweredOp);
Operation *maybeCastBack = lowered;
@@ -1492,6 +1513,20 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct ScaledExtPackedMatrixOpLowering final
+ : public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
+ ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledExtPackedMatrixOp op,
+ ScaledExtPackedMatrixOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1600,6 +1635,173 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}
+int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
+ int32_t firstScaleByte) {
+ // When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
+ // operations, the attributes blockSize, sourceType, scaleWaveHalf, and
+ // firstScaleByte are merged into a single attribute scaleSel. This is how
+ // those values are merged together. (Note: scaleWaveHalf isn't a high-level
+ // attribute but is derifed from firstScaleLane).
+ assert(llvm::is_contained({16, 32}, blockSize));
+ assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
+
+ const bool isFp8 = bitWidth == 8;
+ const bool isBlock16 = blockSize == 16;
+
+ if (!isFp8) {
+ int32_t bit0 = isBlock16;
+ assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
+ int32_t bit1 = (firstScaleByte == 2) << 1;
+ assert(llvm::is_contained({0, 1}, scaleWaveHalf));
+ int32_t bit2 = scaleWaveHalf << 2;
+ return bit2 | bit1 | bit0;
+ }
+
+ int32_t bit0 = isBlock16;
+ // firstScaleByte is guaranteed to be defined by two bits.
+ assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
+ int32_t bits2and1 = firstScaleByte << 1;
+ assert(llvm::is_contained({0, 1}, scaleWaveHalf));
+ int32_t bit3 = scaleWaveHalf << 3;
+ int32_t bits = bit3 | bits2and1 | bit0;
+ // These are invalid cases.
+ assert(!llvm::is_contained(
+ {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
+ return bits;
+}
+
+static std::optional<StringRef>
+scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
+ if (isa<fp4>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<fp8>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<bf8>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<fp6>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
+ return std::nullopt;
+ }
+ if (isa<bf6>(srcElemType)) {
+ if (destElemType.isF16())
+ return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
+ if (destElemType.isBF16())
+ return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
+ if (destElemType.isF32())
+ return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
+ return std::nullopt;
+ }
+ llvm_unreachable("invalid combination of element types for packed conversion "
+ "instructions");
+}
+
+LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
+ ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ using fp4 = Float4E2M1FNType;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+ using fp6 = Float6E2M3FNType;
+ using bf6 = Float6E3M2FNType;
+ Location loc = op.getLoc();
+ if (chipset != kGfx1250) {
+ return rewriter.notifyMatchFailure(
+ loc,
+ "Scaled fp packed conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ }
+ // Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
+ // is being selected.
+ int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
+ int32_t firstScaleByte = op.getFirstScaleByte();
+ int32_t blockSize = op.getBlockSize();
+ auto sourceType = cast<VectorType>(op.getSource().getType());
+ auto srcElemType = cast<FloatType>(sourceType.getElementType());
+ unsigned bitWidth = srcElemType.getWidth();
+
+ auto targetType = cast<VectorType>(op.getResult().getType());
+ auto destElemType = cast<FloatType>(targetType.getElementType());
+
+ IntegerType i32 = rewriter.getI32Type();
+ Value source = adaptor.getSource();
+ Type llvmResultType = typeConverter->convertType(op.getResult().getType());
+ Type packedType = nullptr;
+ if (isa<fp4>(srcElemType)) {
+ packedType = i32;
+ packedType = getTypeConverter()->convertType(packedType);
+ } else if (isa<fp8, bf8>(srcElemType)) {
+ packedType = VectorType::get(2, i32);
+ packedType = getTypeConverter()->convertType(packedType);
+ } else if (isa<fp6, bf6>(srcElemType)) {
+ packedType = VectorType::get(3, i32);
+ packedType = getTypeConverter()->convertType(packedType);
+ } else {
+ llvm_unreachable("invalid element type for packed scaled ext");
+ }
+
+ if (!packedType || !llvmResultType) {
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+ }
+
+ std::optional<StringRef> maybeIntrinsic =
+ scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
+ if (!maybeIntrinsic.has_value())
+ return op.emitOpError(
+ "no intrinsic matching packed scaled conversion on the given chipset");
+
+ int32_t scaleSel =
+ getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
+ Value castedScale =
+ LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
+ Value castedSource =
+ LLVM::BitcastOp::create(rewriter, loc, packedType, source);
+
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes({llvmResultType});
+ loweredOp.addOperands({castedSource, castedScale});
+
+ SmallVector<NamedAttribute, 1> attrs;
+ attrs.push_back(
+ NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
+
+ loweredOp.addAttributes(attrs);
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered);
+
+ return success();
+}
+
LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -2073,6 +2275,441 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
}
};
+struct AMDGPUMakeDmaBaseLowering
+ : public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUMakeDmaBaseLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx1250)
+ return op->emitOpError("make_dma_base is only supported on gfx1250");
+
+ Location loc = op.getLoc();
+
+ ValueRange ldsIndices = adaptor.getLdsIndices();
+ Value lds = adaptor.getLds();
+ auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
+
+ Value ldsPtr =
+ getStridedElementPtr(rewriter, loc, ldsMemRefType, lds, ldsIndices);
+
+ ValueRange globalIndices = adaptor.getGlobalIndices();
+ Value global = adaptor.getGlobal();
+ auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
+
+ Value globalPtr = getStridedElementPtr(rewriter, loc, globalMemRefType,
+ global, globalIndices);
+
+ Type i32 = rewriter.getI32Type();
+ Type i64 = rewriter.getI64Type();
+
+ Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
+ Value castForGlobalAddr =
+ LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
+
+ Value lowHalf =
+ LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
+
+ Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
+ createI64Constant(rewriter, loc, 32));
+
+ Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
+
+ Value mask = createI32Constant(rewriter, loc, (1ull << 25) - 1);
+ Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
+
+ Value typeField = createI32Constant(rewriter, loc, 2 << 30);
+ Value highHalfPlusType =
+ LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
+
+ Value c0 = createI32Constant(rewriter, loc, 0);
+ Value c1 = createI32Constant(rewriter, loc, 1);
+ Value c2 = createI32Constant(rewriter, loc, 2);
+ Value c3 = createI32Constant(rewriter, loc, 3);
+
+ Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
+ assert(v4i32 && "expected type conversion to succeed");
+ Value result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, c1, c0);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ castForLdsAddr, c1);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, lowHalf, c2);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result,
+ highHalfPlusType, c3);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct AMDGPUMakeDmaDescriptorLowering
+ : public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUMakeDmaDescriptorLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ Value getDGroup0(OpAdaptor adaptor) const { return adaptor.getBase(); }
+
+ Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
+ Value accumulator, Value value, int64_t shift) const {
+ shift = shift % 32;
+ Value shiftAmount;
+ if (shift != 0) {
+ shiftAmount = createI32Constant(rewriter, loc, shift % 32);
+ value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
+ }
+
+ if (matchPattern(accumulator, mlir::m_Zero()))
+ return value;
+
+ return LLVM::OrOp::create(rewriter, loc, accumulator, value);
+ }
+
+ Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0) const {
+ Value mask = op.getWorkgroupMask();
+ if (!mask)
+ return sgpr0;
+
+ Type i32 = rewriter.getI32Type();
+ Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
+ return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
+ }
+
+ Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ // Compute data_size.
+ unsigned elementTypeWidthInBits = op.getElementTypeWidth();
+ assert(
+ llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) &&
+ "expected type width to be 8, 16, 32, or 64.");
+ int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8);
+ Value size = createI32Constant(rewriter, loc, dataSize);
+ return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
+ }
+
+ Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
+ if (!atomic_barrier_enable)
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
+ }
+
+ Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool iterate_enable = adaptor.getGlobalIncrement() != nullptr;
+ if (!iterate_enable)
+ return sgpr0;
+
+ // TODO: In future PR, add other required fields for iteration.
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
+ }
+
+ Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
+ }
+
+ Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ if (!op.getWorkgroupMask())
+ return sgpr0;
+
+ return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
+ }
+
+ Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ IntegerType i32 = rewriter.getI32Type();
+ Value padInterval = adaptor.getPadInterval();
+ // pre-condition: padInterval can be a power of two between 2 and 256.
+ padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
+ padInterval, false);
+ padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
+ // post-condition: padInterval can be a value between 0 and 7.
+ return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
+ }
+
+ Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr0, ArrayRef<Value> consts) const {
+ bool pad_enable = op.getPadAmount() != nullptr;
+ if (!pad_enable)
+ return sgpr0;
+
+ Value padAmount = adaptor.getPadAmount();
+ // pre-condition: padAmount is a value between 1-128.
+ padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
+ // post-condition: padAmount is a value between 0-127.
+ return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
+ }
+
+ Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr1,
+ ArrayRef<Value> consts) const {
+ bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() != nullptr;
+ if (!atomic_barrier_enable)
+ return sgpr1;
+
+ Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
+ auto barrierAddressTy =
+ cast<MemRefType>(op.getAtomicBarrierAddress().getType());
+ ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
+ atomicBarrierAddress =
+ getStridedElementPtr(rewriter, loc, barrierAddressTy,
+ atomicBarrierAddress, atomicBarrierIndices);
+ IntegerType i32 = rewriter.getI32Type();
+ // pre-condition: atomicBarrierAddress is aligned to 8 bytes which implies
+ // that the 3 LSBs are zero.
+ atomicBarrierAddress =
+ LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
+ atomicBarrierAddress =
+ LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
+ Value mask = createI32Constant(rewriter, loc, 0xFFFF);
+ atomicBarrierAddress =
+ LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
+ return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
+ }
+
+ std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr1, Value sgpr2,
+ ArrayRef<Value> consts) const {
+ SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
+ OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back();
+ Value tensorDim0;
+ if (auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult))
+ tensorDim0 =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tensorDim0 = cast<Value>(tensorDim0OpFoldResult);
+
+ Value c16 = createI32Constant(rewriter, loc, 16);
+ Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16);
+ sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48);
+ sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16);
+ return {sgpr1, sgpr2};
+ }
+
+ std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ Location loc, Value sgpr2, Value sgpr3,
+ ArrayRef<Value> consts) const {
+ // TODO: Generalize to setTensorDimX.
+ SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
+ OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1);
+ Value tensorDim1;
+ if (auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult))
+ tensorDim1 =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tensorDim1 = cast<Value>(tensorDim1OpFoldResult);
+
+ Value c16 = createI32Constant(rewriter, loc, 16);
+ Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16);
+ sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80);
+ sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16);
+ return {sgpr2, sgpr3};
+ }
+
+ Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr, ArrayRef<Value> consts, size_t dimX,
+ int64_t offset) const {
+ SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes();
+
+ if (mixedSharedSizes.size() <= dimX)
+ return sgpr;
+
+ OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
+ Value tileDimX;
+ if (auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
+ tileDimX =
+ createI32Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tileDimX = cast<Value>(tileDimXOpFoldResult);
+
+ return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
+ }
+
+ Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr3, ArrayRef<Value> consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
+ }
+
+ Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr4, ArrayRef<Value> consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
+ }
+
+ Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr4, ArrayRef<Value> consts) const {
+ return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
+ }
+
+ std::pair<Value, Value>
+ setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgprY, Value sgprZ, ArrayRef<Value> consts,
+ size_t dimX, int64_t offset) const {
+ SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides();
+
+ if (mixedGlobalStrides.size() <= dimX)
+ return {sgprY, sgprZ};
+
+ OpFoldResult tensorDimXStrideOpFoldResult =
+ *(mixedGlobalStrides.rbegin() + dimX);
+ Value tensorDimXStride;
+ if (auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
+ tensorDimXStride =
+ createI64Constant(rewriter, loc, cast<IntegerAttr>(attr).getInt());
+ else
+ tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
+
+ constexpr int64_t first48bits = (1ll << 48) - 1;
+ Value mask = createI64Constant(rewriter, loc, first48bits);
+ tensorDimXStride =
+ LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
+ IntegerType i32 = rewriter.getI32Type();
+ Value tensorDimXStrideLow =
+ LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
+
+ int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
+ Value shiftVal = createI64Constant(rewriter, loc, shift);
+ Value tensorDimXStrideHigh =
+ LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
+ tensorDimXStrideHigh =
+ LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
+
+ sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
+ sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
+ offset + shift);
+ return {sgprY, sgprZ};
+ }
+
+ std::pair<Value, Value>
+ setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
+ 0, 160);
+ }
+
+ std::pair<Value, Value>
+ setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ Value sgpr5, Value sgpr6, ArrayRef<Value> consts) const {
+ return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
+ 1, 208);
+ }
+
+ Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<Value> consts) const {
+ Value sgprs[8];
+ for (int64_t i = 0; i < 8; i++) {
+ sgprs[i] = consts[0];
+ }
+
+ sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
+ sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
+ sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
+
+ sgprs[1] =
+ setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
+ std::tie(sgprs[1], sgprs[2]) =
+ setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
+ std::tie(sgprs[2], sgprs[3]) =
+ setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
+
+ sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
+ sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
+ sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
+ std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
+ op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
+ std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
+ op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
+
+ IntegerType i32 = rewriter.getI32Type();
+ Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
+ assert(v8i32 && "expected type conversion to succeed");
+ Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
+
+ for (auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
+ dgroup1 =
+ LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
+ }
+
+ return dgroup1;
+ }
+
+ LogicalResult
+ matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (chipset < kGfx1250)
+ return op->emitOpError(
+ "make_dma_descriptor is only supported on gfx1250");
+
+ if (op.getRank() > 2)
+ return op->emitOpError("unimplemented");
+
+ Location loc = op.getLoc();
+
+ IntegerType i32 = rewriter.getI32Type();
+ [[maybe_unused]] Type v4i32 =
+ this->typeConverter->convertType(VectorType::get(4, i32));
+ assert(v4i32 && "expected type conversion to succeed");
+
+ SmallVector<Value> consts;
+ for (int64_t i = 0; i < 8; i++)
+ consts.push_back(createI32Constant(rewriter, loc, i));
+
+ Value dgroup0 = this->getDGroup0(adaptor);
+ Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
+
+ SmallVector<Value> results = {dgroup0, dgroup1};
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -2087,6 +2724,11 @@ struct ConvertAMDGPUToROCDLPass
RewritePatternSet patterns(ctx);
LLVMTypeConverter converter(ctx);
+ converter.addConversion([&](TDMBaseType type) -> Type {
+ Type i32 = IntegerType::get(type.getContext(), 32);
+ return converter.convertType(VectorType::get(4, i32));
+ });
+
populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
LLVMConversionTarget target(getContext());
target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
@@ -2122,25 +2764,27 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
populateAMDGPUMemorySpaceAttributeConversions(converter);
- patterns
- .add<FatRawBufferCastLowering,
- RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
- RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
- RawBufferOpLowering<RawBufferAtomicFaddOp,
- ROCDL::RawPtrBufferAtomicFaddOp>,
- RawBufferOpLowering<RawBufferAtomicFmaxOp,
- ROCDL::RawPtrBufferAtomicFmaxOp>,
- RawBufferOpLowering<RawBufferAtomicSmaxOp,
- ROCDL::RawPtrBufferAtomicSmaxOp>,
- RawBufferOpLowering<RawBufferAtomicUminOp,
- ROCDL::RawPtrBufferAtomicUminOp>,
- RawBufferOpLowering<RawBufferAtomicCmpswapOp,
- ROCDL::RawPtrBufferAtomicCmpSwap>,
- AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
- SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
- PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
- PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
+ patterns.add<
+ FatRawBufferCastLowering,
+ RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
+ RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
+ RawBufferOpLowering<RawBufferAtomicFaddOp,
+ ROCDL::RawPtrBufferAtomicFaddOp>,
+ RawBufferOpLowering<RawBufferAtomicFmaxOp,
+ ROCDL::RawPtrBufferAtomicFmaxOp>,
+ RawBufferOpLowering<RawBufferAtomicSmaxOp,
+ ROCDL::RawPtrBufferAtomicSmaxOp>,
+ RawBufferOpLowering<RawBufferAtomicUminOp,
+ ROCDL::RawPtrBufferAtomicUminOp>,
+ RawBufferOpLowering<RawBufferAtomicCmpswapOp,
+ ROCDL::RawPtrBufferAtomicCmpSwap>,
+ AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
+ SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
+ WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
+ ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+ GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
+ AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
+ chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
new file mode 100644
index 0000000..79816fc
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -0,0 +1,665 @@
+//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::func;
+
+static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, FunctionType funcT, bool setPrivate,
+ SymbolTableCollection *symbolTables = nullptr) {
+ OpBuilder::InsertionGuard g(b);
+ assert(!symTable->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&symTable->getRegion(0).front());
+ FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
+ if (setPrivate)
+ funcOp.setPrivate();
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
+ symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
+ }
+ return funcOp;
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function with the given parameter types. Returns an int64_t, unless a
+/// different result type is specified.
+static FailureOr<FuncOp>
+lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
+ StringRef name, TypeRange paramTypes,
+ SymbolTableCollection *symbolTables = nullptr,
+ Type resultType = {}) {
+ if (!resultType)
+ resultType = IntegerType::get(symTable->getContext(), 64);
+ std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
+ auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
+ FailureOr<FuncOp> func =
+ lookupFnDecl(symTable, funcName, funcT, symbolTables);
+ // Failed due to type mismatch.
+ if (failed(func))
+ return func;
+ // Successfully matched existing decl.
+ if (*func)
+ return *func;
+
+ return createFnDecl(b, symTable, funcName, funcT,
+ /*setPrivate=*/true, symbolTables);
+}
+
+/// Helper function to look up or create the symbol for a runtime library
+/// function for a binary arithmetic operation.
+///
+/// Parameter 1: APFloat semantics
+/// Parameter 2: Left-hand side operand
+/// Parameter 3: Right-hand side operand
+///
+/// This function will return a failure if the function is found but has an
+/// unexpected signature.
+///
+static FailureOr<FuncOp>
+lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
+ SymbolTableCollection *symbolTables = nullptr) {
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
+ symbolTables);
+}
+
+static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
+ int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ return arith::ConstantOp::create(b, loc, b.getI32Type(),
+ b.getIntegerAttr(b.getI32Type(), sem));
+}
+
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
+
+/// Rewrite a binary arithmetic operation to an APFloat function call.
+template <typename OpTy>
+struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
+ BinaryArithOpToAPFloatConversion(MLIRContext *context,
+ const char *APFloatName,
+ SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ APFloatName(APFloatName) {};
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ FailureOr<FuncOp> fn =
+ lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(resultType);
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ const char *APFloatName;
+};
+
+template <typename OpTy>
+struct FpToFpConversion final : OpRewritePattern<OpTy> {
+ FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
+ rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+template <typename OpTy>
+struct FpToIntConversion final : OpRewritePattern<OpTy> {
+ FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
+ bool isUnsigned, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ isUnsigned(isUnsigned) {}
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
+ {i32Type, i32Type, i1Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
+
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outIntTy = cast<IntegerType>(resultType);
+ Value outWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {inSemValue, outWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ return arith::TruncIOp::create(rewriter, loc, outIntTy,
+ resultOp->getResult(0));
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ bool isUnsigned;
+};
+
+template <typename OpTy>
+struct IntToFpConversion final : OpRewritePattern<OpTy> {
+ IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
+ bool isUnsigned, PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
+ isUnsigned(isUnsigned) {}
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
+ {i32Type, i32Type, i1Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inIntTy = cast<IntegerType>(operand1.getType());
+ Value operandBits = operand1;
+ if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+ if (isUnsigned) {
+ operandBits =
+ arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
+ } else {
+ operandBits =
+ arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
+ }
+ }
+
+ // Call APFloat function.
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value inWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {outSemValue, inWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+ bool isUnsigned;
+};
+
+struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
+ CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(arith::CmpFOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i1Type = IntegerType::get(symTable->getContext(), 1);
+ auto i8Type = IntegerType::get(symTable->getContext(), 8);
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "compare",
+ {i32Type, i64Type, i64Type}, nullptr, i8Type);
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(lhs.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ Value comparisonResult =
+ func::CallOp::create(rewriter, loc, TypeRange(i8Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Generate an i1 SSA value that is "true" if the comparison result
+ // matches the given `val`.
+ auto checkResult = [&](llvm::APFloat::cmpResult val) {
+ return arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
+ arith::ConstantOp::create(
+ rewriter, loc, i8Type,
+ rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
+ .getResult());
+ };
+ // Generate an i1 SSA value that is "true" if the comparison result
+ // matches any of the given `vals`.
+ std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)>
+ checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
+ Value first = checkResult(vals.front());
+ if (vals.size() == 1)
+ return first;
+ Value rest = checkResults(vals.drop_front());
+ return arith::OrIOp::create(rewriter, loc, first, rest)
+ .getResult();
+ };
+
+ // This switch-case statement was taken from arith::applyCmpPredicate.
+ Value result;
+ switch (op.getPredicate()) {
+ case arith::CmpFPredicate::AlwaysFalse:
+ result =
+ arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 0))
+ .getResult();
+ break;
+ case arith::CmpFPredicate::OEQ:
+ result = checkResult(llvm::APFloat::cmpEqual);
+ break;
+ case arith::CmpFPredicate::OGT:
+ result = checkResult(llvm::APFloat::cmpGreaterThan);
+ break;
+ case arith::CmpFPredicate::OGE:
+ result = checkResults(
+ {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::OLT:
+ result = checkResult(llvm::APFloat::cmpLessThan);
+ break;
+ case arith::CmpFPredicate::OLE:
+ result = checkResults(
+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ONE:
+ // Not cmpUnordered and not cmpUnordered.
+ result = checkResults(
+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::ORD:
+ // Not cmpUnordered.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UEQ:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UGT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::UGE:
+ result = checkResults({llvm::APFloat::cmpUnordered,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ULT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
+ break;
+ case arith::CmpFPredicate::ULE:
+ result = checkResults({llvm::APFloat::cmpUnordered,
+ llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UNE:
+ // Not cmpEqual.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpUnordered});
+ break;
+ case arith::CmpFPredicate::UNO:
+ result = checkResult(llvm::APFloat::cmpUnordered);
+ break;
+ case arith::CmpFPredicate::AlwaysTrue:
+ result =
+ arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 1))
+ .getResult();
+ break;
+ }
+ return result;
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
+ NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
+
+ LogicalResult matchAndRewrite(arith::NegFOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
+
+ // Get APFloat function from runtime library.
+ auto i32Type = IntegerType::get(symTable->getContext(), 32);
+ auto i64Type = IntegerType::get(symTable->getContext(), 64);
+ FailureOr<FuncOp> fn =
+ lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
+ if (failed(fn))
+ return fn;
+
+ // Scalarize and convert to APFloat runtime calls.
+ Location loc = op.getLoc();
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(operand1.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand1));
+
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
+
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
+ return success();
+ }
+
+ SymbolOpInterface symTable;
+};
+
+namespace {
+struct ArithToAPFloatConversionPass final
+ : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
+ using Base::Base;
+
+ void runOnOperation() override;
+};
+
+void ArithToAPFloatConversionPass::runOnOperation() {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
+ getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
+ context, "subtract", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
+ context, "multiply", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
+ context, "divide", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
+ context, "remainder", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
+ context, "minnum", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
+ context, "maxnum", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
+ context, "minimum", getOperation());
+ patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
+ context, "maximum", getOperation());
+ patterns
+ .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
+ CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
+ context, getOperation());
+ patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
+ /*isUnsigned=*/false);
+ patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
+ /*isUnsigned=*/true);
+ patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/false);
+ patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
+ /*isUnsigned=*/true);
+ LogicalResult result = success();
+ ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error) {
+ result = failure();
+ }
+ // NB: if you don't return failure, no other diag handlers will fire (see
+ // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
+ return failure();
+ });
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ if (failed(result))
+ return signalPassFailure();
+}
+} // namespace
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
new file mode 100644
index 0000000..31fce7a
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRArithToAPFloat
+ ArithToAPFloat.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRArithTransforms
+ MLIRFuncDialect
+ MLIRFuncUtils
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index ba57155..220826d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -36,20 +37,23 @@ namespace {
/// attribute.
template <typename SourceOp, typename TargetOp, bool Constrained,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = false>
struct ConstrainedVectorConvertToLLVMPattern
- : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
- using VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::VectorConvertToLLVMPattern;
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP> {
+ using VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
return failure();
- return VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::matchAndRewrite(op, adaptor,
- rewriter);
+ return VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
}
};
@@ -78,7 +82,8 @@ struct IdentityBitcastLowering final
using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
@@ -87,53 +92,67 @@ using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
-using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
+using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ExtSIOpLowering =
VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
using ExtUIOpLowering =
VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
using FPToSIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
+ VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using FPToUIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
+ VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using MaximumFOpLowering =
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxNumFOpLowering =
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
using MinimumFOpLowering =
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinNumFOpLowering =
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
using NegFOpLowering =
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -151,21 +170,25 @@ using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
- false>;
+ false, AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
- arith::AttrConverterConstrainedFPToLLVM>;
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
arith::AttrConvertOverflowToLLVM>;
using UIToFPOpLowering =
- VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
+ VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
//===----------------------------------------------------------------------===//
@@ -240,8 +263,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
@@ -259,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
adaptor.getOperands(), op->getAttrs(),
+ /*propAttr=*/Attribute{},
*getTypeConverter(), rewriter);
}
@@ -460,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+ if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+ op.getLhs().getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
Type operandType = adaptor.getLhs().getType();
Type resultType = op.getResult().getType();
LLVM::FastmathFlags fmf =
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8..613dc6d 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
+add_subdirectory(ArithToAPFloat)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 86d02e6..6a0c211 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -96,7 +96,8 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
- op->getAttrs(), *getTypeConverter(), rewriter);
+ op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
+ rewriter);
}
};
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 798d8b0..b75968e 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -137,8 +137,7 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
@@ -163,8 +162,7 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
@@ -204,7 +202,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
+ matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Get or convert default block.
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 93fe2ed..2220f61 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -374,9 +374,12 @@ FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
// Create a memory effect attribute corresponding to readnone.
if (funcOp->hasAttr(readnoneAttrName)) {
auto memoryAttr = LLVM::MemoryEffectsAttr::get(
- rewriter.getContext(),
- {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
- LLVM::ModRefInfo::NoModRef});
+ rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::NoModRef,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef});
newFuncOp.setMemoryEffectsAttr(memoryAttr);
}
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 425594b..f143a9e 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -66,7 +66,10 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
- /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
+ /*errnoMem=*/noModRef,
+ /*targetMem0=*/noModRef,
+ /*targetMem1=*/noModRef);
func.setMemoryEffectsAttr(memAttr);
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d64c4d6..5848489 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -419,7 +419,10 @@ struct LowerGpuOpsToNVVMOpsPass final
if (this->hasRedux)
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
configureGpuToNVVMConversionLegality(target);
- if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(
+ applyPartialConversion(m, target, std::move(llvmPatterns), config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 99c059c..6254de8 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
using namespace mlir;
@@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isF32())
return type.getOperand() == "COp" ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;
-
+ if (type.getElementType().isF64())
+ return NVVM::MMATypes::f64;
if (type.getElementType().isSignedInteger(8))
return NVVM::MMATypes::s8;
if (type.getElementType().isUnsignedInteger(8))
@@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
// then passed on to the intrinsic call. Emit llvm ops to extract individual
// values form lowered memrefs.
SmallVector<Value> unpackedOps;
-
auto unpackOp = [&](Value operand) {
+ // f64 a and b fragments are not structs but scalars.
+ if (!isa<LLVM::LLVMStructType>(operand.getType())) {
+ unpackedOps.push_back(operand);
+ return;
+ }
+ // every other type is lowered to an LLVM struct, extract the values.
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
@@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
return failure();
Location loc = subgroupMmaConstantOp.getLoc();
Value cst = adaptor.getOperands()[0];
- LLVM::LLVMStructType type = convertMMAToLLVMType(
+ Type type = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
+ // If the element is not a struct, it means it's a scalar f64.
+ auto structType = dyn_cast<LLVM::LLVMStructType>(type);
+ if (!structType) {
+ rewriter.replaceOp(subgroupMmaConstantOp, cst);
+ return success();
+ }
// If the element type is a vector create a vector from the operand.
- if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
+ if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = LLVM::ConstantOp::create(rewriter, loc,
@@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
}
cst = vecCst;
}
- Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
- for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
+ for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
matrixStruct =
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
}
@@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
return failure();
Location loc = subgroupMmaElementwiseOp.getLoc();
size_t numOperands = adaptor.getOperands().size();
- LLVM::LLVMStructType destType = convertMMAToLLVMType(
+ Type destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
- Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
- for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
+
+ // If the element is not a struct, it means it's a scalar f64.
+ LLVM::LLVMStructType structDestTy =
+ dyn_cast<LLVM::LLVMStructType>(destType);
+ if (!structDestTy) {
+ SmallVector<Value> operands;
+ for (auto operand : adaptor.getOperands()) {
+ operands.push_back(operand);
+ }
+ Value element = createScalarOp(
+ rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
+ rewriter.replaceOp(subgroupMmaElementwiseOp, element);
+ return success();
+ }
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
+ for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
extractedOperands.push_back(LLVM::ExtractValueOp::create(
@@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
} // namespace
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
+Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
auto nRow = type.getShape()[0];
auto nCol = type.getShape()[1];
std::pair<Type, unsigned> typeInfo =
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
+ // Special handling for f64 a and b fragments
+ Type f64Ty = Float64Type::get(type.getContext());
+ if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+ return f64Ty;
+ }
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
index bc2f2f2..d4b4c46 100644
--- a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -107,16 +107,16 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
- Type n_type = n.getType();
+ Type nType = n.getType();
Value m = adaptor.getRhs();
// Define the constants
- Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 0));
- Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 1));
- Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, -1));
// Compute `x`.
Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
@@ -157,14 +157,14 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
- Type n_type = n.getType();
+ Type nType = n.getType();
Value m = adaptor.getRhs();
// Define the constants
- Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 0));
- Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 0));
+ Value one = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 1));
// Compute the non-zero result.
Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
@@ -193,16 +193,16 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
- Type n_type = n.getType();
+ Type nType = n.getType();
Value m = adaptor.getRhs();
// Define the constants
- Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 0));
- Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, 1));
- Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
- IntegerAttr::get(n_type, -1));
+ Value zero = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 0));
+ Value posOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, 1));
+ Value negOne = spirv::ConstantOp::create(rewriter, loc, nType,
+ IntegerAttr::get(nType, -1));
// Compute `x`.
Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 48a0319..f28a6cc 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -296,19 +296,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Detail methods
//===----------------------------------------------------------------------===//
-void LLVM::detail::setNativeProperties(Operation *op,
- IntegerOverflowFlags overflowFlags) {
- if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
- iface.setOverflowFlags(overflowFlags);
-}
-
/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags) {
+ ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults();
SmallVector<Type> resultTypes;
@@ -320,11 +314,10 @@ LogicalResult LLVM::detail::oneToOneRewrite(
}
// Create the operation through state since we don't know its C++ type.
- Operation *newOp =
- rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
- resultTypes, targetAttrs);
-
- setNativeProperties(newOp, overflowFlags);
+ OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
+ resultTypes, targetAttrs);
+ state.propertiesAttr = propertiesAttr;
+ Operation *newOp = rewriter.create(state);
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b5..e5969c2 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -105,9 +105,9 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
- ArrayRef<NamedAttribute> targetAttrs,
- const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
- IntegerOverflowFlags overflowFlags) {
+ ArrayRef<NamedAttribute> targetAttrs, Attribute propertiesAttr,
+ const LLVMTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
@@ -116,18 +116,38 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
auto llvmNDVectorTy = operands[0].getType();
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
- return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
- rewriter, overflowFlags);
-
- auto callback = [op, targetOp, targetAttrs, overflowFlags,
+ return oneToOneRewrite(op, targetOp, operands, targetAttrs, propertiesAttr,
+ typeConverter, rewriter);
+ auto callback = [op, targetOp, targetAttrs, propertiesAttr,
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
- Operation *newOp =
- rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
- operands, llvm1DVectorTy, targetAttrs);
- LLVM::detail::setNativeProperties(newOp, overflowFlags);
+ OperationState state(op->getLoc(), rewriter.getStringAttr(targetOp),
+ operands, llvm1DVectorTy, targetAttrs);
+ state.propertiesAttr = propertiesAttr;
+ Operation *newOp = rewriter.create(state);
return newOp->getResult(0);
};
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 16ef11a..59a16df 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -93,13 +93,13 @@ public:
/// Different MPI implementations have different communicator types.
/// Using i64 as a portable, intermediate type.
/// Appropriate cast needs to take place before calling MPI functions.
- virtual Value getCommWorld(const Location loc,
+ virtual Value getCommWorld(Location loc,
ConversionPatternRewriter &rewriter) = 0;
/// Type converter provides i64 type for communicator type.
/// Converts to native type, which might be ptr or int or whatever.
- virtual Value castComm(const Location loc,
- ConversionPatternRewriter &rewriter, Value comm) = 0;
+ virtual Value castComm(Location loc, ConversionPatternRewriter &rewriter,
+ Value comm) = 0;
/// Get the MPI_STATUS_IGNORE value (typically a pointer type).
virtual intptr_t getStatusIgnore() = 0;
@@ -109,13 +109,12 @@ public:
/// Gets or creates an MPI datatype as a value which corresponds to the given
/// type.
- virtual Value getDataType(const Location loc,
- ConversionPatternRewriter &rewriter, Type type) = 0;
+ virtual Value getDataType(Location loc, ConversionPatternRewriter &rewriter,
+ Type type) = 0;
/// Gets or creates an MPI_Op value which corresponds to the given
/// enum value.
- virtual Value getMPIOp(const Location loc,
- ConversionPatternRewriter &rewriter,
+ virtual Value getMPIOp(Location loc, ConversionPatternRewriter &rewriter,
mpi::MPI_ReductionOpEnum opAttr) = 0;
};
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index a2dfc12..a922338 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -68,7 +68,7 @@ struct ClampFOpConversion final
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
- typename math::ClampFOp::Adaptor adaptor(operands);
+ math::ClampFOp::Adaptor adaptor(operands);
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getValue(), adaptor.getMin(),
adaptor.getMax());
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 11f866c..0a382d8 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -122,7 +122,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
return totalSizeBytes.getResult();
}
-static emitc::ApplyOp
+static emitc::AddressOfOp
createPointerFromEmitcArray(Location loc, OpBuilder &builder,
TypedValue<emitc::ArrayType> arrayValue) {
@@ -133,9 +133,9 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
- emitc::ApplyOp ptr = emitc::ApplyOp::create(
+ emitc::AddressOfOp ptr = emitc::AddressOfOp::create(
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
- builder.getStringAttr("&"), subPtr);
+ subPtr);
return ptr;
}
@@ -225,12 +225,12 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
auto srcArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
- emitc::ApplyOp srcPtr =
+ emitc::AddressOfOp srcPtr =
createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
auto targetArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
- emitc::ApplyOp targetPtr =
+ emitc::AddressOfOp targetPtr =
createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
@@ -319,8 +319,8 @@ struct ConvertGetGlobal final
emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
- rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
- op, pointerType, rewriter.getStringAttr("&"), globalLValue);
+ rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
+ globalLValue);
return success();
}
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index ec182f1..64a7f56 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -865,13 +865,7 @@ struct NVGPUMBarrierArriveLowering
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
- barrier);
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
- barrier);
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
return success();
}
};
@@ -892,13 +886,8 @@ struct NVGPUMBarrierArriveNoCompleteLowering
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
Value count = truncToI32(b, adaptor.getCount());
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
- op, tokenType, barrier, count);
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
- op, tokenType, barrier, count);
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
+ op, tokenType, barrier, count);
return success();
}
};
@@ -915,13 +904,8 @@ struct NVGPUMBarrierTestWaitLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type retType = rewriter.getI1Type();
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
- op, retType, barrier, adaptor.getToken());
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
- op, retType, barrier, adaptor.getToken());
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(op, retType, barrier,
+ adaptor.getToken());
return success();
}
};
@@ -938,15 +922,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
- op, barrier, txcount, adaptor.getPredicate());
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
- op, barrier, txcount, adaptor.getPredicate());
+ op, Type{}, // return-value is optional and is void by default
+ barrier, txcount, // barrier and txcount
+ NVVM::MemScopeKind::CTA, // default scope is CTA
+ false, // relaxed-semantics is false
+ adaptor.getPredicate());
return success();
}
};
@@ -965,13 +946,6 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
return success();
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 021e31a..7fdc23a 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -66,6 +66,9 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
for (NamedAttribute attr : op->getAttrs()) {
if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
Type convertedType = converter->convertType(typeAttr.getValue());
+ if (!convertedType)
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert type in attribute");
convertedAttrs.emplace_back(attr.getName(),
TypeAttr::get(convertedType));
} else {
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 37cfc9f..03842cc 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -36,6 +36,7 @@ namespace {
struct SCFToControlFlowPass
: public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
+ using Base::Base;
void runOnOperation() override;
};
@@ -736,7 +737,9 @@ void SCFToControlFlowPass::runOnOperation() {
target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
- if (failed(
- applyPartialConversion(getOperation(), target, std::move(patterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(getOperation(), target, std::move(patterns),
+ config)))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 76a822b..309121f 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -453,10 +453,24 @@ static LogicalResult processParallelLoop(
1, 2,
rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) +
rewriter.getAffineSymbolExpr(1));
+ // Map through cloningMap first so we use values valid at the launch
+ // scope, then ensure they are launch-independent (or cloned constants).
+ Value mappedStep = cloningMap.lookupOrDefault(step);
+ Value mappedLowerBound = cloningMap.lookupOrDefault(lowerBound);
+
+ mappedStep = ensureLaunchIndependent(mappedStep);
+ mappedLowerBound = ensureLaunchIndependent(mappedLowerBound);
+
+ // If either cannot be made available above the launch, fail gracefully.
+ if (!mappedStep || !mappedLowerBound) {
+ return rewriter.notifyMatchFailure(
+ parallelOp, "lower bound / step must be constant or defined above "
+ "the gpu.launch");
+ }
+
newIndex = AffineApplyOp::create(
rewriter, loc, annotation.getMap().compose(lowerAndStep),
- ValueRange{operand, ensureLaunchIndependent(step),
- ensureLaunchIndependent(lowerBound)});
+ ValueRange{operand, mappedStep, mappedLowerBound});
// If there was also a bound, insert that, too.
// TODO: Check that we do not assign bounds twice.
if (annotation.getBound()) {
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 460595b..6423d49 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -188,7 +188,8 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
OpBuilder::InsertionGuard guard(builder);
Type type = reduce.getOperands()[reductionIndex].getType();
auto decl = omp::DeclareReductionOp::create(builder, reduce.getLoc(),
- "__scf_reduction", type);
+ "__scf_reduction", type,
+ /*byref_element_type=*/{});
symbolTable.insert(decl);
builder.createBlock(&decl.getInitializerRegion(),
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 50fca56..02b61bd 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1520,20 +1520,12 @@ public:
if (!dstType)
return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
- Location loc = tanOp.getLoc();
- Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
- Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
- rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
+ rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
+ adaptor.getOperands());
return success();
}
};
-/// Convert `spirv.Tanh` to
-///
-/// exp(2x) - 1
-/// -----------
-/// exp(2x) + 1
-///
class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
public:
using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
@@ -1546,18 +1538,8 @@ public:
if (!dstType)
return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
- Location loc = tanhOp.getLoc();
- Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
- Value multiplied =
- LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
- Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
- Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
- Value numerator =
- LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
- Value denominator =
- LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
- rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
- denominator);
+ rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
+ adaptor.getOperands());
return success();
}
};
diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
index 9921a06..feb0489 100644
--- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
+++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
@@ -23,8 +23,11 @@ namespace mlir {
using namespace mlir;
-namespace {
+//===----------------------------------------------------------------------===//
+// PoisonOpLowering
+//===----------------------------------------------------------------------===//
+namespace {
struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
-
} // namespace
-//===----------------------------------------------------------------------===//
-// PoisonOpLowering
-//===----------------------------------------------------------------------===//
-
LogicalResult
PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -61,6 +59,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
}
//===----------------------------------------------------------------------===//
+// UnreachableOpLowering
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct UnreachableOpLowering
+ : public ConvertOpToLLVMPattern<ub::UnreachableOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+LogicalResult
+
+UnreachableOpLowering::matchAndRewrite(
+ ub::UnreachableOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<LLVM::UnreachableOp>(op);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -93,7 +114,7 @@ struct UBToLLVMConversionPass
void mlir::ub::populateUBToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<PoisonOpLowering>(converter);
+ patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index 244d214..3831387 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
}
};
+struct UnreachableOpLowering final : OpConversionPattern<ub::UnreachableOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(ub::UnreachableOp op, OpAdaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<spirv::UnreachableOp>(op);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final
void mlir::ub::populateUBToSPIRVConversionPatterns(
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<PoisonOpLowering>(converter, patterns.getContext());
+ patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 69a317ec..05d541f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -345,7 +345,8 @@ public:
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
- MemRefType memRefType = scatter.getMemRefType();
+ auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
+ assert(memRefType && "The base should be bufferized");
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
@@ -1654,6 +1655,20 @@ private:
return failure();
}
}
+ } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
+ // Print other floating-point types using the APFloat runtime library.
+ int32_t sem =
+ llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
+ Value semValue = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI32Type(),
+ rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
+ Value floatBits =
+ LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
+ printer =
+ LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
+ emitCall(rewriter, loc, printer.value(),
+ ValueRange({semValue, floatBits}));
+ return success();
} else {
return failure();
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 91c1aa5..079e1e2 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
- xegpu::TensorDescType descType, TypedValue<MemRefType> src,
- Operation::operand_range offsets) {
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+ Location loc,
+ xegpu::TensorDescType descType,
+ TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
auto [strides, offset] = srcTy.getStridesAndOffset();
xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- getAsOpFoldResult(offsets));
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
- SmallVector<Value> sourceDims;
- unsigned srcRank = srcTy.getRank();
- for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
- SmallVector<int64_t> constOffsets;
- SmallVector<Value> dynOffsets;
- for (Value offset : offsets) {
- std::optional<int64_t> staticVal = getConstantIntValue(offset);
- if (!staticVal)
- dynOffsets.push_back(offset);
- constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
- }
-
- SmallVector<Value> dynShapes;
- for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
- if (shape == ShapedType::kDynamic)
- dynShapes.push_back(sourceDims[idx]);
- }
-
- // Compute strides in reverse order.
- SmallVector<Value> dynStrides;
- Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
- // Last stride is guaranteed to be static and unit.
- for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
- accStride =
- arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
- if (strides[i] == ShapedType::kDynamic)
- dynStrides.push_back(accStride);
- }
- std::reverse(dynStrides.begin(), dynStrides.end());
-
- ndDesc = xegpu::CreateNdDescOp::create(
- rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
- DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
- DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
- DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ meta.getConstifiedMixedSizes(),
+ meta.getConstifiedMixedStrides());
}
return ndDesc;
@@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
.getResult();
}
+// Collapses shapes of a nD memref to the target rank while applying offsets for
+// the collapsed dimensions. Returns the new memref value and the remaining
+// offsets for the last targetRank dimensions. For example:
+// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
+// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
+static std::pair<Value, SmallVector<OpFoldResult>>
+convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
+ Value memref,
+ SmallVector<OpFoldResult> offsets,
+ int64_t targetRank) {
+ auto memrefType = cast<MemRefType>(memref.getType());
+ unsigned rank = memrefType.getRank();
+
+ if (rank <= targetRank)
+ return {memref, offsets};
+
+ int64_t numCombinedDims = rank - targetRank;
+ SmallVector<OpFoldResult> subviewOffsets;
+ SmallVector<OpFoldResult> subviewSizes;
+ SmallVector<OpFoldResult> subviewStrides;
+
+ // For the combined dimensions: use the provided offsets, size=1, stride=1
+ for (unsigned i = 0; i < numCombinedDims; ++i) {
+ subviewOffsets.push_back(offsets[i]);
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ // For the last targetRank dimensions: offset=0, use full size, stride=1
+ SmallVector<int64_t> resultShape;
+ auto originalShape = memrefType.getShape();
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
+ for (unsigned i = numCombinedDims; i < rank; ++i) {
+ subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ if (ShapedType::isDynamic(originalShape[i])) {
+ subviewSizes.push_back(meta.getSizes()[i]);
+ resultShape.push_back(ShapedType::kDynamic);
+ } else {
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
+ resultShape.push_back(originalShape[i]);
+ }
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ auto resultType = memref::SubViewOp::inferRankReducedResultType(
+ resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
+ auto subviewOp =
+ memref::SubViewOp::create(rewriter, loc, resultType, memref,
+ subviewOffsets, subviewSizes, subviewStrides);
+
+ // Return the remaining offsets for the last targetRank dimensions
+ SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
+ offsets.end());
+ return {subviewOp.getResult(), newOffsets};
+}
+
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
@@ -435,7 +457,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
rewriter.replaceOp(readOp, gatherOp.getResult());
return success();
@@ -469,7 +492,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
rewriter.eraseOp(writeOp);
return success();
}
@@ -495,8 +519,13 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
return lowerToScatteredLoadOp(readOp, rewriter);
}
- // Perform common data transfer checks.
VectorType vecTy = readOp.getVectorType();
+
+ // Lower using load.gather in 1D case
+ if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
+ return lowerToScatteredLoadOp(readOp, rewriter);
+
+ // Perform common data transfer checks.
if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
return failure();
@@ -523,21 +552,23 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
- readOp.getIndices());
-
DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
+ vecTy.getRank());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+ auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -575,21 +606,24 @@ struct TransferWriteLowering
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, writeOp.getBase(),
+ getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
- writeOp.getIndices());
-
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto storeOp =
- xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+ auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
+ ndDesc, indices,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -621,7 +655,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
auto selectOp =
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -655,7 +690,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
rewriter.eraseOp(scatterOp);
return success();
}
@@ -674,19 +710,25 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
+ vecTy.getRank());
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
- // By default, no specific caching policy is assigned.
- xegpu::CachePolicyAttr hint = nullptr;
- auto loadNdOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+ auto loadNdOp =
+ xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+ /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -708,18 +750,25 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, storeOp.getBase(),
+ getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
auto storeNdOp =
- xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
/*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ /*l2_hint=*/hint, /*l3_hint=*/hint,
+ /*layout=*/nullptr);
+
rewriter.replaceOp(storeOp, storeNdOp);
return success();
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2e..0ecb50e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
- BasePtr = 0, // Base pointer (i64)
- BaseShapeW = 2, // Base shape width (i32)
- BaseShapeH = 3, // Base shape height (i32)
- TensorOffsetW = 4, // Tensor offset W (i32)
- TensorOffsetH = 5 // Tensor offset H (i32)
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ BasePitch = 4, // Base pitch (i32)
};
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -151,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
}
}
+//
+// Note:
+// Block operations for tile of sub byte element types are handled by
+// emulating with larger element types.
+// Tensor descriptor are keep intact and only ops consuming them are
+// emulated
+//
+
class CreateNdDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern::OpConversionPattern;
@@ -179,16 +186,12 @@ class CreateNdDescToXeVMPattern
Value baseAddr;
Value baseShapeW;
Value baseShapeH;
- Value offsetW;
- Value offsetH;
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
- if (rank != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
-
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -197,10 +200,20 @@ class CreateNdDescToXeVMPattern
if (!sourceMemrefTy.hasRank()) {
return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
}
- baseAddr =
- memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ // Access adaptor after failure check to avoid rolling back generated code
+ // for materialization cast.
+ baseAddr = adaptor.getSource();
} else {
baseAddr = adaptor.getSource();
+ if (baseAddr.getType() != i64Ty) {
+ // Pointer type may be i32. Cast to i64 if needed.
+ baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ }
+ }
+ // 1D tensor descriptor is just the base address.
+ if (rank == 1) {
+ rewriter.replaceOp(op, baseAddr);
+ return success();
}
// Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -209,19 +222,11 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- // Offsets are not supported (0 is used).
- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
- if (sourceMemrefTy) {
- // Cast index to i64.
- baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
- } else if (baseAddr.getType() != i64Ty) {
- // Pointer type may be i32. Cast to i64 if needed.
- baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
- }
+ // Get pitch value from op fold results.
+ Value basePitch = createOffset(mixedStrides, 0);
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +240,9 @@ class CreateNdDescToXeVMPattern
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetW, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetW));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetH, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload =
+ vector::InsertOp::create(rewriter, loc, basePitch, payload,
+ static_cast<int>(NdTdescOffset::BasePitch));
rewriter.replaceOp(op, payload);
return success();
}
@@ -257,108 +259,240 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
ConversionPatternRewriter &rewriter) const override {
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
- if (opOffsetsSize != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ auto tileRank = tdescTy.getRank();
+ if (opOffsetsSize != tileRank)
+ return rewriter.notifyMatchFailure(
+ op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
- if (elemBitSize % 8 != 0)
+ bool isSubByte = elemBitSize < 8;
+ uint64_t wScaleFactor = 1;
+
+ if (!isSubByte && (elemBitSize % 8 != 0))
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
+ auto tileW = tdescTy.getDimSize(tileRank - 1);
+ // For sub byte types, only 4bits are currently supported.
+ if (isSubByte) {
+ if (elemBitSize != 4)
+ return rewriter.notifyMatchFailure(
+ op, "Only sub byte types of 4bits are supported.");
+ if (tileRank != 2)
+ return rewriter.notifyMatchFailure(
+ op, "Sub byte types are only supported for 2D tensor descriptors.");
+ auto subByteFactor = 8 / elemBitSize;
+ auto tileH = tdescTy.getDimSize(0);
+ // Handle special case for packed load.
+ if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ if (op.getPacked().value_or(false)) {
+ // packed load is implemented as packed loads of 8bit elements.
+ if (tileH == systolicDepth * 4 &&
+ tileW == executionSize * subByteFactor) {
+ // Usage case for loading as Matrix B with pack request.
+ // source is assumed to pre-packed into 8bit elements
+ // Emulate with 8bit loads with pack request.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(8);
+ tileW = executionSize;
+ wScaleFactor = subByteFactor;
+ }
+ }
+ }
+ // If not handled by packed load case above, handle other cases.
+ if (wScaleFactor == 1) {
+ auto sub16BitFactor = subByteFactor * 2;
+ if (tileW == executionSize * sub16BitFactor) {
+ // Usage case for loading as Matrix A operand
+ // Emulate with 16bit loads/stores.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(16);
+ tileW = executionSize;
+ wScaleFactor = sub16BitFactor;
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported tile shape for sub byte types.");
+ }
+ }
+ // recompute element bit size for emulation.
+ elemBitSize = elemType.getIntOrFloatBitWidth();
+ }
- VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
- Value payLoadAsI64 =
- vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
- Value basePtr = vector::ExtractOp::create(
- rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
- Value baseShapeW = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
- Value baseShapeH = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
- // Offsets are provided by the op.
- // convert them to i32.
- Value offsetW =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
- offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetW);
- Value offsetH =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
- offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- // Convert base pointer (i64) to LLVM pointer type.
- Value basePtrLLVM =
- LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
- // Compute element byte size and surface width in bytes.
- Value elemByteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
- Value surfaceW =
- arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
-
- // Get tile sizes and vblocks from the tensor descriptor type.
- auto tileW = tdescTy.getDimSize(1);
- auto tileH = tdescTy.getDimSize(0);
- int32_t vblocks = tdescTy.getArrayLength();
- if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- Value src = adaptor.getValue();
- // If store value is a scalar, get value from op instead of adaptor.
- // Adaptor might have optimized away single element vector
- if (src.getType().isIntOrFloat()) {
- src = op.getValue();
+ if (tileRank == 2) {
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr =
+ vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+ static_cast<int>(NdTdescOffset::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ Value basePitch = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
+ // Offsets are provided by the op.
+ // convert them to i32.
+ Value offsetW =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ Value offsetH =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // FIXME: width or pitch is not the same as baseShapeW it should be the
+ // stride of the second to last dimension in row major layout.
+ // Compute width in bytes.
+ Value baseShapeWInBytes =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+ // Compute pitch in bytes.
+ Value basePitchBytes =
+ arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
+
+ if (wScaleFactor > 1) {
+ // Scale offsetW, baseShapeWInBytes for sub byte emulation.
+ // Note: tileW is already scaled above.
+ Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
+ baseShapeWInBytes = arith::ShRSIOp::create(
+ rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
+ basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
+ wScaleFactorValLog2);
+ offsetW =
+ arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
}
- VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
- if (!srcVecTy)
- return rewriter.notifyMatchFailure(
- op, "Expected store value to be a vector type.");
- // Get flat vector type of integer type with matching element bit size.
- VectorType newSrcVecTy =
- encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
- if (srcVecTy != newSrcVecTy)
- src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, src,
- xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
- rewriter.eraseOp(op);
- } else {
- auto loadCacheControl =
- translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
- xevm::BlockPrefetch2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, vblocks,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ // Get tile height from the tensor descriptor type.
+ auto tileH = tdescTy.getDimSize(0);
+ // Get vblocks from the tensor descriptor type.
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ xevm::BlockStore2dOp::create(
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
- VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
- const bool vnni = op.getPacked().value_or(false);
- auto transposeValue = op.getTranspose();
- bool transpose =
- transposeValue.has_value() && transposeValue.value()[0] == 1;
- VectorType loadedTy = encodeVectorTypeTo(
- dstVecTy, vnni ? rewriter.getI32Type()
- : rewriter.getIntegerType(elemBitSize));
-
- Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- transpose, vnni,
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
+ vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
+ baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
+ tileH, vblocks, transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
+ }
+ } else {
+ // 1D tensor descriptor.
+ // `tdesc` represents base address as i64
+ // Offset in number of elements, need to multiply by element byte size.
+ // Compute byte offset.
+ // byteOffset = offset * elementByteSize
+ Value offset =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI64Type(), offset);
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
+ Value byteOffset =
+ rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
+ // Final address = basePtr + byteOffset
+ Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
+ loc, tdesc,
+ getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
+ byteOffset));
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value finalPtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
+ op, finalPtrLLVM, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ VectorType resTy = cast<VectorType>(op.getValue().getType());
+ VectorType loadedTy =
+ encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
+ Value load = xevm::BlockLoadOp::create(
+ rewriter, loc, loadedTy, finalPtrLLVM,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
- resultFlatVec = vector::BitCastOp::create(
- rewriter, loc,
- encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
- resultFlatVec);
- rewriter.replaceOp(op, resultFlatVec);
+ if (loadedTy != resTy)
+ load = vector::BitCastOp::create(rewriter, loc, resTy, load);
+ rewriter.replaceOp(op, load);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported operation: xegpu.prefetch_nd with tensor "
+ "descriptor rank == 1");
}
}
return success();
@@ -511,9 +645,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
-// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
-// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
-// 32 bits will be converted to 32 bits.
class CreateMemDescOpPattern final
: public OpConversionPattern<xegpu::CreateMemDescOp> {
public:
@@ -522,16 +653,7 @@ public:
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resTy = op.getMemDesc();
-
- // Create the result MemRefType with the same shape, element type, and
- // memory space
- auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-
- Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
- auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
- op.getSource(), zero, ValueRange());
- rewriter.replaceOp(op, viewOp);
+ rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};
@@ -551,17 +673,27 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
- Value basePtrStruct = adaptor.getMemDesc();
+ Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
- Value data;
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
- data = op.getResult();
- else
- data = adaptor.getData();
- VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ Type dataTy;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Type resType = op.getResult().getType();
+ // Some transforms may leave unit dimension in the 2D vector, adaptors do
+ // not catch it for results.
+ if (auto vecType = dyn_cast<VectorType>(resType)) {
+ assert(llvm::count_if(vecType.getShape(),
+ [](int64_t d) { return d != 1; }) <= 1 &&
+ "Expected either 1D vector or nD with unit dimensions");
+ resType = VectorType::get({vecType.getNumElements()},
+ vecType.getElementType());
+ }
+ dataTy = resType;
+ } else
+ dataTy = adaptor.getData().getType();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
if (!valOrResVecTy)
- valOrResVecTy = VectorType::get(1, data.getType());
+ valOrResVecTy = VectorType::get(1, dataTy);
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -577,21 +709,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
- Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
- rewriter, loc, basePtrStruct);
-
- // Convert base pointer (ptr) to i32
- Value basePtrI32 = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
-
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
- basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
- elemByteSize);
+ Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
+ linearOffset, elemByteSize);
// convert base pointer (i32) to LLVM pointer type
- basePtrLLVM =
+ Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
if (op.getSubgroupBlockIoAttr()) {
@@ -927,20 +1052,22 @@ struct ConvertXeGPUToXeVMPass
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+ // Scattered descriptors are not supported in XeVM lowering.
if (type.isScattered())
+ return {};
+ if (type.getRank() == 1)
return IntegerType::get(&getContext(), 64);
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
- // Convert MemDescType into flattened MemRefType for SLM
+ // Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
- Type elemTy = type.getElementType();
- int numElems = type.getNumElements();
- return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ return IntegerType::get(&getContext(), 32);
});
typeConverter.addConversion([&](MemRefType type) -> Type {
- // Convert MemRefType to i64 type.
+ if (type.getMemorySpaceAsInt() == 3)
+ return IntegerType::get(&getContext(), 32);
return IntegerType::get(&getContext(), 64);
});
@@ -1057,6 +1184,7 @@ struct ConvertXeGPUToXeVMPass
};
typeConverter.addSourceMaterialization(
singleElementVectorMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index f276984..20a420d 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -290,7 +290,7 @@ static LLVM::CallOp createDeviceFunctionCall(
ArrayRef<Type> argTypes, ArrayRef<Value> args,
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
- auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+ auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
Location loc = op->getLoc();
@@ -401,7 +401,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
- /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
Value result =
@@ -450,7 +453,10 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
- /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
@@ -556,7 +562,10 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
- /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr = noUnwindAttrs;
funcAttr.memEffectsAttr = memAttr;
} else {
@@ -798,7 +807,10 @@ class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
- /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
+ /*errnoMem=*/noModRef,
+ /*targetMem0=*/noModRef,
+ /*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
@@ -836,7 +848,10 @@ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
- /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
+ /*errnoMem=*/noModRef,
+ /*targetMem0=*/noModRef,
+ /*targetMem1=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();