aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp286
-rw-r--r--mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp44
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp321
-rw-r--r--mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp8
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp106
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp38
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp141
-rw-r--r--mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp9
-rw-r--r--mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp214
-rw-r--r--mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp113
-rw-r--r--mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp107
-rw-r--r--mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp32
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt2
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp41
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp1
-rw-r--r--mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp11
-rw-r--r--mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp2
-rw-r--r--mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp1
-rw-r--r--mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp4
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp32
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp34
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp15
-rw-r--r--mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp13
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp3
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp7
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp78
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp36
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp53
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp2
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp2
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp3
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp2
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp12
-rw-r--r--mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp1
-rw-r--r--mlir/lib/Conversion/ShardToMPI/CMakeLists.txt (renamed from mlir/lib/Conversion/MeshToMPI/CMakeLists.txt)8
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (renamed from mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp)146
-rw-r--r--mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp134
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp158
-rw-r--r--mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp76
-rw-r--r--mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp75
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp71
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp13
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp5
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp64
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp1
52 files changed, 1295 insertions, 1260 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 309476c..64720bf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
if (i32 == valTy)
return val;
return valTy.getWidth() > 32
- ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
- : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+ ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
+ : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type i32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(rewriter, loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
+ return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}
/// Returns the linear index used to access an element in the memref.
@@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
ShapedType::isDynamic(stride)
? convertUnsignedToI32(rewriter, loc,
memRefDescriptor.stride(rewriter, loc, i))
- : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
- increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
+ : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
+ increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
}
- index =
- index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
+ : increment;
}
return index ? index : createI32Constant(rewriter, loc, 0);
}
@@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
maxIndex = maxIndex
- ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
: maxThisDim;
}
Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
- return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
+ return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
@@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
Value stride;
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
Value cacheStrideZext =
- rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
- Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(1 << 14));
- stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
- /*isDisjoint=*/true);
+ LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
+ Value swizzleBit = LLVM::ConstantOp::create(
+ rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
} else {
- stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
- rewriter.getI16IntegerAttr(0));
+ stride = LLVM::ConstantOp::create(rewriter, loc, i16,
+ rewriter.getI16IntegerAttr(0));
}
// Get the number of elements.
// Flag word:
@@ -209,20 +209,21 @@ struct FatRawBufferCastLowering
: descriptor.alignedPtr(rewriter, loc);
Value offset = adaptor.getResetOffset()
- ? rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(0))
+ ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(0))
: descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
- Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kSizePosInMemRefDescriptor)
- : Value{};
- Value strides = hasSizes
- ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kStridePosInMemRefDescriptor)
- : Value{};
+ Value sizes = hasSizes
+ ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kSizePosInMemRefDescriptor)
+ : Value{};
+ Value strides =
+ hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kStridePosInMemRefDescriptor)
+ : Value{};
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
@@ -231,17 +232,17 @@ struct FatRawBufferCastLowering
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
- kOffsetPosInMemRefDescriptor);
+ SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor};
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAlignedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
if (hasSizes) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
- kSizePosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, strides, kStridePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
+ kStridePosInMemRefDescriptor);
}
rewriter.replaceOp(op, result);
return success();
@@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForStore =
- rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
+ Value castForStore = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
@@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForCmp = rewriter.create<LLVM::BitcastOp>(
- loc, llvmBufferValType, atomicCmpData);
+ Value castForCmp = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
@@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
indexOffset && *indexOffset > 0) {
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
- voffset =
- voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
- : extraOffsetConst;
+ voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
+ extraOffsetConst)
+ : extraOffsetConst;
}
- voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
+ voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
args.push_back(voffset);
// SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
+ sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
@@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
- Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
- replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
- replacement);
+ replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
+ replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
@@ -480,16 +481,16 @@ struct MemoryCounterWaitOpLowering
if (chipset.majorVersion >= 12) {
Location loc = op.getLoc();
if (std::optional<int> ds = adaptor.getDs())
- rewriter.create<ROCDL::WaitDscntOp>(loc, *ds);
+ ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
if (std::optional<int> load = adaptor.getLoad())
- rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load);
+ ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
if (std::optional<int> store = adaptor.getStore())
- rewriter.create<ROCDL::WaitStorecntOp>(loc, *store);
+ ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
if (std::optional<int> exp = adaptor.getExp())
- rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp);
+ ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
rewriter.eraseOp(op);
return success();
@@ -571,12 +572,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
<< chipset.majorVersion;
Location loc = op->getLoc();
- rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
+ ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
- rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
- rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
+ ROCDL::WaitDscntOp::create(rewriter, loc, 0);
+ ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
}
@@ -622,19 +623,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
- return rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
vectorType.getNumElements() <= 8)
- return rewriter.create<LLVM::BitcastOp>(
- loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
32);
- return rewriter.create<LLVM::BitcastOp>(
- loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+ input);
}
}
return input;
@@ -655,8 +658,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
Type outputType = rewriter.getI32Type();
if (auto intType = dyn_cast<IntegerType>(inputType))
- return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
- return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
+ return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
}
/// Push an input operand. If it is a float type, nothing to do. If it is
@@ -682,8 +685,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- llvmInput = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
+ llvmInput = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -719,7 +722,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
// Add in the zeros here.
if (numBits < 32)
- castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
+ castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
operands.push_back(castInput);
}
@@ -739,8 +742,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- output = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), output);
+ output = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
@@ -1098,7 +1101,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
- lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+ lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
@@ -1198,8 +1201,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Operation *maybeCastBack = lowered;
if (rawOutType != outType)
- maybeCastBack =
- rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+ maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
+ lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
@@ -1249,22 +1252,22 @@ struct TransposeLoadOpLowering
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
@@ -1422,21 +1425,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
- Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (resultVecType) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
@@ -1488,21 +1491,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
- Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
+ Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
@@ -1560,54 +1563,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
Value scale = adaptor.getScale();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, intResultType, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
else
- existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
+ existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
if (sourceVecType.getNumElements() < 2) {
Value c0 = createI32Constant(rewriter, loc, 0);
- Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
+ Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
VectorType v2 = VectorType::get(2, sourceElemType);
- source = rewriter.create<LLVM::ZeroOp>(loc, v2);
- source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
+ source = LLVM::ZeroOp::create(rewriter, loc, v2);
+ source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
}
Value sourceA, sourceB;
if (sourceElemType.isF32()) {
Value c0 = createI32Constant(rewriter, loc, 0);
Value c1 = createI32Constant(rewriter, loc, 1);
- sourceA = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
- sourceB = rewriter.create<LLVM::ExtractElementOp>(loc, source, c1);
+ sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
+ sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
}
Value result;
if (sourceElemType.isF32() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E5M2Type>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float8E4M3FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isF32() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F32Op>(
- loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
+ existing, sourceA, sourceB,
+ scale, op.getIndex());
else if (sourceElemType.isF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4F16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4F16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else if (sourceElemType.isBF16() && isa<Float4E2M1FNType>(resultElemType))
- result = rewriter.create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
- loc, intResultType, existing, source, scale, op.getIndex());
+ result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
+ rewriter, loc, intResultType, existing, source, scale, op.getIndex());
else
return failure();
@@ -1632,20 +1638,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value sourceA = adaptor.getSourceA();
Value sourceB = adaptor.getSourceB();
if (!sourceB)
- sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+ sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType());
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
- existing, op.getWordIndex());
+ result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
+ existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1669,17 +1675,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value stoch = adaptor.getStochiasticParam();
Value existing = adaptor.getExisting();
if (existing)
- existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
+ existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
else
- existing = rewriter.create<LLVM::UndefOp>(loc, i32);
+ existing = LLVM::UndefOp::create(rewriter, loc, i32);
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
- result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
- loc, i32, source, stoch, existing, op.getStoreIndex());
+ result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
+ existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1723,14 +1729,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
if (operandType.getIntOrFloatBitWidth() <= 16) {
if (llvm::isa<FloatType>(operandType)) {
operand =
- rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
+ LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
}
auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
- Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
- operand = rewriter.create<LLVM::InsertElementOp>(
- loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
- operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
+ Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
+ operand =
+ LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
+ createI32Constant(rewriter, loc, 0));
+ operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
}
return operand;
};
@@ -1817,14 +1824,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
// create a ROCDL_DPPMovOp instruction with the appropriate attributes
- auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
- loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
+ auto dppMovOp =
+ ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
+ rowMask, bankMask, boundCtrl);
Value result = dppMovOp.getRes();
if (srcType.getIntOrFloatBitWidth() < 32) {
- result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
+ result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
if (!llvm::isa<IntegerType>(srcType)) {
- result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
+ result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
}
}
@@ -1858,7 +1866,7 @@ struct AMDGPUSwizzleBitModeLowering
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
- rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
+ ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 3b143ca..3b148f9 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc,
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
if (predicate == arith::CmpIPredicate::sgt)
- value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
+ value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
else
- value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
+ value = arith::MinSIOp::create(builder, loc, value, *valueIt);
}
return value;
@@ -154,9 +154,9 @@ public:
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
Value step =
- rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
- auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
- step, op.getInits());
+ arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt());
+ auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
+ step, op.getInits());
rewriter.eraseBlock(scfForOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
scfForOp.getRegion().end());
@@ -197,7 +197,7 @@ public:
}
steps.reserve(op.getSteps().size());
for (int64_t step : op.getSteps())
- steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
+ steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step));
// Get the terminator op.
auto affineParOpTerminator =
@@ -205,9 +205,9 @@ public:
scf::ParallelOp parOp;
if (op.getResults().empty()) {
// Case with no reduction operations/return values.
- parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
- upperBoundTuple, steps,
- /*bodyBuilderFn=*/nullptr);
+ parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
+ upperBoundTuple, steps,
+ /*bodyBuilderFn=*/nullptr);
rewriter.eraseBlock(parOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
parOp.getRegion().end());
@@ -233,9 +233,9 @@ public:
identityVals.push_back(
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
}
- parOp = rewriter.create<scf::ParallelOp>(
- loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
- /*bodyBuilderFn=*/nullptr);
+ parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
+ upperBoundTuple, steps, identityVals,
+ /*bodyBuilderFn=*/nullptr);
// Copy the body of the affine.parallel op.
rewriter.eraseBlock(parOp.getBody());
@@ -261,7 +261,7 @@ public:
Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
reductionBody.getArgument(1));
- rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
+ scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
}
rewriter.replaceOp(op, parOp.getResults());
return success();
@@ -278,7 +278,7 @@ public:
// Now we just have to handle the condition logic.
auto integerSet = op.getIntegerSet();
- Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::ArrayRef(operands);
@@ -298,18 +298,18 @@ public:
auto pred =
isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
Value cmpVal =
- rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
- cond = cond
- ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
- : cmpVal;
+ arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
+ cond =
+ cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
+ : cmpVal;
}
cond = cond ? cond
- : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
- /*width=*/1);
+ : arith::ConstantIntOp::create(rewriter, loc, /*value=*/1,
+ /*width=*/1);
bool hasElseRegion = !op.getElseRegion().empty();
- auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
- hasElseRegion);
+ auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
+ hasElseRegion);
rewriter.inlineRegionBefore(op.getThenRegion(),
&ifOp.getThenRegion().back());
rewriter.eraseBlock(&ifOp.getThenRegion().back());
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 156c679..8230591 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -49,8 +49,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
Chipset chipset;
- ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+ ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
+ PatternBenefit benefit)
+ : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const override;
@@ -59,9 +60,9 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
- Chipset chipset)
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
- chipset(chipset) {}
+ Chipset chipset, PatternBenefit benefit)
+ : OpRewritePattern::OpRewritePattern(ctx, benefit),
+ saturateFP8(saturateFP8), chipset(chipset) {}
Chipset chipset;
LogicalResult matchAndRewrite(arith::TruncFOp op,
@@ -81,9 +82,6 @@ struct ScalingExtFRewritePattern final
: OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;
- ScalingExtFRewritePattern(MLIRContext *ctx)
- : OpRewritePattern::OpRewritePattern(ctx) {}
-
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const override;
};
@@ -92,9 +90,6 @@ struct ScalingTruncFRewritePattern final
: OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;
- ScalingTruncFRewritePattern(MLIRContext *ctx)
- : OpRewritePattern::OpRewritePattern(ctx) {}
-
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const override;
};
@@ -115,9 +110,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::TruncFOp>(loc, desType, f32);
+ return arith::TruncFOp::create(rewriter, loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::ExtFOp>(loc, desType, f32);
+ return arith::ExtFOp::create(rewriter, loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -139,27 +134,27 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), in, 0);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = inVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());
if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
Value scalarExt =
- rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
- ArrayRef<int64_t>{});
+ arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
+ zerodSplat, ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
@@ -171,32 +166,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
inVecType.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
- Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, i, elemsThisOp, 1);
+ Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
+ elemsThisOp, 1);
for (int64_t j = 0; j < elemsThisOp; j += 2) {
if (i + j + 1 < numElements) { // Convert two 8-bit elements
- Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, extResType, inSlice, j / 2);
+ Value asFloats = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, extResType, inSlice, j / 2);
Type desType = VectorType::get(2, outElemType);
Value asType = castF32To(desType, asFloats, loc, rewriter);
- result = rewriter.create<vector::InsertStridedSliceOp>(
- loc, asType, result, i + j, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
+ result, i + j, 1);
} else { // Convert a 8-bit element
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
+ Value asFloat = amdgpu::ExtPackedFp8Op::create(
+ rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
}
}
}
if (inVecType.getRank() != outType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
}
rewriter.replaceOp(op, result);
@@ -208,9 +203,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
if (type.isF32())
return value;
if (type.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
+ return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
if (type.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
+ return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -250,13 +245,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
loc, arith::CmpFPredicate::OEQ, source, negInf);
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, source, source);
- Value isNonFinite = rewriter.create<arith::OrIOp>(
- loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+ Value isNonFinite = arith::OrIOp::create(
+ rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
+ isNan);
- Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
- Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+ Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
+ Value clamped =
+ arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
Value res =
- rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+ arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
return res;
}
@@ -290,25 +287,25 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType truncResType = VectorType::get(4, outElemType);
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
- Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
+ Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = outVecType.getNumElements();
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
+ Value zero = arith::ConstantOp::create(
+ rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
if (outVecType.getShape().empty()) {
Value scalarIn =
- rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
+ vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
- rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
- ArrayRef<int64_t>{});
+ arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
+ Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
}
@@ -320,32 +317,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
- Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
+ Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
asFloatB = castToF32(elemB, loc, rewriter);
}
- thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
- loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
+ thisResult = amdgpu::PackedTrunc2xFp8Op::create(
+ rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
}
if (elemsThisOp < 4)
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -373,10 +370,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
// Handle the case where input type is not a vector type
if (!inVectorTy) {
- auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
Value asF16s =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
- Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
+ Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
rewriter.replaceOp(op, result);
return success();
}
@@ -389,7 +386,7 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
- in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}
// Handle the vector case. We also handle the (uncommon) case where the vector
@@ -397,25 +394,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
- Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
- Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
+ Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
+ Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
if (elemsThisOp == 2) {
- elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
+ elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
}
thisResult =
- rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
+ ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
// Place back the truncated result into the possibly larger vector. If we
// are operating on a size 2 vector, these operations should be folded away
- thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, thisResult, 0, elemsThisOp, 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
- result, i, 1);
+ thisResult = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, thisResult, 0, elemsThisOp, 1);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
+ result, i, 1);
}
if (inVectorTy.getRank() != outVecType.getRank()) {
- result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
+ result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}
rewriter.replaceOp(op, result);
@@ -452,7 +449,7 @@ LogicalResult
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opOutWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -463,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Type scaleType = getElementTypeOrSelf(scale);
Type outType = getElementTypeOrSelf(out);
+ int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
+
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
@@ -472,28 +471,29 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- VectorType extScaleResultType = VectorType::get(opWidth, outType);
+ VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
- Value inCast = rewriter.create<vector::BroadcastOp>(
- loc, VectorType::get(1, inType), in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc,
+ VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, inCast, scale, 0);
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inCast, scale, 0);
scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
return success();
}
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+ if (origScaleVecType)
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
originalScaleShape.insert(originalScaleShape.end(),
@@ -508,45 +508,52 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
int64_t blockSize = computeProduct(ratio);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
Value result =
rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+ for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
- loc, extScaleResultType, slice, uniformScale, 0);
- if (sliceWidth != opWidth)
- scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleExt, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleExt, blockResult, i, 1);
+ i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+ Value inSlice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, inSliceWidth, 1);
+ for (int64_t j = 0,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+ j < inSliceWidth; j += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
+ j / opOutWidth);
+ if (outSliceWidth < opOutWidth) {
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i + j, 1);
+ }
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -558,7 +565,7 @@ LogicalResult
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opInWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -571,28 +578,28 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
if (outVecType && outVecType.isScalable())
return failure();
Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
- scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
- scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+ scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, outType, rewriter.getFloatAttr(outType, 0.0));
- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
- VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+ Value zero = arith::ConstantOp::create(rewriter, loc, outType,
+ rewriter.getFloatAttr(outType, 0.0));
+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+ VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
- Value inCast = rewriter.create<vector::BroadcastOp>(loc, inVecType, in);
+ Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
+ Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, inCast, scale, 0,
+ /*existing=*/nullptr);
scaleTrunc =
rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
return success();
@@ -600,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
- SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
- llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+ SmallVector<int64_t> scaleShape;
+ if (origScaleVecType)
+ llvm::append_range(scaleShape, origScaleVecType.getShape());
- originalScaleShape.insert(originalScaleShape.end(),
- inShape.size() - originalScaleShape.size(), 1);
+ scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
- auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+ auto maybeRatio = computeShapeRatio(inShape, scaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");
@@ -623,41 +630,57 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value block = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, in, offsets, ratio, strides);
+ Value block = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
- rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
Value uniformScale =
- rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+ vector::ExtractOp::create(rewriter, loc, scale, offsets);
VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
- i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
- loc, truncScaleResultType, slice, uniformScale, 0,
- /*existing=*/nullptr);
- int64_t packedWidth =
- cast<VectorType>(scaleTrunc.getType()).getNumElements();
- if (packedWidth != opWidth)
- scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, scaleTrunc, 0, sliceWidth, 1);
- blockResult = rewriter.create<vector::InsertStridedSliceOp>(
- loc, scaleTrunc, blockResult, i, 1);
+ for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+ i < blockSize; i += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+ Value scaleTrunc;
+ // Case where <= 2 elements are being truncated.
+ if (outSliceWidth <= opInWidth) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, outSliceWidth, 1);
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+ /*existing=*/nullptr);
+ } else {
+ scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+ truncScaleResultType, zero);
+ for (int64_t j = 0,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+ j < outSliceWidth; j += opInWidth,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i + j, inSliceWidth, 1);
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale,
+ j / opInWidth, scaleTrunc);
+ }
+ }
+ if (outSliceWidth != opOutWidth) {
+ scaleTrunc = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleTrunc, blockResult, i, 1);
}
VectorType resultType = VectorType::get(ratio, outType);
Value cast =
- rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
- result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
- offsets, strides);
+ vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
+ result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
+ offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -667,19 +690,21 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
- bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
+ PatternBenefit benefit) {
if (convertFP8Arithmetic) {
- patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
- patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
- saturateFP8Truncf, chipset);
+ patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
+ benefit);
+ patterns.add<TruncFToFloat8RewritePattern>(
+ patterns.getContext(), saturateFP8Truncf, chipset, benefit);
}
if (allowPackedF16Rtz)
- patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+ patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
if (chipset >= kGfx950) {
- patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
- patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
+ patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
+ patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
}
}
diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
index cbe0b3f..ba48943 100644
--- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
+++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
@@ -74,15 +74,15 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
- auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+ auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
// Create 'arm_sme.insert_tile_slice' to write vector to tile
// slice.
- auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ auto nextTile = arm_sme::InsertTileSliceOp::create(
+ b, loc, tileType, constantOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
auto forOp = mlir::arm_sme::createLoopOverTileSlices(
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index a5c08a6..515fe5c 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -110,9 +110,9 @@ public:
emitc::CmpPredicate predicate;
switch (op.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: {
- auto constant = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rewriter.getI1Type(),
- rewriter.getBoolAttr(/*value=*/false));
+ auto constant =
+ emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/false));
rewriter.replaceOp(op, constant);
return success();
}
@@ -179,9 +179,9 @@ public:
return success();
}
case arith::CmpFPredicate::AlwaysTrue: {
- auto constant = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rewriter.getI1Type(),
- rewriter.getBoolAttr(/*value=*/true));
+ auto constant =
+ emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(/*value=*/true));
rewriter.replaceOp(op, constant);
return success();
}
@@ -189,8 +189,8 @@ public:
// Compare the values naively
auto cmpResult =
- rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
- adaptor.getLhs(), adaptor.getRhs());
+ emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
+ adaptor.getLhs(), adaptor.getRhs());
// Adjust the results for unordered/ordered semantics
if (unordered) {
@@ -213,16 +213,16 @@ private:
Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is NaN exactly when it compares unequal to itself.
- return rewriter.create<emitc::CmpOp>(
- loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
+ return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
+ emitc::CmpPredicate::ne, operand, operand);
}
/// Return a value that is true if \p operand is not NaN.
Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
Value operand) const {
// A value is not NaN exactly when it compares equal to itself.
- return rewriter.create<emitc::CmpOp>(
- loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
+ return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
+ emitc::CmpPredicate::eq, operand, operand);
}
/// Return a value that is true if the operands \p first and \p second are
@@ -231,8 +231,8 @@ private:
Location loc, Value first, Value second) const {
auto firstIsNaN = isNaN(rewriter, loc, first);
auto secondIsNaN = isNaN(rewriter, loc, second);
- return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
- firstIsNaN, secondIsNaN);
+ return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
+ firstIsNaN, secondIsNaN);
}
/// Return a value that is true if the operands \p first and \p second are
@@ -241,8 +241,8 @@ private:
Value first, Value second) const {
auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
- return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
- firstIsNotNaN, secondIsNotNaN);
+ return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
+ firstIsNotNaN, secondIsNotNaN);
}
};
@@ -378,10 +378,10 @@ public:
Type attrType = (emitc::isPointerWideType(operandType))
? rewriter.getIndexType()
: operandType;
- auto constOne = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), operandType, rewriter.getOneAttr(attrType));
- auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
- op.getLoc(), operandType, adaptor.getIn(), constOne);
+ auto constOne = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
+ auto oneAndOperand = emitc::BitwiseAndOp::create(
+ rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
oneAndOperand);
return success();
@@ -402,8 +402,8 @@ public:
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
// Actual cast (may change bitwidth)
- auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
- castDestType, actualOp);
+ auto cast =
+ emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
// Cast to the expected output type
auto result = adaptValueType(cast, rewriter, opReturnType);
@@ -466,9 +466,8 @@ public:
Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
- auto newDivOp =
- rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
- ArrayRef<Value>{lhsAdapted, rhsAdapted});
+ auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
+ ArrayRef<Value>{lhsAdapted, rhsAdapted});
Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
rewriter.replaceOp(uiBinOp, resultAdapted);
return success();
@@ -508,8 +507,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
- Value arithmeticResult = rewriter.template create<EmitCOp>(
- op.getLoc(), arithmeticType, lhs, rhs);
+ Value arithmeticResult =
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
@@ -548,8 +547,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
- Value arithmeticResult = rewriter.template create<EmitCOp>(
- op.getLoc(), arithmeticType, lhs, rhs);
+ Value arithmeticResult =
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
@@ -588,38 +587,40 @@ public:
// Add a runtime check for overflow
Value width;
if (emitc::isPointerWideType(type)) {
- Value eight = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rhsType, rewriter.getIndexAttr(8));
- emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
- op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
- width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
- sizeOfCall.getResult(0));
+ Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
+ rewriter.getIndexAttr(8));
+ emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
+ rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
+ width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
+ sizeOfCall.getResult(0));
} else {
- width = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), rhsType,
+ width = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
}
- Value excessCheck = rewriter.create<emitc::CmpOp>(
- op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
+ Value excessCheck =
+ emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
+ emitc::CmpPredicate::lt, rhs, width);
// Any concrete value is a valid refinement of poison.
- Value poison = rewriter.create<emitc::ConstantOp>(
- op.getLoc(), arithmeticType,
+ Value poison = emitc::ConstantOp::create(
+ rewriter, op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
- emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
- op.getLoc(), arithmeticType, /*do_not_inline=*/false);
+ emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
+ rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
- rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
- Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
- op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
- rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
+ Value resultOrPoison =
+ emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
+ excessCheck, arithmeticResult, poison);
+ emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
Value result = adaptValueType(ternary, rewriter, type);
@@ -700,11 +701,12 @@ public:
/*isSigned=*/false);
}
- Value result = rewriter.create<emitc::CastOp>(
- castOp.getLoc(), actualResultType, adaptor.getOperands());
+ Value result = emitc::CastOp::create(
+ rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
if (isa<arith::FPToUIOp>(castOp)) {
- result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
+ result =
+ emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
}
rewriter.replaceOp(castOp, result);
@@ -746,8 +748,8 @@ public:
}
Value fpCastOperand = adaptor.getIn();
if (actualOperandType != operandType) {
- fpCastOperand = rewriter.template create<emitc::CastOp>(
- castOp.getLoc(), actualOperandType, fpCastOperand);
+ fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
+ actualOperandType, fpCastOperand);
}
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f7bf581..18e857c 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -293,11 +293,11 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename OpTy::Adaptor adaptor(operands);
if (targetBits < sourceBits) {
- return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
}
- return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
- adaptor.getIn());
+ return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getIn());
},
rewriter);
}
@@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
Type newOverflowType = typeConverter->convertType(overflowResultType);
Type structType =
LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
- Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
- loc, structType, adaptor.getLhs(), adaptor.getRhs());
+ Value addOverflow = LLVM::UAddWithOverflowOp::create(
+ rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
Value sumExtracted =
- rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
+ LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
Value overflowExtracted =
- rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
+ LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
return success();
}
@@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
"LLVM dialect should support all signless integer types");
using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
- Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
- Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
- Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
+ Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
+ Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
+ Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
// Split the 2*N-bit wide result into two N-bit values.
- Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
- Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
- Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
- Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
+ Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
+ Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
+ Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
+ Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
rewriter.replaceOp(op, {low, high});
return success();
@@ -435,8 +435,8 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
- return rewriter.create<LLVM::ICmpOp>(
- op.getLoc(), llvm1DVectorTy,
+ return LLVM::ICmpOp::create(
+ rewriter, op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs());
},
@@ -471,8 +471,8 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
- return rewriter.create<LLVM::FCmpOp>(
- op.getLoc(), llvm1DVectorTy,
+ return LLVM::FCmpOp::create(
+ rewriter, op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs(), fmf);
},
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df..265293b 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -117,12 +128,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value,
if (auto vectorType = dyn_cast<VectorType>(type)) {
Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
auto attr = SplatElementsAttr::get(vectorType, element);
- return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
+ return spirv::ConstantOp::create(builder, loc, vectorType, attr);
}
if (auto intType = dyn_cast<IntegerType>(type))
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, value));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getIntegerAttr(type, value));
return nullptr;
}
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -418,18 +447,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Type type = lhs.getType();
// Calculate the remainder with spirv.UMod.
- Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
- Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
- Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
+ Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
+ Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
+ Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
// Fix the sign.
Value isPositive;
if (lhs == signOperand)
- isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
+ isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
else
- isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
- Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
- return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
+ isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
+ Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
+ return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
+ absNegate);
}
/// Converts arith.remsi to GLSL SPIR-V ops.
@@ -601,13 +631,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
Value allOnes;
if (auto intTy = dyn_cast<IntegerType>(dstType)) {
unsigned componentBitwidth = intTy.getWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, intTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, intTy,
rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
} else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, vectorTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, vectorTy,
SplatElementsAttr::get(vectorTy,
APInt::getAllOnes(componentBitwidth)));
} else {
@@ -653,8 +683,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
// First shift left to sequeeze out all leading bits beyond the original
// bitwidth. Here we need to use the original source and result type's
// bitwidth.
- auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
- op.getLoc(), dstType, adaptor.getIn(), shiftSize);
+ auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
+ rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
// Then we perform arithmetic right shift to make sure we have the right
// sign bits for negative values.
@@ -757,9 +787,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
auto srcType = adaptor.getOperands().front().getType();
// Check if (x & 1) == 1.
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
- Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
- loc, srcType, adaptor.getOperands()[0], mask);
- Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+ Value maskedSrc = spirv::BitwiseAndOp::create(
+ rewriter, loc, srcType, adaptor.getOperands()[0], mask);
+ Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -914,9 +944,9 @@ public:
if (auto vectorType = dyn_cast<VectorType>(dstType))
type = VectorType::get(vectorType.getShape(), type);
Value extLhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
Value extRhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
extRhs);
@@ -1067,12 +1097,12 @@ public:
replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
}
} else {
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+ replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
if (op.getPredicate() == arith::CmpFPredicate::ORD)
- replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+ replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
}
rewriter.replaceOp(op, replace);
@@ -1094,17 +1124,17 @@ public:
ConversionPatternRewriter &rewriter) const override {
Type dstElemTy = adaptor.getLhs().getType();
Location loc = op->getLoc();
- Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
- adaptor.getRhs());
+ Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
+ adaptor.getRhs());
- Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(0));
- Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(1));
+ Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
// Convert the carry value to boolean.
Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
- Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+ Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
rewriter.replaceOp(op, {sumResult, carryResult});
return success();
@@ -1125,12 +1155,12 @@ public:
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value result =
- rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
+ SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
- Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(0));
- Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(1));
+ Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
rewriter.replaceOp(op, {low, high});
return success();
@@ -1183,20 +1213,20 @@ public:
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
rewriter.replaceOp(op, spirvOp);
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getLhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getRhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getLhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getRhs(), select1);
rewriter.replaceOp(op, select2);
return success();
@@ -1237,7 +1267,7 @@ public:
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (!shouldInsertNanGuards<SPIRVOp>() ||
bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
@@ -1245,13 +1275,13 @@ public:
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getRhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getLhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getLhs(), select1);
rewriter.replaceOp(op, select2);
return success();
@@ -1351,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de93..e34b368 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
@@ -41,11 +40,11 @@ public:
Value c2d = op.getC();
Location loc = op.getLoc();
Value b1d =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d);
Value c1d =
- rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
- Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
- b1d, c1d);
+ vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d);
+ Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(),
+ op.getA(), b1d, c1d);
rewriter.replaceOp(op, {newOp});
return success();
}
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 21ea444..8a2e3b63 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
break;
}
}
@@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
}
llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
@@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc,
// Move to the first operation in the function.
rewriter.setInsertionPointToStart(&func.getBlocks().front());
// Create an alloca matching the tile size of the `tileOp`.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
auto tileElementType = tileOp.getTileType().getElementType();
auto memrefType = MemRefType::get(
{ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
auto minElementsOp =
- rewriter.create<arith::ConstantIndexOp>(loc, minElements);
- auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
- auto alloca = rewriter.create<memref::AllocaOp>(
- loc, memrefType, ValueRange{vectorLen, vectorLen});
+ arith::ConstantIndexOp::create(rewriter, loc, minElements);
+ auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
+ auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
+ ValueRange{vectorLen, vectorLen});
return alloca;
}
@@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
Value tileMemory, Value sliceIndex) const {
auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
auto descriptor =
- rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
- auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
- auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), sliceIndex);
+ UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
+ auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64);
+ auto sliceIndexI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), sliceIndex);
return getStridedElementPtr(
static_cast<ConversionPatternRewriter &>(rewriter), loc,
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
@@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
arm_sme::ArmSMETileType tileType, VectorType sliceType,
IntegerAttr tileId, Value sliceIndex) const {
// Cast the slice index to an i32.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create an all-true predicate for the slice.
auto predicateType = sliceType.clone(rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Create padding vector (never used due to all-true predicate).
- auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
+ auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
// Get a pointer to the current slice.
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
// Read the value of the current slice from ZA.
- auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
- loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
+ auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
+ rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
+ sliceIndexI32);
// Load the new tile slice back from memory into ZA.
createLoadTileSliceIntrinsic(
rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
allTruePredicate, slicePtr, tileId, sliceIndexI32);
// Store the current tile slice to memory.
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
- ValueRange{sliceIndex, zero});
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
+ ValueRange{sliceIndex, zero});
}
/// Emits a full in-place swap of the contents of a tile in ZA and a
@@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
RewriterBase::InsertionGuard guard(rewriter);
// Create an scf.for over all tile slices.
auto minNumElts =
- rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(
- loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0));
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto upperBound =
+ arith::MulIOp::create(rewriter, loc, minNumElts,
+ vector::VectorScaleOp::create(rewriter, loc));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
// Emit a swap for each tile slice.
rewriter.setInsertionPointToStart(forOp.getBody());
auto sliceIndex = forOp.getInductionVar();
@@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
//
// This holds for all tile sizes.
int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- loc, rewriter.getI32IntegerAttr(zeroMask));
+ arm_sme::aarch64_sme_zero::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
// Note: Place the `get_tile` op at the start of the block. This ensures
@@ -513,8 +516,8 @@ struct LoadTileSliceConversion
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
auto maskOp = loadTileSliceOp.getMask();
@@ -559,8 +562,8 @@ struct StoreTileSliceConversion
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
auto maskOp = storeTileSliceOp.getMask();
@@ -595,28 +598,29 @@ struct InsertTileSliceConversion
auto tileSlice = insertTileSliceOp.getTileSliceIndex();
// Cast tile slice from index to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI1Type(),
+ auto one = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+ auto allActiveMask =
+ vector::BroadcastOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
- rewriter.create<arm_sme::aarch64_sme_write_horiz>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
- rewriter.create<arm_sme::aarch64_sme_write_vert>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
}
@@ -646,16 +650,16 @@ struct ExtractTileSliceConversion
// Create an 'all true' predicate for the tile slice.
auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Zero destination/fallback for tile slice extraction.
- auto zeroVector = rewriter.create<arith::ConstantOp>(
- loc, sliceType, rewriter.getZeroAttr(sliceType));
+ auto zeroVector = arith::ConstantOp::create(
+ rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
// Cast tile slice from index to i32 for intrinsic.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
switch (extractTileSlice.getLayout()) {
@@ -743,7 +747,7 @@ struct OuterProductOpConversion
Value acc = outerProductOp.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
zero.setTileId(tileId);
acc = zero;
}
@@ -754,16 +758,16 @@ struct OuterProductOpConversion
if (!lhsMask || !rhsMask) {
auto predTy =
outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
// Create 'arm_sme.intr.mopa' outer product intrinsic.
- rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
- outerProductOp.getLhs(),
- outerProductOp.getRhs());
+ arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ outerProductOp.getLhs(),
+ outerProductOp.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -792,7 +796,7 @@ struct OuterProductWideningOpConversion
Value acc = op.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
zero.setTileId(tileId);
acc = zero;
}
@@ -801,14 +805,14 @@ struct OuterProductWideningOpConversion
Value rhsMask = op.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
- rewriter.create<OuterProductWideningIntrOp>(
- loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
+ OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ adaptor.getLhs(), adaptor.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -843,13 +847,13 @@ struct StreamingVLOpConversion
auto *intrOp = [&]() -> Operation * {
switch (streamingVlOp.getTypeSize()) {
case arm_sme::TypeSize::Byte:
- return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Half:
- return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Word:
- return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Double:
- return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
}
llvm_unreachable("unknown type size in StreamingVLOpConversion");
}();
@@ -872,8 +876,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) {
if (zeroOpsToMerge.size() <= 1)
return;
IRRewriter rewriter(zeroOpsToMerge.front());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- zeroOpsToMerge.front().getLoc(),
+ arm_sme::aarch64_sme_zero::create(
+ rewriter, zeroOpsToMerge.front().getLoc(),
rewriter.getI32IntegerAttr(mergedZeroMask));
for (auto zeroOp : zeroOpsToMerge)
rewriter.eraseOp(zeroOp);
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 458628c..e28d5122 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -39,7 +39,7 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
auto tileSliceOffset = tileSliceIndex;
auto baseIndexPlusTileSliceOffset =
- rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
+ arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
outIndices.push_back(indices[1]);
@@ -59,10 +59,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
if (memrefIndices.size() != 2)
return rewriter.notifyMatchFailure(loc, "invalid number of indices");
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc,
+ arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
@@ -70,7 +71,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// elements in a vector of SVL bits for a given element type (SVL_B,
// SVL_H, ..., SVL_Q).
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
Value predicate;
Value upperBound;
@@ -82,30 +83,30 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
// The upper bound of the loop must be clamped at `numTileSlices` as
// `vector.create_mask` allows operands to be greater than the size of a
// dimension.
- auto numRowI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), maskDim0);
- auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), numTileSlices);
+ auto numRowI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), maskDim0);
+ auto numTileSlicesI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), numTileSlices);
auto upperBoundI64 =
- rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
- upperBound = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getIndexType(), upperBoundI64);
+ arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64);
+ upperBound = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getIndexType(), upperBoundI64);
predicate =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+ vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1);
} else {
upperBound = numTileSlices;
// No mask. Create an 'all true' predicate for the tile slice.
- predicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ predicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
}
bool hasCarriedArgs = bool(initTile);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
- hasCarriedArgs ? ValueRange{initTile}
- : ValueRange{});
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
+ hasCarriedArgs ? ValueRange{initTile} : ValueRange{});
rewriter.setInsertionPointToStart(forOp.getBody());
Value tileSliceIndex = forOp.getInductionVar();
@@ -118,7 +119,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
assert(bool(nextTile) == hasCarriedArgs);
if (nextTile)
- rewriter.create<scf::YieldOp>(loc, nextTile);
+ scf::YieldOp::create(rewriter, loc, nextTile);
return forOp;
}
@@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive
// rows however, no load will occur so these need to be zeroed.
- initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+ initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType);
} else {
- initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
}
// Create a loop to load the active tile slices from memory.
@@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Value currentTile) -> Value {
// Create 'arm_sme.load_tile_slice' to load tile slice from memory
// into tile.
- return rewriter.create<arm_sme::LoadTileSliceOp>(
- loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ return arm_sme::LoadTileSliceOp::create(
+ rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
+ currentTile, memrefIndices, tileSliceIndex,
+ tileLoadOp.getLayout());
});
if (failed(forOp))
@@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];
- auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), numCols);
+ auto numColsI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), numCols);
- auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto minTileSlices = arith::ConstantIndexOp::create(
+ rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
- step, ValueRange{initTile});
+ arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
+ auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto currentTile = forOp.getRegionIterArg(0);
// Combine masks.
- auto rowIsActive = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
- auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
- loc, rewriter.getI32Type(), rowIsActive);
- auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
- auto maskIndex =
- rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
+ auto rowIsActive = arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+ auto rowIsActiveI32 = arith::ExtSIOp::create(
+ rewriter, loc, rewriter.getI32Type(), rowIsActive);
+ auto mask =
+ arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32);
+ auto maskIndex = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), mask);
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
- loc, predicateType, maskIndex.getResult());
+ auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType,
+ maskIndex.getResult());
auto memrefIndices = getMemrefIndices(
tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
@@ -324,17 +327,19 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+ auto pad1DOp =
+ vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
- auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
- loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
- /*passthru=*/pad1DOp);
+ auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
+ tileLoadOp.getBase(),
+ memrefIndices, maskOp1D,
+ /*passthru=*/pad1DOp);
// Create 'arm_sme.insert_tile_slice' to insert slice into tile.
- auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>(
- loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
- tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, insertSlice.getResult());
+ auto insertSlice = arm_sme::InsertTileSliceOp::create(
+ rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+ tileSliceIndex, tileLoadOp.getLayout());
+ scf::YieldOp::create(rewriter, loc, insertSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 94f7caa..29e6552 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -18,7 +18,6 @@
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -203,7 +202,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
auto addFuncDecl = [&](StringRef name, FunctionType type) {
if (module.lookupSymbol(name))
return;
- builder.create<func::FuncOp>(name, type).setPrivate();
+ func::FuncOp::create(builder, name, type).setPrivate();
};
MLIRContext *ctx = module.getContext();
@@ -254,15 +253,15 @@ static void addResumeFunction(ModuleOp module) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
Type ptrType = AsyncAPI::opaquePointerType(ctx);
- auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
- kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
+ auto resumeOp = LLVM::LLVMFuncOp::create(
+ moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
resumeOp.setPrivate();
auto *block = resumeOp.addEntryBlock(moduleBuilder);
auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
- blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
- blockBuilder.create<LLVM::ReturnOp>(ValueRange());
+ LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0));
+ LLVM::ReturnOp::create(blockBuilder, ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -282,7 +281,8 @@ public:
// in patterns for other dialects.
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ auto cast =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return cast.getResult(0);
};
@@ -343,8 +343,8 @@ public:
// Constants for initializing coroutine frame.
auto constZero =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType);
// Get coroutine id: @llvm.coro.id.
rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
@@ -372,33 +372,33 @@ public:
// Get coroutine frame size: @llvm.coro.size.i64.
Value coroSize =
- rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
+ LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type());
// Get coroutine frame alignment: @llvm.coro.align.i64.
Value coroAlign =
- rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
+ LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type());
// Round up the size to be multiple of the alignment. Since aligned_alloc
// requires the size parameter be an integral multiple of the alignment
// parameter.
auto makeConstant = [&](uint64_t c) {
- return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
- rewriter.getI64Type(), c);
+ return LLVM::ConstantOp::create(rewriter, op->getLoc(),
+ rewriter.getI64Type(), c);
};
- coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
+ coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign);
coroSize =
- rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
+ LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1));
Value negCoroAlign =
- rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
+ LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign);
coroSize =
- rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
+ LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign);
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
if (failed(allocFuncOp))
return failure();
- auto coroAlloc = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
+ auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
+ ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -427,7 +427,7 @@ public:
// Get a pointer to the coroutine frame memory: @llvm.coro.free.
auto coroMem =
- rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
+ LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands());
// Free the memory.
auto freeFuncOp =
@@ -455,15 +455,15 @@ public:
matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We are not in the block that is part of the unwind sequence.
- auto constFalse = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
- auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
+ auto constFalse =
+ LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
+ rewriter.getBoolAttr(false));
+ auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc());
// Mark the end of a coroutine: @llvm.coro.end.
auto coroHdl = adaptor.getHandle();
- rewriter.create<LLVM::CoroEndOp>(
- op->getLoc(), rewriter.getI1Type(),
- ValueRange({coroHdl, constFalse, noneToken}));
+ LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(),
+ ValueRange({coroHdl, constFalse, noneToken}));
rewriter.eraseOp(op);
return success();
@@ -534,13 +534,13 @@ public:
auto loc = op->getLoc();
// This is not a final suspension point.
- auto constFalse = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+ auto constFalse = LLVM::ConstantOp::create(
+ rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
// Suspend a coroutine: @llvm.coro.suspend
auto coroState = adaptor.getState();
- auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
- loc, i8, ValueRange({coroState, constFalse}));
+ auto coroSuspend = LLVM::CoroSuspendOp::create(
+ rewriter, loc, i8, ValueRange({coroState, constFalse}));
// Cast return code to i32.
@@ -551,7 +551,7 @@ public:
llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
op.getCleanupDest()};
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
- op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
+ op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()),
/*defaultDestination=*/op.getSuspendDest(),
/*defaultOperands=*/ValueRange(),
/*caseValues=*/caseValues,
@@ -602,11 +602,11 @@ public:
// %Size = getelementptr %T* null, int 1
// %SizeI = ptrtoint %T* %Size to i64
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType);
auto gep =
- rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
- nullPtr, ArrayRef<LLVM::GEPArg>{1});
- return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
+ LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType,
+ nullPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep);
};
rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
@@ -739,8 +739,8 @@ public:
.Case<ValueType>([](Type) { return kAwaitValue; })
.Case<GroupType>([](Type) { return kAwaitGroup; });
- rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
- adaptor.getOperands());
+ func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
+ adaptor.getOperands());
rewriter.eraseOp(op);
return success();
@@ -772,13 +772,12 @@ public:
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
- kResume);
+ auto resumePtr = LLVM::AddressOfOp::create(
+ rewriter, op->getLoc(),
+ AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
- rewriter.create<func::CallOp>(
- op->getLoc(), apiFuncName, TypeRange(),
- ValueRange({operand, handle, resumePtr.getRes()}));
+ func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(),
+ ValueRange({operand, handle, resumePtr.getRes()}));
rewriter.eraseOp(op);
return success();
@@ -801,9 +800,9 @@ public:
ConversionPatternRewriter &rewriter) const override {
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
- auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
- op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
- kResume);
+ auto resumePtr = LLVM::AddressOfOp::create(
+ rewriter, op->getLoc(),
+ AsyncAPI::opaquePointerType(rewriter.getContext()), kResume);
// Call async runtime API to execute a coroutine in the managed thread.
auto coroHdl = adaptor.getHandle();
@@ -832,8 +831,8 @@ public:
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(
- loc, kGetValueStorage, TypeRange(ptrType), storage);
+ auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
+ TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getValue().getType();
@@ -845,7 +844,7 @@ public:
Value castedStoragePtr = storagePtr.getResult(0);
// Store the yielded value into the async value storage.
auto value = adaptor.getValue();
- rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
+ LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr);
// Erase the original runtime store operation.
rewriter.eraseOp(op);
@@ -872,8 +871,8 @@ public:
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
- auto storagePtr = rewriter.create<func::CallOp>(
- loc, kGetValueStorage, TypeRange(ptrType), storage);
+ auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage,
+ TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getResult().getType();
@@ -960,9 +959,9 @@ public:
LogicalResult
matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto count = rewriter.create<arith::ConstantOp>(
- op->getLoc(), rewriter.getI64Type(),
- rewriter.getI64IntegerAttr(op.getCount()));
+ auto count =
+ arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(),
+ rewriter.getI64IntegerAttr(op.getCount()));
auto operand = adaptor.getOperand();
rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index b9991f3..3edcbb8 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -47,30 +47,29 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
// Constants
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Dynamically evaluate the size and shape of the unranked memref
- Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
+ Value rank = memref::RankOp::create(rewriter, loc, op.getInput());
MemRefType allocType =
MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
- Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
+ Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank);
// Create a loop to query dimension sizes, store them as a shape, and
// compute the total size of the memref
auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
ValueRange args) {
auto acc = args.front();
- auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
+ auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i);
- rewriter.create<memref::StoreOp>(loc, dim, shape, i);
- acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
+ memref::StoreOp::create(rewriter, loc, dim, shape, i);
+ acc = arith::MulIOp::create(rewriter, loc, acc, dim);
- rewriter.create<scf::YieldOp>(loc, acc);
+ scf::YieldOp::create(rewriter, loc, acc);
};
- auto size = rewriter
- .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
- loopBody)
+ auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
+ ValueRange(one), loopBody)
.getResult(0);
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
@@ -78,9 +77,9 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
// Allocate new memref with 1D dynamic shape, then reshape into the
// shape of the original unranked memref
- alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
+ alloc = memref::AllocOp::create(rewriter, loc, memrefType, size);
alloc =
- rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
+ memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape);
} else {
MemRefType memrefType = cast<MemRefType>(type);
MemRefLayoutAttrInterface layout;
@@ -103,14 +102,15 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
}
// Allocate a memref with identity layout.
- alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
+ alloc =
+ memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands);
// Cast the allocation to the specified type if needed.
if (memrefType != allocType)
alloc =
- rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
+ memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc);
}
- rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
+ memref::CopyOp::create(rewriter, loc, op.getInput(), alloc);
rewriter.replaceOp(op, alloc);
return success();
}
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index f84375b..785cb82 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
-add_subdirectory(MeshToMPI)
+add_subdirectory(ShardToMPI)
add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc29..35ad99c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index eeff8a9..5ad514d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include <type_traits>
diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index c8311eb..5ac838c 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
return emitError(loc, "Cannot create unreachable terminator for '")
<< parentOp->getName() << "'";
- return builder
- .create<func::ReturnOp>(
- loc, llvm::map_to_vector(funcOp.getResultTypes(),
- [&](Type type) {
- return getUndefValue(loc, builder, type);
- }))
+ return func::ReturnOp::create(
+ builder, loc,
+ llvm::map_to_vector(
+ funcOp.getResultTypes(),
+ [&](Type type) { return getUndefValue(loc, builder, type); }))
.getOperation();
}
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4..56b6181 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp
index c9b1dc1..ee6d7d5 100644
--- a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp
@@ -9,8 +9,6 @@
#include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h"
#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
-#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
index 252245d..c70b5f0 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp
@@ -9,7 +9,6 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
-#include "llvm/ADT/DenseSet.h"
using namespace mlir;
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f65..c0439a4 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 63eb6c58..3cfbd89 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -579,8 +579,8 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
auto function = [&] {
if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
return function;
- return OpBuilder::atBlockEnd(module.getBody())
- .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
+ auto builder = OpBuilder::atBlockEnd(module.getBody());
+ return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
}();
return LLVM::CallOp::create(builder, loc, function, arguments);
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index a19194e..3545acb 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -385,6 +385,14 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
+ if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
+ for (Attribute targetAttr : targets)
+ if (auto spirvTargetEnvAttr =
+ dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
+ spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
+ break;
+ }
+ }
rewriter.eraseOp(moduleOp);
return success();
@@ -507,25 +515,27 @@ LogicalResult GPURotateConversion::matchAndRewrite(
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
- IntegerAttr widthAttr;
- if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
- widthAttr.getValue().getZExtValue() > subgroupSize)
+ unsigned width = rotateOp.getWidth();
+ if (width > subgroupSize)
return rewriter.notifyMatchFailure(
- rotateOp,
- "rotate width is not a constant or larger than target subgroup size");
+ rotateOp, "rotate width is larger than target subgroup size");
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value offsetVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
+ Value widthVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
- rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
- adaptor.getWidth());
+ rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
Value validVal;
- if (widthAttr.getValue().getZExtValue() == subgroupSize) {
+ if (width == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
+ IntegerAttr widthAttr = adaptor.getWidthAttr();
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ laneId, widthVal);
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
@@ -559,8 +569,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
- return builder
- .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
+ return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
+ clusterSizeValue)
.getResult();
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index a344f88..5eab057 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -48,9 +48,36 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
void runOnOperation() override;
private:
+ /// Queries the target environment from 'targets' attribute of the given
+ /// `moduleOp`.
+ spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp);
+
+ /// Queries the target environment from 'targets' attribute of the given
+ /// `moduleOp` or returns target environment as returned by
+ /// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'.
+ spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp);
bool mapMemorySpace;
};
+spirv::TargetEnvAttr
+GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) {
+ if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
+ for (Attribute targetAttr : targets)
+ if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
+ return spirvTargetEnvAttr;
+ }
+
+ return {};
+}
+
+spirv::TargetEnvAttr
+GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) {
+ if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp))
+ return targetEnvAttr;
+
+ return spirv::lookupTargetEnvOrDefault(moduleOp);
+}
+
void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
@@ -58,9 +85,8 @@ void GPUToSPIRVPass::runOnOperation() {
SmallVector<Operation *, 1> gpuModules;
OpBuilder builder(context);
- auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) {
- Operation *gpuModule = moduleOp.getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
+ auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) {
+ auto targetAttr = lookupTargetEnvOrDefault(moduleOp);
spirv::TargetEnv targetEnv(targetAttr);
return targetEnv.allows(spirv::Capability::Kernel);
};
@@ -86,7 +112,7 @@ void GPUToSPIRVPass::runOnOperation() {
// TargetEnv attributes.
for (Operation *gpuModule : gpuModules) {
spirv::TargetEnvAttr targetAttr =
- spirv::lookupTargetEnvOrDefault(gpuModule);
+ lookupTargetEnvOrDefault(cast<gpu::GPUModuleOp>(gpuModule));
// Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index ecd5b63..2568044 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
- toDynamic
- ? builder
- .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
- .getResult()
- : LLVM::AllocaOp::create(builder, loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
+ toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5b68eb8..e5496e5 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
+ ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
}
return ret;
}
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 08a4566..cde2340 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -17,13 +17,12 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
@@ -33,7 +32,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-funcs"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
// Pattern to convert vector operations to scalar operations.
@@ -654,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
if (!isa<IntegerType>(elementType)) {
- LLVM_DEBUG({
- DBGS() << "non-integer element type for CtlzFunc; type was: ";
- elementType.print(llvm::dbgs());
- });
+ LDBG() << "non-integer element type for CtlzFunc; type was: "
+ << elementType;
llvm_unreachable("non-integer element type");
}
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
@@ -699,7 +695,8 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
scf::IfOp ifOp =
scf::IfOp::create(builder, elementType, inputEqZero,
/*addThenBlock=*/true, /*addElseBlock=*/true);
- ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
+ auto thenBuilder = ifOp.getThenBodyBuilder();
+ scf::YieldOp::create(thenBuilder, loc, bitWidthValue);
auto elseBuilder =
ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 93d8b49..df219f3 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,7 +22,6 @@
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
@@ -31,7 +31,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-rocdl"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index a877ad2..1787e0a 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -488,7 +488,12 @@ namespace mlir {
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Core patterns
- patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
+ patterns
+ .add<CopySignPattern,
+ CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
+ CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
+ CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
+ typeConverter, patterns.getContext());
// GLSL patterns
patterns
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e882845..6bd0e2d 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
using namespace mlir;
+static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
+ return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
+ memRefType.getRank() != 0 &&
+ !llvm::is_contained(memRefType.getShape(), 0);
+}
+
namespace {
/// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = allocOp.getLoc();
+ MemRefType memrefType = allocOp.getType();
+ if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+ Type elementType = memrefType.getElementType();
+ IndexType indexType = rewriter.getIndexType();
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+ loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
+
+ int64_t numElements = 1;
+ for (int64_t dimSize : memrefType.getShape()) {
+ numElements *= dimSize;
+ }
+ Value numElementsValue = rewriter.create<emitc::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(numElements));
+
+ Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+
+ emitc::CallOpaqueOp allocCall;
+ StringAttr allocFunctionName;
+ Value alignmentValue;
+ SmallVector<Value, 2> argsVec;
+ if (allocOp.getAlignment()) {
+ allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
+ alignmentValue = rewriter.create<emitc::ConstantOp>(
+ loc, sizeTType,
+ rewriter.getIntegerAttr(indexType,
+ allocOp.getAlignment().value_or(0)));
+ argsVec.push_back(alignmentValue);
+ } else {
+ allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
+ }
+
+ argsVec.push_back(totalSizeBytes);
+ ValueRange args(argsVec);
+
+ allocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ emitc::PointerType::get(
+ emitc::OpaqueType::get(rewriter.getContext(), "void")),
+ allocFunctionName, args);
+
+ emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
+ emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
+ loc, targetPointerType, allocCall.getResult(0));
+
+ rewriter.replaceOp(allocOp, castOp);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
- if (!memRefType.hasStaticShape() ||
- !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
- llvm::is_contained(memRefType.getShape(), 0)) {
+ if (!isMemRefTypeLegalForEmitC(memRefType)) {
return {};
}
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09..e78dd76 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ using Base::Base;
void runOnOperation() override {
TypeConverter converter;
-
+ ConvertMemRefToEmitCOptions options;
+ options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
+
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
+ if (callOp.getCallee() != alignedAllocFunctionName &&
+ callOp.getCallee() != mallocFunctionName) {
+ return mlir::WalkResult::advance();
+ }
+
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp) {
+ continue;
+ }
+ if (includeOp.getIsStandardInclude() &&
+ ((options.lowerToCpp &&
+ includeOp.getInclude() == cppStandardLibraryHeader) ||
+ (!options.lowerToCpp &&
+ includeOp.getInclude() == cStandardLibraryHeader))) {
+ return mlir::WalkResult::interrupt();
+ }
+ }
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ StringAttr includeAttr =
+ builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ return mlir::WalkResult::interrupt();
+ });
}
};
} // namespace
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 53a1912..dc2035b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -24,11 +24,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
+
#include <optional>
#define DEBUG_TYPE "memref-to-llvm"
-#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
@@ -575,8 +576,8 @@ private:
Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
getTypeConverter()->getIndexType(),
offsetPtr, idxPlusOne);
- return rewriter
- .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
+ return LLVM::LoadOp::create(rewriter, loc,
+ getTypeConverter()->getIndexType(), sizePtr)
.getResult();
}
@@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
- "from fmax to fmaximum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs";
return LLVM::AtomicBinOp::fmaximum;
case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
@@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
- "from fmin to fminimum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs";
return LLVM::AtomicBinOp::fminimum;
case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 905287e1..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -21,19 +21,17 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1106,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1282,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1367,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1392,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1401,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1468,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 662ee9e..91788f9 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -25,11 +25,10 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
@@ -52,17 +51,17 @@ struct PtxLowering
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
- LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
+ LDBG() << "Ptx Builder does not lower \n\t" << op;
return failure();
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
- LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
+ LDBG() << op.getPtx();
PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
+ LDBG() << asmValue << "\t Modifier : " << &modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 3e434ea..5bd1d49 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
predList.emplace_back(pos, builder.getIsNotNull());
- if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
+ if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) {
// If the attribute has a type or value, add a constraint.
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
index e1a9fa59..2d9c661f 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp
@@ -14,9 +14,7 @@
#include "RootOrdering.h"
#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
-#include <queue>
#include <utility>
using namespace mlir;
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 240491a..807be7e 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// block. This should be reconsidered if we allow break/continue in SCF.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
+ SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
@@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
- rewriter.replaceOp(whileOp, condOp.getArgs());
+ rewriter.replaceOp(whileOp, args);
return success();
}
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index f191f35..badd2f6 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -25,9 +25,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"
#include <optional>
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index aae3271..9b61540 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1493,11 +1493,11 @@ public:
Value extended;
if (op2TypeWidth < dstTypeWidth) {
if (isUnsignedIntegerOrVector(op2Type)) {
- extended = rewriter.template create<LLVM::ZExtOp>(
- loc, dstType, adaptor.getOperand2());
+ extended =
+ LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
} else {
- extended = rewriter.template create<LLVM::SExtOp>(
- loc, dstType, adaptor.getOperand2());
+ extended =
+ LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
}
} else if (op2TypeWidth == dstTypeWidth) {
extended = adaptor.getOperand2();
@@ -1505,8 +1505,8 @@ public:
return failure();
}
- Value result = rewriter.template create<LLVMOp>(
- loc, dstType, adaptor.getOperand1(), extended);
+ Value result =
+ LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 7025c5a..0ff9fb3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
index 15560aa..564f36f 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
@@ -1,8 +1,8 @@
-add_mlir_conversion_library(MLIRMeshToMPI
- MeshToMPI.cpp
+add_mlir_conversion_library(MLIRShardToMPI
+ ShardToMPI.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI
DEPENDS
MLIRConversionPassIncGen
@@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
MLIRLinalgTransforms
MLIRMemRefDialect
MLIRPass
- MLIRMeshDialect
+ MLIRShardDialect
MLIRMPIDialect
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 63b1fda..fa9e544 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -1,4 +1,4 @@
-//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements a translation of Mesh communication ops tp MPI ops.
+// This file implements a translation of Shard communication ops to MPI ops.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -20,11 +20,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
@@ -35,16 +35,15 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "mesh-to-mpi"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DEBUG_TYPE "shard-to-mpi"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
-using namespace mesh;
+using namespace shard;
namespace {
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
@@ -177,9 +176,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
auto type = RankedTensorType::get({nSplits, 2}, i64);
Value resHaloSizes =
haloSizes.empty()
- ? rewriter
- .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
- i64)
+ ? tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64)
.getResult()
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
.getResult();
@@ -188,18 +186,18 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
// maxSplitSize+1}. Store the offsets in the tensor but set trailing
// elements for smaller split-groups to -1. Computing the max size of the
// split groups needs using collectiveProcessGroupSize (which needs the
- // MeshOp)
+ // GridOp)
Value resOffsets;
if (adaptor.getStaticShardedDimsOffsets().empty()) {
resOffsets = tensor::EmptyOp::create(rewriter, loc,
std::array<int64_t, 2>{0, 0}, i64);
} else {
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto gridOp = getGrid(op, symbolTableCollection);
int64_t maxSplitSize = 0;
for (auto axes : splitAxes) {
int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
assert(splitSize != ShapedType::kDynamic);
maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
}
@@ -218,7 +216,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
int64_t curr = 0;
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
++splitSize; // add one for the total size
ArrayRef<Value> values(&offsets[curr], splitSize);
@@ -264,20 +262,20 @@ struct ConvertProcessMultiIndexOp
SymbolTableCollection symbolTableCollection;
Location loc = op.getLoc();
- auto meshOp = getMesh(op, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto gridOp = getGrid(op, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return failure();
SmallVector<Value> dims;
llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
- Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp);
+ Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
- // optionally extract subset of mesh axes
+ // optionally extract subset of grid axes
auto axes = adaptor.getAxes();
if (!axes.empty()) {
SmallVector<Value> subIndex;
@@ -306,13 +304,11 @@ public:
auto ctx = op.getContext();
Value commWorld =
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
- auto rank =
- rewriter
- .create<mpi::CommRankOp>(
- loc,
- TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
- commWorld)
- .getRank();
+ auto rank = mpi::CommRankOp::create(
+ rewriter, loc,
+ TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
+ commWorld)
+ .getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
return success();
@@ -338,12 +334,12 @@ struct ConvertNeighborsLinearIndicesOp
Location loc = op.getLoc();
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto gridOp = getGrid(op, symbolTableCollection);
auto mIdx = adaptor.getDevice();
auto orgIdx = mIdx[axes[0]];
SmallVector<Value> dims;
llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
Value dimSz = dims[axes[0]];
@@ -394,14 +390,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
if (!sharding) {
return op->emitError()
- << "Expected SharingOp as defining op for sharding"
+ << "Expected ShardingOp as defining op for sharding"
<< " but found " << adaptor.getSharding()[0].getDefiningOp();
}
// Compute the sharded shape by applying the sharding to the input shape.
// If shardedDimsOffsets is not defined in the sharding, the shard shape is
// computed by dividing the dimension size by the number of shards in that
- // dimension (which is given by the size of the mesh axes provided in
+ // dimension (which is given by the size of the grid axes provided in
// split-axes). Odd elements get distributed to trailing shards. If a
// shardedDimsOffsets is provided, the shard shape is computed by
// subtracting the offset of the current shard from the offset of the next
@@ -431,11 +427,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> multiIdx =
getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
- // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+ // Get the GridOp, the grid shape is needed to compute the sharded shape.
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(sharding, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto gridOp = getGrid(sharding, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return failure();
auto splitAxes = sharding.getSplitAxes().getAxes();
@@ -455,7 +451,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
tmp);
}
- // With static mesh shape the sizes of the split axes are known.
+ // With static grid shape the sizes of the split axes are known.
// Hence the start/pos for each split axes in shardDimsOffsets can be
// computed statically.
int64_t pos = 0;
@@ -475,10 +471,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
// Create a value from the static position in shardDimsOffsets.
Value posVal = arith::ConstantOp::create(rewriter, loc,
rewriter.getIndexAttr(pos));
- // Get the index of the local shard in the mesh axis.
+ // Get the index of the local shard in the grid axis.
Value idx = multiIdx[axes[0]];
auto numShards =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
if (shardedDimsOffs) {
// If sharded dims offsets are provided, use them to compute the
// sharded shape.
@@ -556,13 +552,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SymbolTableCollection symbolTableCollection;
- auto mesh = adaptor.getMesh();
- mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
- if (!meshOp)
- return op->emitError() << "No mesh found for AllReduceOp";
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto grid = adaptor.getGrid();
+ mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
+ if (!gridOp)
+ return op->emitError() << "No grid found for AllReduceOp";
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return op->emitError()
- << "Dynamic mesh shape not supported in AllReduceOp";
+ << "Dynamic grid shape not supported in AllReduceOp";
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
Value input = adaptor.getInput();
@@ -592,27 +588,27 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
linalg::CopyOp::create(iBuilder, input, buffer);
// Get an MPI_Comm_split for the AllReduce operation.
- // The color is the linear index of the process in the mesh along the
- // non-reduced axes. The key is the linear index of the process in the mesh
+ // The color is the linear index of the process in the grid along the
+ // non-reduced axes. The key is the linear index of the process in the grid
// along the reduced axes.
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
iBuilder.getIndexType());
SmallVector<Value> myMultiIndex =
- ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
.getResult();
Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
- auto redAxes = adaptor.getMeshAxes();
+ auto redAxes = adaptor.getGridAxes();
for (auto axis : redAxes) {
multiKey[axis] = myMultiIndex[axis];
myMultiIndex[axis] = zero;
}
Value color =
- createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+ createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
- Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+ Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
@@ -698,15 +694,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
- auto mesh = adaptor.getMesh();
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto grid = adaptor.getGrid();
+ auto gridOp = getGrid(op, symbolTableCollection);
// subviews need Index values
for (auto &sz : haloSizes) {
if (auto value = dyn_cast<Value>(sz))
- sz =
- rewriter
- .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
- .getResult();
+ sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
+ value)
+ .getResult();
}
// most of the offset/size/stride data is the same for all dims
@@ -745,10 +740,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
rewriter.getIndexType());
auto myMultiIndex =
- ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
.getResult();
// traverse all split axes from high to low dim
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
@@ -758,9 +753,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the
// given split
- auto tmp = rewriter
- .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
- splitAxes)
+ auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
+ myMultiIndex, splitAxes)
.getResults();
// MPI operates on i32...
Value neighbourIDs[2] = {
@@ -791,7 +785,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
: haloSizes[currHaloDim * 2];
// Check if we need to send and/or receive
- // Processes on the mesh borders have only one neighbor
+ // Processes on the grid borders have only one neighbor
auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
auto hasFrom = arith::CmpIOp::create(
@@ -869,8 +863,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
};
-struct ConvertMeshToMPIPass
- : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+struct ConvertShardToMPIPass
+ : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
using Base::Base;
/// Run the dialect converter on the module.
@@ -879,12 +873,12 @@ struct ConvertMeshToMPIPass
RewritePatternSet patterns(ctxt);
ConversionTarget target(getContext());
- // Define a type converter to convert mesh::ShardingType,
+ // Define a type converter to convert shard::ShardingType,
// mostly for use in return operations.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
- // convert mesh::ShardingType to a tuple of RankedTensorTypes
+ // convert shard::ShardingType to a tuple of RankedTensorTypes
typeConverter.addConversion(
[](ShardingType type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -920,10 +914,10 @@ struct ConvertMeshToMPIPass
return results;
});
- // No mesh dialect should left after conversion...
- target.addIllegalDialect<mesh::MeshDialect>();
- // ...except the global MeshOp. MeshShapeOp which will get folded later.
- target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
+ // No shard dialect should left after conversion...
+ target.addIllegalDialect<shard::ShardDialect>();
+ // ...except the global GridOp. GridShapeOp which will get folded later.
+ target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
// Allow all the stuff that our patterns will convert to
target.addLegalDialect<
BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
@@ -951,7 +945,7 @@ struct ConvertMeshToMPIPass
// Folding patterns cannot be mixed with conversion patterns -> extra pass.
patterns.clear();
SymbolTableCollection symbolTableCollection;
- mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
+ mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386e..8cd650e 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ec55091..0e3de06 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -22,7 +22,6 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
@@ -570,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// to UIToFP.
if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
auto unrealizedCast =
- rewriter
- .create<UnrealizedConversionCastOp>(
- loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
- args[0])
+ UnrealizedConversionCastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
.getResult(0);
return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
unrealizedCast);
@@ -869,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Emit 'linalg.generic' op
auto resultTensor =
- opBuilder
- .create<linalg::GenericOp>(
- loc, outputTensor.getType(), operand, outputTensor, affineMaps,
- getNParallelLoopsAttrs(rank),
- [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
- // Emit 'linalg.yield' op
- linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
- })
+ linalg::GenericOp::create(
+ opBuilder, loc, outputTensor.getType(), operand, outputTensor,
+ affineMaps, getNParallelLoopsAttrs(rank),
+ [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
+ // Emit 'linalg.yield' op
+ linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
+ })
.getResult(0);
// Cast to original operand type if necessary
@@ -1156,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input);
// First fill the output buffer with the init value.
- auto emptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
- dynDims)
- .getResult();
+ auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
+ .getResult();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
@@ -1168,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
op, "No initial value found for reduction operation");
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
- auto filledTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValue},
- ValueRange{emptyTensor})
- .result();
+ auto filledTensor =
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
+ ValueRange{emptyTensor})
+ .result();
outputs.push_back(filledTensor);
bool isNanIgnoreMode = false;
@@ -1187,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto trueAttr = rewriter.getBoolAttr(true);
auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
auto emptyBoolTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
- dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ trueValue.getType(), dynDims)
.getResult();
auto allResultsNaNTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{trueValue},
- ValueRange{emptyBoolTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
+ ValueRange{emptyBoolTensor})
.result();
// Note that because the linalg::ReduceOp has two variadic arguments
// (inputs and outputs) and it has the SameVariadicOperandSize trait we
@@ -1262,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
.getResult();
auto nanFilledTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{nanValue},
- ValueRange{emptyNanTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
+ ValueRange{emptyNanTensor})
.result();
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
.getResult();
// Do a selection between the tensors akin to:
@@ -1504,12 +1494,11 @@ public:
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.isUnsignedInteger()) {
- value = nestedBuilder
- .create<UnrealizedConversionCastOp>(
- nestedLoc,
- nestedBuilder.getIntegerType(
- valueTy.getIntOrFloatBitWidth()),
- value)
+ value = UnrealizedConversionCastOp::create(
+ nestedBuilder, nestedLoc,
+ nestedBuilder.getIntegerType(
+ valueTy.getIntOrFloatBitWidth()),
+ value)
.getResult(0);
}
if (valueTy.getIntOrFloatBitWidth() < 32) {
@@ -1558,9 +1547,8 @@ public:
}
if (outIntType.isUnsignedInteger()) {
- value = nestedBuilder
- .create<UnrealizedConversionCastOp>(nestedLoc,
- outIntType, value)
+ value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
+ outIntType, value)
.getResult(0);
}
linalg::YieldOp::create(nestedBuilder, loc, value);
@@ -2096,10 +2084,9 @@ public:
Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
// First fill the output buffer with the init value.
- auto emptyTensor = rewriter
- .create<tensor::EmptyOp>(loc, inputTy.getShape(),
- inputTy.getElementType(),
- ArrayRef<Value>({dynDims}))
+ auto emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, inputTy.getShape(),
+ inputTy.getElementType(), ArrayRef<Value>({dynDims}))
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
@@ -2242,23 +2229,22 @@ public:
}
// First fill the output buffer for the index.
- auto emptyTensorIdx = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- outElementTy, dynDims)
- .getResult();
+ auto emptyTensorIdx =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ outElementTy, dynDims)
+ .getResult();
auto fillValueIdx = arith::ConstantOp::create(
rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
- ValueRange{emptyTensorIdx})
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
+ ValueRange{emptyTensorIdx})
.result();
// Second fill the output buffer for the running max.
- auto emptyTensorMax = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- inElementTy, dynDims)
- .getResult();
+ auto emptyTensorMax =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
+ dynDims)
+ .getResult();
auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
@@ -2269,9 +2255,8 @@ public:
auto fillValueMax =
arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
auto filledTensorMax =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
- ValueRange{emptyTensorMax})
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
+ ValueRange{emptyTensorMax})
.result();
// We need to reduce along the arg-max axis, with parallel operations along
@@ -2372,9 +2357,8 @@ public:
auto loc = op.getLoc();
auto emptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
- dynamicDims)
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultElementTy, dynamicDims)
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
@@ -2449,10 +2433,10 @@ public:
}
}
- auto emptyTensor = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- resultElementTy, dynDims)
- .getResult();
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultElementTy, dynDims)
+ .getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
@@ -2586,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
- auto filledTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValue},
- ValueRange{emptyTensor})
- .result();
+ auto filledTensor =
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
+ ValueRange{emptyTensor})
+ .result();
return filledTensor;
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3a20524..da1fb20 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value conv, Value result,
ArrayRef<AffineMap> indexingMaps) {
ShapedType resultTy = cast<ShapedType>(conv.getType());
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal);
- }
- Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]);
- linalg::YieldOp::create(builder, loc, added);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, conv}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ arith::ExtSIOp::create(builder, loc, resType, biasVal);
+ }
+ Value added =
+ arith::AddIOp::create(builder, loc, biasVal, args[1]);
+ linalg::YieldOp::create(builder, loc, added);
+ })
.getResult(0);
}
@@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
// Build the broadcast-like operation as a linalg.generic.
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({source}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal =
- resultTy.getElementType().isFloat()
- ? arith::ExtFOp::create(builder, loc, resType, biasVal)
- .getResult()
- : arith::ExtSIOp::create(builder, loc, resType, biasVal)
- .getResult();
- }
- linalg::YieldOp::create(builder, loc, biasVal);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({source}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ resultTy.getElementType().isFloat()
+ ? arith::ExtFOp::create(builder, loc, resType, biasVal)
+ .getResult()
+ : arith::ExtSIOp::create(builder, loc, resType,
+ biasVal)
+ .getResult();
+ }
+ linalg::YieldOp::create(builder, loc, biasVal);
+ })
.getResult(0);
}
@@ -397,21 +398,19 @@ public:
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
- Value conv =
- rewriter
- .create<LinalgConvQOp>(
- loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
- ->getResult(0);
+ Value conv = LinalgConvQOp::create(
+ rewriter, loc, resultTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ ->getResult(0);
rewriter.replaceOp(op, conv);
return success();
}
- Value conv = rewriter
- .create<LinalgConvOp>(
- loc, accTy, ValueRange{input, weight},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ Value conv = LinalgConvOp::create(
+ rewriter, loc, accTy, ValueRange{input, weight},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -529,9 +528,8 @@ public:
Value emptyTensor = tensor::EmptyOp::create(
rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
Value biasEmptyTensor = tensor::EmptyOp::create(
@@ -544,10 +542,9 @@ public:
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
if (hasNullZps) {
- Value conv = rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
- loc, linalgConvTy, ValueRange{input, weight},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
+ rewriter, loc, linalgConvTy, ValueRange{input, weight},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -565,22 +562,20 @@ public:
rewriter, loc, resultTy, conv, reassociationMap);
Value result =
- rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, convReshape}),
- biasEmptyTensor, indexingMaps,
- getNParallelLoopsAttrs(resultRank),
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- Value added;
- if (llvm::isa<FloatType>(inputETy))
- added = arith::AddFOp::create(nestedBuilder, loc, args[0],
- args[1]);
- else
- added = arith::AddIOp::create(nestedBuilder, loc, args[0],
- args[1]);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
- })
+ linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, convReshape}),
+ biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange args) {
+ Value added;
+ if (llvm::isa<FloatType>(inputETy))
+ added = arith::AddFOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ else
+ added = arith::AddIOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
+ })
.getResult(0);
rewriter.replaceOp(op, result);
} else {
@@ -588,12 +583,11 @@ public:
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
- Value conv =
- rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
- loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
- .getResult(0);
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
+ rewriter, loc, linalgConvTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ .getResult(0);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = tensor::CollapseShapeOp::create(
@@ -639,9 +633,8 @@ public:
auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
outputTy.getElementType(), filteredDims);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
@@ -910,20 +903,18 @@ public:
rewriter, loc, accTy.getShape(), accETy, dynamicDims);
Value filledEmptyTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{initialValue},
- ValueRange{poolEmptyTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
+ ValueRange{poolEmptyTensor})
.result();
Value fakeWindowDims =
tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
// Sum across the pooled region.
- Value poolingOp = rewriter
- .create<linalg::PoolingNhwcSumOp>(
- loc, ArrayRef<Type>{accTy},
- ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr)
+ Value poolingOp = linalg::PoolingNhwcSumOp::create(
+ rewriter, loc, ArrayRef<Type>{accTy},
+ ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr)
.getResult(0);
// Normalize the summed value by the number of elements grouped in each
@@ -1050,10 +1041,9 @@ public:
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
auto scaled =
- rewriter
- .create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), poolVal, multiplier, shift,
- rewriter.getStringAttr("SINGLE_ROUND"))
+ tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
+ shift, rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();
// If we have quantization information we need to apply output
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
index b83f5ec9..f8efb34 100644
--- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -13,7 +13,6 @@
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 125ea1e..9efa34a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
}
};
-/// Conversion pattern for vector.splat.
-///
-/// Example:
-///
-/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
-///
-/// is converted to:
-///
-/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
-/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
-/// {
-/// %tile_update = arm_sme.insert_tile_slice
-/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
-/// vector<[4]xi32> into vector<[4]x[4]xi32>
-/// scf.yield %tile_update : vector<[4]x[4]xi32>
-/// }
-///
-/// This is identical to vector.broadcast of a scalar.
-struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
- using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
- auto tileType = splatOp.getResult().getType();
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto loc = splatOp.getLoc();
- auto srcType = splatOp.getOperand().getType();
-
- assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
- // Avoid unused-variable warning when building without assertions.
- (void)srcType;
-
- // First, broadcast the scalar to a 1-d vector.
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- Value broadcastOp1D = vector::BroadcastOp::create(
- rewriter, loc, tileSliceType, splatOp.getInput());
-
- auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
-
- auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
- Value currentTile) {
- auto nextTile = arm_sme::InsertTileSliceOp::create(
- b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- return nextTile.getResult();
- };
-
- // Next, create a loop over ZA tile slices and "move" the generated 1-d
- // vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
-
- rewriter.replaceOp(splatOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
}
};
+// Convert all `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to ArmSME via another pattern.
+struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+ PatternRewriter &rewriter) const final {
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+ splatOp.getInput());
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 77aab85..1d1904f 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -31,10 +31,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-to-gpu"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
@@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
// by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
if (!supportsMMaMatrixType(op, useNvGpu)) {
- LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
+ LDBG() << "cannot convert op: " << *op;
return true;
}
return false;
@@ -482,14 +481,12 @@ struct CombineTransferReadOpTranspose final
permutationMap.compose(transferReadOp.getPermutationMap());
auto loc = op.getLoc();
- Value result =
- rewriter
- .create<vector::TransferReadOp>(
- loc, resultType, transferReadOp.getBase(),
- transferReadOp.getIndices(), AffineMapAttr::get(newMap),
- transferReadOp.getPadding(), transferReadOp.getMask(),
- transferReadOp.getInBoundsAttr())
- .getResult();
+ Value result = vector::TransferReadOp::create(
+ rewriter, loc, resultType, transferReadOp.getBase(),
+ transferReadOp.getIndices(), AffineMapAttr::get(newMap),
+ transferReadOp.getPadding(), transferReadOp.getMask(),
+ transferReadOp.getInBoundsAttr())
+ .getResult();
// Fuse through the integer extend op.
if (extOp) {
@@ -550,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
@@ -585,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
- LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+ LDBG() << "transfer read to: " << load;
return success();
}
@@ -599,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
auto it = valueMapping.find(op.getVector());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no mapping\n");
+ LDBG() << "no mapping";
return rewriter.notifyMatchFailure(op, "no mapping");
}
@@ -615,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
- LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+ LDBG() << "transfer write to: " << store;
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -643,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense) {
- LLVM_DEBUG(DBGS() << "not a splat\n");
+ LDBG() << "not a splat";
return rewriter.notifyMatchFailure(op, "not a splat");
}
@@ -679,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
mlir::AffineMap map = op.getPermutationMap();
if (map.getNumResults() != 2) {
- LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
- "is not a 2d operand\n");
+ LDBG() << "Failed because the result of `vector.transfer_read` "
+ "is not a 2d operand";
return failure();
}
@@ -693,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
auto exprN = dyn_cast<AffineDimExpr>(dN);
if (!exprM || !exprN) {
- LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
- "expressions, then transpose cannot be determined.\n");
+ LDBG() << "Failed because expressions are not affine dim "
+ "expressions, then transpose cannot be determined.";
return failure();
}
@@ -711,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
FailureOr<bool> transpose = isTransposed(op);
if (failed(transpose)) {
- LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
+ LDBG() << "failed to determine the transpose";
return rewriter.notifyMatchFailure(
op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
}
@@ -733,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
if (failed(params)) {
- LLVM_DEBUG(
- DBGS()
- << "failed to convert vector.transfer_read to ldmatrix. "
- << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+ LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
+ << "Op should likely not be converted to a nvgpu.ldmatrix call.";
return rewriter.notifyMatchFailure(
op, "failed to convert vector.transfer_read to ldmatrix; this op "
"likely should not be converted to a nvgpu.ldmatrix call.");
@@ -747,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
- LLVM_DEBUG(DBGS() << "no offsets\n");
+ LDBG() << "no offsets";
return rewriter.notifyMatchFailure(op, "no offsets");
}
@@ -936,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1134,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
loop.getNumResults())))
rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
- LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
- LLVM_DEBUG(DBGS() << "erase: " << loop);
+ LDBG() << "newLoop now: " << newLoop;
+ LDBG() << "stripped scf.for: " << loop;
+ LDBG() << "erase: " << loop;
rewriter.eraseOp(loop);
return newLoop;
@@ -1152,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
+ LDBG() << "no value mapping for: " << operand.value();
continue;
}
argMapping.push_back(std::make_pair(
@@ -1170,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
}
- LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+ LDBG() << "scf.for to: " << newForOp;
return success();
}
@@ -1193,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
}
scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1246,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
auto globalRes = LogicalResult::success();
for (Operation *op : ops) {
- LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+ LDBG() << "Process op: " << *op;
// Apparently callers do not want to early exit on failure here.
auto res = LogicalResult::success();
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9cd491c..17a79e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -29,7 +29,9 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Casting.h"
+
#include <optional>
using namespace mlir;
@@ -1068,39 +1070,6 @@ public:
}
};
-class VectorExtractElementOpConversion
- : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
-public:
- using ConvertOpToLLVMPattern<
- vector::ExtractElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getSourceVectorType();
- auto llvmType = typeConverter->convertType(vectorType.getElementType());
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1204,39 +1173,6 @@ public:
}
};
-class VectorInsertElementOpConversion
- : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
-public:
- using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter->convertType(vectorType);
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2242,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index d3d0a45..cf10869 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -96,13 +96,16 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorGatherLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
- arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+ arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
if (armSVE)
- populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ populateLowerContractionToSVEI8MMPatterns(patterns);
+ }
+ if (armBF16) {
+ if (armNeon)
+ arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
+ if (armSVE)
+ populateLowerContractionToSVEBFMMLAPatterns(patterns);
}
- if (armBF16)
- populateLowerContractionToSVEBFMMLAPatterns(patterns);
-
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 4c1047a..508f4e2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -24,7 +24,6 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -691,7 +690,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -1644,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 986eae3..a4be7d4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -335,63 +335,6 @@ struct VectorInsertOpConvert final
}
};
-struct VectorExtractElementOpConvert final
- : public OpConversionPattern<vector::ExtractElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(extractOp.getType());
- if (!resultType)
- return failure();
-
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, resultType, adaptor.getVector(),
- rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
- else
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
-struct VectorInsertElementOpConvert final
- : public OpConversionPattern<vector::InsertElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type vectorType = getTypeConverter()->convertType(insertOp.getType());
- if (!vectorType)
- return failure();
-
- if (isa<spirv::ScalarType>(vectorType)) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(),
- cstPos.getSExtValue());
- else
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorToElementOpConvert, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 2411af0..4dfcb2b 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -10,7 +10,6 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"