diff options
Diffstat (limited to 'mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 214 |
1 files changed, 109 insertions, 105 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 21ea444..8a2e3b63 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); break; } } @@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create<arm_sme::aarch64_sme_st1b_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create<arm_sme::aarch64_sme_st1h_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create<arm_sme::aarch64_sme_st1w_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create<arm_sme::aarch64_sme_st1d_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create<arm_sme::aarch64_sme_st1q_vert>( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } llvm_unreachable("unknown type in createStoreTileSliceIntrinsic"); @@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc, // Move to the first operation in the function. rewriter.setInsertionPointToStart(&func.getBlocks().front()); // Create an alloca matching the tile size of the `tileOp`. - auto vscale = rewriter.create<vector::VectorScaleOp>(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto tileElementType = tileOp.getTileType().getElementType(); auto memrefType = MemRefType::get( {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); auto minElementsOp = - rewriter.create<arith::ConstantIndexOp>(loc, minElements); - auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp); - auto alloca = rewriter.create<memref::AllocaOp>( - loc, memrefType, ValueRange{vectorLen, vectorLen}); + arith::ConstantIndexOp::create(rewriter, loc, minElements); + auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp); + auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType, + ValueRange{vectorLen, vectorLen}); return alloca; } @@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { Value tileMemory, Value sliceIndex) const { auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); auto descriptor = - rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory); - auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64); - auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), sliceIndex); + UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory); + auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64); + auto sliceIndexI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), sliceIndex); return getStridedElementPtr( static_cast<ConversionPatternRewriter &>(rewriter), loc, llvm::cast<MemRefType>(tileMemory.getType()), descriptor.getResult(0), @@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { arm_sme::ArmSMETileType tileType, VectorType sliceType, IntegerAttr tileId, Value sliceIndex) const { // Cast the slice index to an i32. - auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create an all-true predicate for the slice. auto predicateType = sliceType.clone(rewriter.getI1Type()); - auto allTruePredicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Create padding vector (never used due to all-true predicate). - auto padVector = rewriter.create<LLVM::PoisonOp>(loc, sliceType); + auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType); // Get a pointer to the current slice. auto slicePtr = getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); // Read the value of the current slice from ZA. - auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>( - loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); + auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create( + rewriter, loc, sliceType, padVector, allTruePredicate, tileId, + sliceIndexI32); // Load the new tile slice back from memory into ZA. createLoadTileSliceIntrinsic( rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, allTruePredicate, slicePtr, tileId, sliceIndexI32); // Store the current tile slice to memory. - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca, - ValueRange{sliceIndex, zero}); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca, + ValueRange{sliceIndex, zero}); } /// Emits a full in-place swap of the contents of a tile in ZA and a @@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { RewriterBase::InsertionGuard guard(rewriter); // Create an scf.for over all tile slices. auto minNumElts = - rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0)); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto upperBound = rewriter.create<arith::MulIOp>( - loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc)); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = + arith::MulIOp::create(rewriter, loc, minNumElts, + vector::VectorScaleOp::create(rewriter, loc)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Emit a swap for each tile slice. rewriter.setInsertionPointToStart(forOp.getBody()); auto sliceIndex = forOp.getInductionVar(); @@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> { // // This holds for all tile sizes. int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); - rewriter.create<arm_sme::aarch64_sme_zero>( - loc, rewriter.getI32IntegerAttr(zeroMask)); + arm_sme::aarch64_sme_zero::create(rewriter, loc, + rewriter.getI32IntegerAttr(zeroMask)); // Create a placeholder op to preserve dataflow. // Note: Place the `get_tile` op at the start of the block. This ensures @@ -513,8 +516,8 @@ struct LoadTileSliceConversion auto tileSlice = loadTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. auto maskOp = loadTileSliceOp.getMask(); @@ -559,8 +562,8 @@ struct StoreTileSliceConversion auto tileSlice = storeTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); auto maskOp = storeTileSliceOp.getMask(); @@ -595,28 +598,29 @@ struct InsertTileSliceConversion auto tileSlice = insertTileSliceOp.getTileSliceIndex(); // Cast tile slice from index to i32 for intrinsic. - auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. - auto one = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI1Type(), + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); + auto allActiveMask = + vector::BroadcastOp::create(rewriter, loc, predTy, one); // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. switch (insertTileSliceOp.getLayout()) { case arm_sme::TileSliceLayout::Horizontal: - rewriter.create<arm_sme::aarch64_sme_write_horiz>( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; case arm_sme::TileSliceLayout::Vertical: - rewriter.create<arm_sme::aarch64_sme_write_vert>( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; } @@ -646,16 +650,16 @@ struct ExtractTileSliceConversion // Create an 'all true' predicate for the tile slice. auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); - auto allTruePredicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Zero destination/fallback for tile slice extraction. - auto zeroVector = rewriter.create<arith::ConstantOp>( - loc, sliceType, rewriter.getZeroAttr(sliceType)); + auto zeroVector = arith::ConstantOp::create( + rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType)); // Cast tile slice from index to i32 for intrinsic. - auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. switch (extractTileSlice.getLayout()) { @@ -743,7 +747,7 @@ struct OuterProductOpConversion Value acc = outerProductOp.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType); zero.setTileId(tileId); acc = zero; } @@ -754,16 +758,16 @@ struct OuterProductOpConversion if (!lhsMask || !rhsMask) { auto predTy = outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } // Create 'arm_sme.intr.mopa' outer product intrinsic. - rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask, - outerProductOp.getLhs(), - outerProductOp.getRhs()); + arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask, + outerProductOp.getLhs(), + outerProductOp.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -792,7 +796,7 @@ struct OuterProductWideningOpConversion Value acc = op.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType()); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType()); zero.setTileId(tileId); acc = zero; } @@ -801,14 +805,14 @@ struct OuterProductWideningOpConversion Value rhsMask = op.getRhsMask(); if (!lhsMask || !rhsMask) { auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } - rewriter.create<OuterProductWideningIntrOp>( - loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); + OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask, + adaptor.getLhs(), adaptor.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -843,13 +847,13 @@ struct StreamingVLOpConversion auto *intrOp = [&]() -> Operation * { switch (streamingVlOp.getTypeSize()) { case arm_sme::TypeSize::Byte: - return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type); + return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Half: - return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type); + return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Word: - return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type); + return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Double: - return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type); + return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); } llvm_unreachable("unknown type size in StreamingVLOpConversion"); }(); @@ -872,8 +876,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) { if (zeroOpsToMerge.size() <= 1) return; IRRewriter rewriter(zeroOpsToMerge.front()); - rewriter.create<arm_sme::aarch64_sme_zero>( - zeroOpsToMerge.front().getLoc(), + arm_sme::aarch64_sme_zero::create( + rewriter, zeroOpsToMerge.front().getLoc(), rewriter.getI32IntegerAttr(mergedZeroMask)); for (auto zeroOp : zeroOpsToMerge) rewriter.eraseOp(zeroOp); |