aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp214
1 files changed, 109 insertions, 105 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 21ea444..8a2e3b63 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
break;
}
}
@@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic(
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
- return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
- return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
- return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
- return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
- return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
- loc, maskOp, ptr, tileId, tileSliceI32);
+ return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr,
+ tileId, tileSliceI32);
}
}
llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
@@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc,
// Move to the first operation in the function.
rewriter.setInsertionPointToStart(&func.getBlocks().front());
// Create an alloca matching the tile size of the `tileOp`.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto vscale = vector::VectorScaleOp::create(rewriter, loc);
auto tileElementType = tileOp.getTileType().getElementType();
auto memrefType = MemRefType::get(
{ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
auto minElementsOp =
- rewriter.create<arith::ConstantIndexOp>(loc, minElements);
- auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
- auto alloca = rewriter.create<memref::AllocaOp>(
- loc, memrefType, ValueRange{vectorLen, vectorLen});
+ arith::ConstantIndexOp::create(rewriter, loc, minElements);
+ auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp);
+ auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType,
+ ValueRange{vectorLen, vectorLen});
return alloca;
}
@@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
Value tileMemory, Value sliceIndex) const {
auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
auto descriptor =
- rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
- auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
- auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI64Type(), sliceIndex);
+ UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory);
+ auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64);
+ auto sliceIndexI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), sliceIndex);
return getStridedElementPtr(
static_cast<ConversionPatternRewriter &>(rewriter), loc,
llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0),
@@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
arm_sme::ArmSMETileType tileType, VectorType sliceType,
IntegerAttr tileId, Value sliceIndex) const {
// Cast the slice index to an i32.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create an all-true predicate for the slice.
auto predicateType = sliceType.clone(rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Create padding vector (never used due to all-true predicate).
- auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType);
+ auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType);
// Get a pointer to the current slice.
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
// Read the value of the current slice from ZA.
- auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
- loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
+ auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create(
+ rewriter, loc, sliceType, padVector, allTruePredicate, tileId,
+ sliceIndexI32);
// Load the new tile slice back from memory into ZA.
createLoadTileSliceIntrinsic(
rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
allTruePredicate, slicePtr, tileId, sliceIndexI32);
// Store the current tile slice to memory.
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
- ValueRange{sliceIndex, zero});
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca,
+ ValueRange{sliceIndex, zero});
}
/// Emits a full in-place swap of the contents of a tile in ZA and a
@@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
RewriterBase::InsertionGuard guard(rewriter);
// Create an scf.for over all tile slices.
auto minNumElts =
- rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(
- loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0));
+ auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto upperBound =
+ arith::MulIOp::create(rewriter, loc, minNumElts,
+ vector::VectorScaleOp::create(rewriter, loc));
+ auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto forOp =
+ scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
// Emit a swap for each tile slice.
rewriter.setInsertionPointToStart(forOp.getBody());
auto sliceIndex = forOp.getInductionVar();
@@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
//
// This holds for all tile sizes.
int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- loc, rewriter.getI32IntegerAttr(zeroMask));
+ arm_sme::aarch64_sme_zero::create(rewriter, loc,
+ rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
// Note: Place the `get_tile` op at the start of the block. This ensures
@@ -513,8 +516,8 @@ struct LoadTileSliceConversion
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
auto maskOp = loadTileSliceOp.getMask();
@@ -559,8 +562,8 @@ struct StoreTileSliceConversion
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
auto maskOp = storeTileSliceOp.getMask();
@@ -595,28 +598,29 @@ struct InsertTileSliceConversion
auto tileSlice = insertTileSliceOp.getTileSliceIndex();
// Cast tile slice from index to i32 for intrinsic.
- auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), tileSlice);
+ auto tileSliceI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
- auto one = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI1Type(),
+ auto one = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+ auto allActiveMask =
+ vector::BroadcastOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
- rewriter.create<arm_sme::aarch64_sme_write_horiz>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
- rewriter.create<arm_sme::aarch64_sme_write_vert>(
- loc, tileId, tileSliceI32, allActiveMask,
- insertTileSliceOp.getVector());
+ arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId,
+ tileSliceI32, allActiveMask,
+ insertTileSliceOp.getVector());
break;
}
@@ -646,16 +650,16 @@ struct ExtractTileSliceConversion
// Create an 'all true' predicate for the tile slice.
auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ auto allTruePredicate = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predicateType, true));
// Zero destination/fallback for tile slice extraction.
- auto zeroVector = rewriter.create<arith::ConstantOp>(
- loc, sliceType, rewriter.getZeroAttr(sliceType));
+ auto zeroVector = arith::ConstantOp::create(
+ rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType));
// Cast tile slice from index to i32 for intrinsic.
- auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
- loc, rewriter.getI32Type(), sliceIndex);
+ auto sliceIndexI32 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI32Type(), sliceIndex);
// Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
switch (extractTileSlice.getLayout()) {
@@ -743,7 +747,7 @@ struct OuterProductOpConversion
Value acc = outerProductOp.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType);
zero.setTileId(tileId);
acc = zero;
}
@@ -754,16 +758,16 @@ struct OuterProductOpConversion
if (!lhsMask || !rhsMask) {
auto predTy =
outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
// Create 'arm_sme.intr.mopa' outer product intrinsic.
- rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
- outerProductOp.getLhs(),
- outerProductOp.getRhs());
+ arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ outerProductOp.getLhs(),
+ outerProductOp.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -792,7 +796,7 @@ struct OuterProductWideningOpConversion
Value acc = op.getAcc();
if (!acc) {
// Initalize accumulator with zero.
- auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+ auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType());
zero.setTileId(tileId);
acc = zero;
}
@@ -801,14 +805,14 @@ struct OuterProductWideningOpConversion
Value rhsMask = op.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
- Value allActiveMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predTy, true));
+ Value allActiveMask = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
- rewriter.create<OuterProductWideningIntrOp>(
- loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
+ OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask,
+ adaptor.getLhs(), adaptor.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -843,13 +847,13 @@ struct StreamingVLOpConversion
auto *intrOp = [&]() -> Operation * {
switch (streamingVlOp.getTypeSize()) {
case arm_sme::TypeSize::Byte:
- return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Half:
- return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Word:
- return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
case arm_sme::TypeSize::Double:
- return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
+ return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
}
llvm_unreachable("unknown type size in StreamingVLOpConversion");
}();
@@ -872,8 +876,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) {
if (zeroOpsToMerge.size() <= 1)
return;
IRRewriter rewriter(zeroOpsToMerge.front());
- rewriter.create<arm_sme::aarch64_sme_zero>(
- zeroOpsToMerge.front().getLoc(),
+ arm_sme::aarch64_sme_zero::create(
+ rewriter, zeroOpsToMerge.front().getLoc(),
rewriter.getI32IntegerAttr(mergedZeroMask));
for (auto zeroOp : zeroOpsToMerge)
rewriter.eraseOp(zeroOp);