diff options
Diffstat (limited to 'mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp')
-rw-r--r-- | mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 113 |
1 files changed, 59 insertions, 54 deletions
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 458628c..e28d5122 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -39,7 +39,7 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank, auto tileSliceOffset = tileSliceIndex; auto baseIndexPlusTileSliceOffset = - rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset); + arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset); outIndices.push_back(baseIndexPlusTileSliceOffset); outIndices.push_back(indices[1]); @@ -59,10 +59,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( if (memrefIndices.size() != 2) return rewriter.notifyMatchFailure(loc, "invalid number of indices"); - auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, + arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); auto vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); @@ -70,7 +71,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( // elements in a vector of SVL bits for a given element type (SVL_B, // SVL_H, ..., SVL_Q). auto numTileSlices = - rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); Value predicate; Value upperBound; @@ -82,30 +83,30 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( // The upper bound of the loop must be clamped at `numTileSlices` as // `vector.create_mask` allows operands to be greater than the size of a // dimension. - auto numRowI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), maskDim0); - auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI64Type(), numTileSlices); + auto numRowI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), maskDim0); + auto numTileSlicesI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), numTileSlices); auto upperBoundI64 = - rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64); - upperBound = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), upperBoundI64); + arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64); + upperBound = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), upperBoundI64); predicate = - rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1); + vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1); } else { upperBound = numTileSlices; // No mask. Create an 'all true' predicate for the tile slice. - predicate = rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(predicateType, true)); + predicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); } bool hasCarriedArgs = bool(initTile); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step, - hasCarriedArgs ? ValueRange{initTile} - : ValueRange{}); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, + hasCarriedArgs ? ValueRange{initTile} : ValueRange{}); rewriter.setInsertionPointToStart(forOp.getBody()); Value tileSliceIndex = forOp.getInductionVar(); @@ -118,7 +119,7 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices( assert(bool(nextTile) == hasCarriedArgs); if (nextTile) - rewriter.create<scf::YieldOp>(loc, nextTile); + scf::YieldOp::create(rewriter, loc, nextTile); return forOp; } @@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { // Initialize tile with zero to satisfy padding. Inactive cols will be // zeroed anyway since the loads use zeroing predication. For inactive // rows however, no load will occur so these need to be zeroed. - initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType); + initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType); } else { - initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); } // Create a loop to load the active tile slices from memory. @@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> { Value currentTile) -> Value { // Create 'arm_sme.load_tile_slice' to load tile slice from memory // into tile. - return rewriter.create<arm_sme::LoadTileSliceOp>( - loc, tileType, tileLoadOp.getBase(), predicate, currentTile, - memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + return arm_sme::LoadTileSliceOp::create( + rewriter, loc, tileType, tileLoadOp.getBase(), predicate, + currentTile, memrefIndices, tileSliceIndex, + tileLoadOp.getLayout()); }); if (failed(forOp)) @@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto numRows = createMaskOp.getOperands()[0]; auto numCols = createMaskOp.getOperands()[1]; - auto numColsI32 = rewriter.create<arith::IndexCastUIOp>( - loc, rewriter.getI32Type(), numCols); + auto numColsI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), numCols); - auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); // Create a loop that loads each ZA tile slice from memory. - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto numTileSlices = - rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, - step, ValueRange{initTile}); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices, + step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); @@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto currentTile = forOp.getRegionIterArg(0); // Combine masks. - auto rowIsActive = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); - auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>( - loc, rewriter.getI32Type(), rowIsActive); - auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32); - auto maskIndex = - rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask); + auto rowIsActive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); + auto rowIsActiveI32 = arith::ExtSIOp::create( + rewriter, loc, rewriter.getI32Type(), rowIsActive); + auto mask = + arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32); + auto maskIndex = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), mask); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); - auto maskOp1D = rewriter.create<vector::CreateMaskOp>( - loc, predicateType, maskIndex.getResult()); + auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType, + maskIndex.getResult()); auto memrefIndices = getMemrefIndices( tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(), @@ -324,17 +327,19 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion // Splat pad into 1-D vector matching type of tile slice. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp); + auto pad1DOp = + vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp); - auto loadSlice = rewriter.create<vector::MaskedLoadOp>( - loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D, - /*passthru=*/pad1DOp); + auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType, + tileLoadOp.getBase(), + memrefIndices, maskOp1D, + /*passthru=*/pad1DOp); // Create 'arm_sme.insert_tile_slice' to insert slice into tile. - auto insertSlice = rewriter.create<arm_sme::InsertTileSliceOp>( - loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex, - tileLoadOp.getLayout()); - rewriter.create<scf::YieldOp>(loc, insertSlice.getResult()); + auto insertSlice = arm_sme::InsertTileSliceOp::create( + rewriter, loc, tileType, loadSlice->getResult(0), currentTile, + tileSliceIndex, tileLoadOp.getLayout()); + scf::YieldOp::create(rewriter, loc, insertSlice.getResult()); rewriter.setInsertionPointAfter(forOp); |