aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
blob: 8694bca974df11f89ca0957d25b02a9e21527de8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
//===- XeGPUPeepHoleOptimizer.cpp - XeGPU optimize block loads -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>

namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
} // namespace xegpu
} // namespace mlir

#define DEBUG_TYPE "xegpu-optimize-peephole"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

using namespace mlir;

namespace {

/// Get the 2D lane data from a tensor desc type if it exists.
static std::optional<SmallVector<int64_t>>
getMaybeLaneData(xegpu::TensorDescType tdescType) {
  auto layout = tdescType.getLayoutAttr();
  if (!layout)
    return std::nullopt;
  auto laneData = layout.getEffectiveLaneDataAsInt();
  if (laneData.size() != 2)
    return std::nullopt;
  return laneData;
}

/// Get the 2D lane layout from a tensor desc type if it exists.
static std::optional<SmallVector<int64_t>>
getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
  auto layout = tdescType.getLayoutAttr();
  if (!layout)
    return std::nullopt;
  auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
  if (laneLayout.size() != 2)
    return std::nullopt;
  return laneLayout;
}

/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 &&
/// lane[1] == 1), but inner lane data is not equal to [1, 1].
/// Example:
///     !xegpu.tensor_desc<16x16xf16,
///         #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form)
/// indicating that this is a load that requires transpose effect. However,
/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from
/// the inner dimension. We convert this to a optimized form by converting the
/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the
/// later lowering easily use the load with transpose instruction.
static bool canBeOptimizedForTranspose(ArrayRef<int64_t> laneLayout,
                                       ArrayRef<int64_t> laneData) {
  if (laneLayout.size() != 2 || laneData.size() != 2)
    return false;
  if (laneLayout[0] == 1 || laneLayout[1] != 1)
    return false;
  if (laneData[0] != 1 || laneData[1] == 1)
    return false;
  return true;
}

/// A tensor desc type can be optimized if its element type is less than 32 bits
/// and its layout can be optimized.
static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
  // If the dtype is greater or equal to 32 bits, layout must be valid.
  int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
  if (elementTyBitwidth >= 32)
    return false;
  auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
  auto maybeLaneData = getMaybeLaneData(tdescType);
  if (!maybeLaneData || !maybeLaneLayout)
    return false;
  return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
}

/// Check if a tensor desc type can be optimized for transpose, if so return the
/// new optimized tensor desc type with a valid transpose layout.
static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
                                         const uArch *targetuArch) {
  if (!canBeOptimizedForTranspose(tdescType))
    return tdescType;
  auto laneData = getMaybeLaneData(tdescType)
                      .value(); // Lane data must exist if we reach here.
  int64_t innerLaneData = laneData[1];
  int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
  // Required shape is total shape of the vector result that this tensor desc
  // must eventually load after adjusting for the new bitwidth and array
  // length.
  SmallVector<int64_t> requiredShape(tdescType.getShape());
  requiredShape.back() =
      requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
  int newBitWidth = elementTyBitwidth * innerLaneData;
  Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
  // Supported shape is the max transpose shape that can be supported by
  // hardware that is less than or equal to required shape.
  auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
      targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad));
  auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
      newElemTy, /** has transform */ false, /** has transpose */ true);
  // If no HW params found, return the original type.
  if (!maybeHWParams)
    return tdescType;
  auto [widths, heights, counts] = maybeHWParams.value();
  // TODO: Currently we expect array length to be 1 for transpose case.
  if (counts.size() != 1 || counts[0] != 1)
    return tdescType;
  int arrayLen = counts[0];
  int supportedHeight =
      xegpu::getLargestDivisor(static_cast<int>(requiredShape[0]), heights);
  int supportedWidth =
      xegpu::getLargestDivisor(static_cast<int>(requiredShape[1]), widths);
  // If no supported height or width found, return the original type.
  if (supportedHeight == -1 || supportedWidth == -1)
    return tdescType;

  SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
  xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
      tdescType.getContext(),
      tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
  // Array length can not be larger than 1 for transpose case.
  return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
                                    tdescType.getBoundaryCheck(),
                                    tdescType.getMemorySpace(), newLayout);
}

/// Helper to convert an OpFoldResult to Value.
static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
                            OpFoldResult ofr) {
  std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
  if (mayBeInt)
    return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult();
  return llvm::cast<Value>(ofr);
}

/// Helper to divide a Value by a constant integer.
static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
                              Value val, int64_t constant) {
  // If the constant is a power of 2, use right shift for division.
  if (llvm::isPowerOf2_64(constant)) {
    int64_t shiftAmount = llvm::Log2_64(constant);
    return arith::ShRUIOp::create(
               rewriter, loc, val,
               arith::ConstantIndexOp::create(rewriter, loc, shiftAmount)
                   .getResult())
        .getResult();
  }
  auto constantOp =
      arith::ConstantIndexOp::create(rewriter, loc, constant).getResult();
  return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
}

/// This function takes a larger register block `data` and generates multiple
/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block
/// starting from `offsets`.
static Value generateLoads(ConversionPatternRewriter &rewriter,
                           TypedValue<VectorType> data,
                           SmallVector<OpFoldResult> offsets,
                           TypedValue<xegpu::TensorDescType> newTensorDesc,
                           xegpu::LoadNdOp origLoadOp) {
  Location loc = data.getLoc();
  assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
  Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
  Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
  SmallVector<int64_t> supportedShape(newTensorDesc.getType().getShape());
  // Compute the ratio between original shape and supported shape. We need to
  // generate loads in this ratio arrangement.
  auto shapeRatio = computeShapeRatio(data.getType().getShape(),
                                      supportedShape)
                        .value(); // `ratio` must be defined if we reach here.
  for (int64_t h = 0; h < shapeRatio[0]; ++h) {
    for (int64_t w = 0; w < shapeRatio[1]; ++w) {
      int64_t localOffsetDim0 = h * supportedShape[0];
      int64_t localOffsetDim1 = w * supportedShape[1];
      Value loadOffsetX = arith::AddIOp::create(
          rewriter, loc, offsetDim0,
          arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0)
              .getResult());
      Value loadOffsetY = arith::AddIOp::create(
          rewriter, loc, offsetDim1,
          arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1)
              .getResult());
      auto loadOp = xegpu::LoadNdOp::create(
          rewriter, loc,
          VectorType::get(supportedShape, data.getType().getElementType()),
          newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
          origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
          origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
          origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
      // Set the layout for the loadOp.
      auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
      loadOp.setAnchorLayout(layoutAttr);
      // Insert the loaded block into the right position in data.
      auto insertOp = vector::InsertStridedSliceOp::create(
          rewriter, loc, loadOp.getResult(), data,
          ArrayRef<int64_t>{localOffsetDim0, localOffsetDim1},
          ArrayRef<int64_t>{1, 1});
      // InsertOp must have the same layout as newTensorDesc.
      xegpu::setTemporaryLayout(insertOp->getOpResult(0), layoutAttr);
      data = insertOp.getResult();
    }
  }
  return data;
}

/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a
/// new CreateNdDescOp with optimized tensor desc type. This involves extracting
/// the base pointer from the original memory source and adjusting the shape and
/// strides of the tensor desc to fit with the new optimized transpose layout.
class XeGPUCreateNdDescOpPattern final
    : public OpConversionPattern<xegpu::CreateNdDescOp> {
public:
  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
  LogicalResult
  matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto tdescTy = createNdOp.getType();
    // Get the target uArch info.
    auto chipStr = xegpu::getChipStr(createNdOp);
    // Check if the chip is supported.
    assert(
        chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") &&
        "Expecting target chip to be pvc or bmg for transpose optimization.");
    const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value());

    auto convertType = tryOptimize(tdescTy, targetuArch);
    if (convertType == tdescTy)
      return failure();
    auto strides = createNdOp.getMixedStrides();
    auto maybeConstInnerStride = getConstantIntValue(strides.back());
    // Only row-major memrefs are expected for now.
    if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
      return rewriter.notifyMatchFailure(
          createNdOp, "Expecting row-major memref for transpose optimization.");
    Value source = createNdOp.getSource();
    auto optionalLaneData = getMaybeLaneData(tdescTy);
    assert(optionalLaneData && "Expected 2D lane data");
    auto laneData = optionalLaneData.value();
    int64_t innerLaneData = laneData[1];
    auto memrefType = dyn_cast<MemRefType>(source.getType());
    // Inner dimension of the shape must be adjusted based on innerLaneData.
    SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
    modifiedShape.back() = divideByConstant(
        rewriter, createNdOp.getLoc(),
        convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
        innerLaneData);
    // Similarly, second to last stride must be adjusted.
    assert(strides.size() >= 2 &&
           "Expected at least 2 strides for CreateNdDescOp");
    SmallVector<OpFoldResult> modifiedStrides(strides);
    modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
        rewriter, createNdOp.getLoc(),
        convertToValue(rewriter, createNdOp.getLoc(),
                       modifiedStrides[modifiedStrides.size() - 2]),
        innerLaneData);

    // If the source is a static memref, we need to extract the pointer to
    // base address.
    if (memrefType && memrefType.hasStaticShape()) {
      auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
          rewriter, createNdOp.getLoc(), source);
      source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
                                          rewriter.getI64Type(),
                                          extractOp.getResult())
                   .getResult();
    }
    // Create a new CreateNdDescOp with the modified shape and converted type.
    auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
        rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
        modifiedStrides);
    rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
    return success();
  }
};

/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for
/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the
/// adjusted tensor desc type. This can result in multiple LoadNdOps being
/// generated to fill in the original load shape.
class XeGPULoadNdDescOpPattern final
    : public OpConversionPattern<xegpu::LoadNdOp> {
public:
  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
  LogicalResult
  matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto origTensorDescType = loadNdOp.getTensorDescType();
    auto adaptorType =
        cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
    if (adaptorType == origTensorDescType)
      return failure();
    // Offsets must be adjusted based on innerLaneData.
    auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
    int64_t innerLaneData = laneData[1];
    auto offsets = loadNdOp.getMixedOffsets();
    if (offsets.empty())
      return rewriter.notifyMatchFailure(loadNdOp,
                                         "Expecting offsets in LoadNd");
    SmallVector<OpFoldResult> modifiedOffsets(offsets);
    modifiedOffsets.back() = divideByConstant(
        rewriter, loadNdOp.getLoc(),
        convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
        innerLaneData);
    // Get the 2D data shape of this loadNdOp in its original type including
    // array length.
    SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
    // Adjust the data shape based on innerLaneData.
    origDataShape.back() /= innerLaneData;
    // HW supported shape is the new tensor desc shape after conversion.
    SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
    VectorType origVectorType =
        VectorType::get(origDataShape, adaptorType.getElementType());
    Value data;
    // Orig data shape is 3D for the array length case.
    if (origTensorDescType.getArrayLength() > 1) {
      SmallVector<Value> arraySlices;
      for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
        Value slice = arith::ConstantOp::create(
            rewriter, loadNdOp->getLoc(), origVectorType,
            rewriter.getZeroAttr(origVectorType));
        // Increase the Y offset for each array slice.
        Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
                                       modifiedOffsets.back());
        modifiedOffsets.back() =
            arith::AddIOp::create(
                rewriter, loadNdOp->getLoc(), offsetY,
                arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(),
                                               i * origDataShape[1])
                    .getResult())
                .getResult();
        slice = generateLoads(
            rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets,
            cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
            loadNdOp);
        // BitCast back to original load shape without array length.
        auto bitcastType = VectorType::get(origTensorDescType.getShape(),
                                           origTensorDescType.getElementType());
        auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
                                                   bitcastType, slice);
        // BitCastOp must have the same layout as the original loadNdOp.
        xegpu::setTemporaryLayout(bitCastOp->getOpResult(0),
                                  origTensorDescType.getLayoutAttr());
        arraySlices.push_back(bitCastOp.getResult());
      }
      rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
      return success();
    }
    data = arith::ConstantOp::create(
        rewriter, loadNdOp->getLoc(),
        VectorType::get(origDataShape, adaptorType.getElementType()),
        rewriter.getZeroAttr(origVectorType));
    data = generateLoads(
        rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets,
        cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
        loadNdOp);
    auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
                                               loadNdOp.getType(), data);
    // BitCastOp must have the same layout as the original loadNdOp.
    xegpu::setTemporaryLayout(bitCastOp->getOpResult(0),
                              origTensorDescType.getLayoutAttr());
    rewriter.replaceOp(loadNdOp, bitCastOp);
    return success();
  }
};

/// Vector ExtractOp must be processed if the original tensor desc type has
/// array length greater than 1. In this case, the LoadNdOp is replaced with
/// multiple LoadNdOps for each array slice making the extraction unnecessary.
/// In this case, we simply remove the ExtractOp.
class VectorExtractOpPattern final
    : public OpConversionPattern<vector::ExtractOp> {
public:
  using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
  LogicalResult
  matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Check if the source of the extraction is split to multiple values.
    if (adaptor.getSource().size() == 1)
      return failure();
    auto mixedPos = extractOp.getMixedPosition();
    if (mixedPos.size() != 1)
      return failure();
    auto mayBeInt = getConstantIntValue(mixedPos[0]);
    if (!mayBeInt)
      return failure();
    rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
    return success();
  }
};

/// Performs a reduction over 2 dimensions by decomposing it into two 1D
/// reductions ordered based on layout to minimize cross-lane communication.
class MultiRed2dOpPattern
    : public OpConversionPattern<vector::MultiDimReductionOp> {
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto sourceVecType = reductionOp.getSourceVectorType();
    if (reductionOp.getReductionDims().size() != 2 ||
        sourceVecType.getRank() != 2)
      return rewriter.notifyMatchFailure(
          reductionOp, "Expected 2D multi reduction of a 2D source");
    auto resLayout = xegpu::getDistributeLayoutAttr(reductionOp.getResult());
    // Retrieve and order dims for 1D decomposition (prefer intra-lane first).
    auto dims = llvm::to_vector(reductionOp.getReductionDims());
    auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
    // Order does not matter
    if (intraLaneDim == -1 || crossLaneDim == -1) {
      intraLaneDim = dims[0];
      crossLaneDim = dims[1];
    }
    auto loc = reductionOp.getLoc();
    auto acc = reductionOp.getAcc();

    // The first reduction's dist attribute does not have the cross lane dim.
    auto resSliceLayoutAttr = cast<xegpu::SliceAttr>(resLayout);
    SmallVector<int64_t> dropDims{crossLaneDim};
    auto intraLaneRedResLayout = resSliceLayoutAttr.dropSliceDims(dropDims);

    SmallVector<int64_t> accShape(sourceVecType.getShape());
    accShape.erase(accShape.begin() + intraLaneDim);
    if (acc) {
      acc = vector::BroadcastOp::create(
          rewriter, loc,
          VectorType::get(accShape, sourceVecType.getElementType()), acc);
      xegpu::setDistributeLayoutAttr(
          llvm::dyn_cast<OpResult>(acc),
          cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
    }
    Value intraLaneReduced = vector::MultiDimReductionOp::create(
        rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), acc,
        ArrayRef<int64_t>(intraLaneDim));
    xegpu::setDistributeLayoutAttr(
        llvm::dyn_cast<OpResult>(intraLaneReduced),
        cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));

    Value crossLaneReduced = vector::ReductionOp::create(
        rewriter, loc, reductionOp.getKind(), intraLaneReduced, nullptr);
    xegpu::setDistributeLayoutAttr(
        llvm::dyn_cast<OpResult>(crossLaneReduced),
        cast<xegpu::DistributeLayoutAttr>(resLayout));
    assert(crossLaneReduced.getType() == reductionOp.getResult().getType() &&
           "Type mismatch");
    rewriter.replaceOp(reductionOp, crossLaneReduced);
    return success();
  }

private:
  std::pair<int64_t, int64_t>
  getReductionDimOrder(ArrayRef<int64_t> reductionDims,
                       xegpu::DistributeLayoutAttr layout) const {
    assert(layout.isForSubgroup() && "Must know the lane layout");
    assert(reductionDims.size() == 2 && "Expected 2D reduction");
    int64_t intra, cross = -1;
    xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
    if (auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
      layoutAttr =
          dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
    assert(layoutAttr);
    SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();

    assert(laneLayout.size() && "Expected a non-empty layout");
    // try to pick a dim that does not communicate
    for (auto dim : reductionDims) {
      if (laneLayout[dim] == 1)
        intra = dim;
      else
        cross = dim;
    }
    return {intra, cross};
  }
};

} // namespace

void xegpu::populateXeGPUPeepHoleOptimizerPatterns(
    RewritePatternSet &patterns) {
  patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
               VectorExtractOpPattern, MultiRed2dOpPattern>(
      patterns.getContext());
}

namespace {

struct XeGPUPeepHoleOptimizerPass final
    : public xegpu::impl::XeGPUPeepHoleOptimizerBase<
          XeGPUPeepHoleOptimizerPass> {
  void runOnOperation() override {
    MLIRContext &context = getContext();
    TypeConverter converter;
    RewritePatternSet patterns(&context);
    ConversionTarget target(context);

    // This pass is only meant for PVC and BMG targets. If unsupported target
    // is found, exit early.
    bool isTargetSupported = false;
    getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
      auto chipStr = xegpu::getChipStr(funcOp);
      if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg"))
        isTargetSupported = true;
    });

    if (!isTargetSupported) {
      DBGS() << "XeGPUPeepHoleOptimizerPass only supports PVC and BMG targets."
             << "\n";
      return;
    }

    // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be
    // converted.
    target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
        [&](xegpu::CreateNdDescOp createNdOp) {
          return !canBeOptimizedForTranspose(createNdOp.getType());
        });
    target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
        [&](xegpu::LoadNdOp loadNdOp) {
          return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
        });
    // Vector ExtractOps can have optimizable layouts if they extract from
    // LoadNdOps with array length greater than 1. These ExtractOps must be
    // converted.
    target.addDynamicallyLegalOp<vector::ExtractOp>(
        [&](vector::ExtractOp extractOp) {
          auto layout = xegpu::getTemporaryLayout(
              dyn_cast<OpResult>(extractOp.getResult()));
          if (!layout)
            return true;
          auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
          auto laneData = layout.getEffectiveLaneDataAsInt();
          return !canBeOptimizedForTranspose(laneLayout, laneData);
        });

    target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
        [=](Operation *op) -> bool {
          auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
          if (!layout || !layout.isForSubgroup())
            return true;
          if (auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
            return reductionOp.getReductionDims().size() != 2;
          return true;
        });

    converter.addConversion([](Type type) { return type; });

    target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
                           vector::VectorDialect>();
    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                         target);
    xegpu::populateXeGPUPeepHoleOptimizerPatterns(patterns);
    if (failed(applyPartialConversion(getOperation(), target,
                                      std::move(patterns)))) {
      DBGS() << "Optimize block loads pass failed.\n";
      return signalPassFailure();
    }
  }
};

} // namespace