aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
blob: 7b9ed1d8cd21ac6083df8666eb6803bd4aa1623d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"

#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/DebugLog.h"

#include <numeric>

namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOAMX
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

#define DEBUG_TYPE "vector-to-amx"

namespace {

/// Return true if vector shape is compatible with AMX tiles.
/// The validation accounts for VNNI packing.
static bool verifyAmxShape(VectorType vec) {
  // Check overall shape:
  //   - 2D for plain layout input or output
  //   - 3D for VNNI packed input
  if (vec.getRank() != 2 && vec.getRank() != 3)
    return false;

  ArrayRef<int64_t> shape = vec.getShape();
  int64_t rows = shape[0];
  int64_t cols = shape[1];
  unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();

  // 3D shape indicates VNNI packed layout.
  if (vec.getRank() == 3) {
    int64_t vnniFactor = 32 / elemBitWidth;
    if (shape.back() != vnniFactor) {
      LDBG() << "invalid VNNI packing factor";
      return false;
    }
    cols *= vnniFactor;
  }

  // AMX tile supports up to 16 rows of 64 bytes each.
  constexpr unsigned maxRows = 16;
  constexpr unsigned maxBitsPerRow = 64 * 8;
  return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
}

/// Check if contraction operands are in AMX-compatible packed VNNI layout.
static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
                                     vector::ContractionOp contractOp) {
  VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
  if (!accType || accType.getRank() != 2)
    return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");

  // Expect 3D inputs for VNNI packed data.
  VectorType lhsType = contractOp.getLhs().getType();
  VectorType rhsType = contractOp.getRhs().getType();
  if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
    return rewriter.notifyMatchFailure(contractOp,
                                       "Expects lhs and rhs 3D vectors");

  // Check if shapes are compatible with AMX tile.
  if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
      !verifyAmxShape(accType))
    return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");

  // Validate affine maps.
  //
  // Iterators can be ordered arbitrarily. Indexing map positions are based on
  // operands' target shapes.
  // The matrix layouts must match the following:
  //   - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
  //   - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
  //   - matrix C - [M]x[N]
  SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
  AffineMap mapA = indexingMaps[0];
  AffineMap mapB = indexingMaps[1];
  if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
      mapB.getNumResults() != 3)
    return rewriter.notifyMatchFailure(contractOp,
                                       "Invalid input indexing maps");
  FailureOr<linalg::ContractionDimensions> dims =
      linalg::inferContractionDims(indexingMaps);
  if (failed(dims))
    return rewriter.notifyMatchFailure(contractOp,
                                       "Failed to infer contraction dims");
  // Two reduction dimensions are expected:
  //   - one for the K dimension
  //   - one for the VNNI factor
  if (dims->k.size() != 2)
    return rewriter.notifyMatchFailure(contractOp,
                                       "Expected two reduction dims");
  assert(dims->m.size() == 1 && dims->n.size() == 1 &&
         "Invalid parallel contraction dims");

  SmallVector<vector::IteratorType> iteratorTypes =
      contractOp.getIteratorTypesArray();
  // Check VNNI dim maps - the innermost dim for A and B inputs.
  auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
  auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
  if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
      iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
    return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
  // Check K dim maps - non-transposed row-major layout.
  auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
  auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
  if (!redDimA || !redDimB || redDimA != redDimB ||
      iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
    return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
  // Check M and N dim maps - map to non-transposed output.
  AffineMap mapC = indexingMaps[2];
  auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
  auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
  if (!mDimC || !nDimC)
    return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
  auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
  if (!parallelDimA ||
      iteratorTypes[parallelDimA.getPosition()] !=
          vector::IteratorType::parallel ||
      parallelDimA != mDimC)
    return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
  auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
  if (!parallelDimB ||
      iteratorTypes[parallelDimB.getPosition()] !=
          vector::IteratorType::parallel ||
      parallelDimB != nDimC)
    return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");

  return success();
}

/// Validate contraction operands for AMX lowering.
static LogicalResult validateOperands(PatternRewriter &rewriter,
                                      vector::ContractionOp contractOp) {
  VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
  if (!accType)
    return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");

  // Check if operand types are compatible with AMX compute ops.
  bool validElemTypes = false;
  Type lhsElemType = contractOp.getLhs().getType().getElementType();
  Type rhsElemType = contractOp.getRhs().getType().getElementType();
  Type accElemType = accType.getElementType();
  if (accElemType.isInteger(32)) {
    validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
  } else if (accElemType.isF32()) {
    validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
                     (lhsElemType.isBF16() && rhsElemType.isBF16());
  }
  if (!validElemTypes)
    return rewriter.notifyMatchFailure(contractOp,
                                       "Invalid combination of operand types");

  if (failed(isAmxVnniLayout(rewriter, contractOp)))
    return failure();

  return success();
}

/// Collapse the two innermost dimensions together.
static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
                                              TypedValue<MemRefType> memref) {
  int64_t rank = memref.getType().getRank();
  SmallVector<ReassociationIndices> reassocIndices;
  for (auto i : llvm::seq<int64_t>(0, rank - 2))
    reassocIndices.push_back({i});
  reassocIndices.push_back({rank - 2, rank - 1});
  return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
                                         reassocIndices);
}

/// Attempt to create an AMX tile load/store operation equivalent to the given
/// vector transfer `xfer` op.
/// This approach allows to skip longer route through registers and a temporary
/// buffer otherwise required to move data to/from an AMX tile.
static Operation *
loadStoreFromTransfer(PatternRewriter &rewriter,
                      VectorTransferOpInterface xferOp, bool isPacked,
                      TypedValue<amx::TileType> tileToStore = nullptr) {
  if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
    return nullptr;
  if (xferOp.hasOutOfBoundsDim() ||
      !xferOp.getPermutationMap().isMinorIdentity())
    return nullptr;

  // Extra checks in case of a write op.
  // Stores must not be packed.
  if (isa<vector::TransferWriteOp>(xferOp) &&
      (!tileToStore || isPacked ||
       tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
    return nullptr;

  // Check for a memref source buffer.
  // AMX data transfer requires at least 2D shape to correctly
  // infer stride between rows.
  Value base = xferOp.getBase();
  auto memTy = dyn_cast<MemRefType>(base.getType());
  int64_t memRank = memTy.getRank();
  if (!memTy || memRank < 2)
    return nullptr;

  // Check that the source buffer has enough contiguous elements to load whole
  // AMX tile row.
  //
  // To ensure correctness, the validation is conservative and expects the
  // buffer's innermost dimensions to be statically known, equal to or larger
  // than the vector row length, and equal to the VNNI dimension if applicable.
  //
  // This check could be relaxed to accept more arbitrarily shaped buffers as
  // long as there are enough contiguous elements to load a whole row.
  if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
    return nullptr;
  VectorType vecTy = xferOp.getVectorType();
  ArrayRef<int64_t> vecShape = vecTy.getShape();
  ArrayRef<int64_t> memShape = memTy.getShape();
  if (memShape.back() == ShapedType::kDynamic ||
      memShape.back() < vecShape.back())
    return nullptr;
  if (isPacked &&
      (memShape.back() != vecShape.back() ||
       memShape[memShape.size() - 2] == ShapedType::kDynamic ||
       memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
    return nullptr;

  // Load values directly from the buffer to an AMX tile.
  PatternRewriter::InsertionGuard g(rewriter);
  rewriter.setInsertionPoint(xferOp);
  Location loc = xferOp.getLoc();

  // Create a subview of the source buffer based on the transfer op to resolve
  // offsets.
  SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
  int64_t vecRank = vecTy.getRank();
  assert(memRank >= vecRank &&
         "Expects buffer to be the same or greater rank than vector");
  SmallVector<int64_t> shape(memRank - vecRank, 1);
  shape.append(vecShape.begin(), vecShape.end());
  TypedValue<MemRefType> src =
      memref::SubViewOp::create(
          rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
          getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
          .getResult();

  // Collapse the VNNI dimension in case of packing.
  if (isPacked)
    src = collapseLastDim(rewriter, src);
  int64_t rows = vecShape[0];
  int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
                                 std::multiplies<int64_t>());
  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());

  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
  SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);

  Operation *amxTileOp = nullptr;
  if (isa<vector::TransferReadOp>(xferOp)) {
    amxTileOp =
        amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
  } else if (isa<vector::TransferWriteOp>(xferOp)) {
    amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
                                         tileToStore);
  } else {
    llvm_unreachable("unsupported vector transfer op");
  }

  return amxTileOp;
}

/// Attempt to create an AMX tile load operation equivalent to the given
/// vector transfer `readOp`.
/// Returns loaded AMX tile if successful.
static FailureOr<TypedValue<amx::TileType>>
loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
                 bool isPacked) {
  amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
      loadStoreFromTransfer(rewriter, readOp, isPacked));
  if (!loadOp)
    return failure();
  return loadOp.getRes();
}

/// Attempt to create an AMX tile store operation equivalent to the given
/// vector transfer `writeOp`.
static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
                                       vector::TransferWriteOp writeOp,
                                       TypedValue<amx::TileType> tileToStore) {
  return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
                                       tileToStore));
}

/// Load vector values to an AMX tile.
static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
                                          TypedValue<VectorType> vec) {
  Location loc = vec.getLoc();

  VectorType vecTy = vec.getType();
  bool isPacked = vecTy.getRank() == 3;

  // Try to load tile directly from vector producer's buffer.
  auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
  FailureOr<TypedValue<amx::TileType>> tile =
      loadFromTransfer(rewriter, readOp, isPacked);
  if (succeeded(tile))
    return *tile;

  // Transfer the vector to a tile through an intermediate buffer.
  Value buf = memref::AllocaOp::create(
      rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
  SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
  vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);

  // Collapse the VNNI dimension in case of packing.
  if (isPacked)
    buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));

  ArrayRef<int64_t> shape = vecTy.getShape();
  int64_t rows = shape[0];
  int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
                                 std::multiplies<int64_t>());
  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());

  return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
                                 {zeroIndex, zeroIndex});
}

/// Store an AMX tile in a vector.
static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
                                        TypedValue<amx::TileType> tile) {
  Location loc = tile.getLoc();

  // Transfer the tile to a vector through an intermediate buffer.
  amx::TileType tileTy = tile.getType();
  Value buf = memref::AllocaOp::create(
      rewriter, loc,
      MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
  SmallVector<Value> indices(2, zeroIndex);
  amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);

  auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
  return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
}

struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                PatternRewriter &rewriter) const override {
    Location loc = contractOp.getLoc();

    if (contractOp.getKind() != vector::CombiningKind::ADD)
      return rewriter.notifyMatchFailure(contractOp,
                                         "Expects add combining kind");
    if (failed(validateOperands(rewriter, contractOp)))
      return failure();

    TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
    TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
    auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
    assert(acc && "Invalid accumulator type");
    TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);

    TypedValue<amx::TileType> tileMul;
    if (acc.getType().getElementType().isFloat()) {
      tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
                                        lhsTile, rhsTile, accTile);
    } else {
      tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
                                        lhsTile, rhsTile, accTile);
    }

    // If the contraction result is only written back to memory, try to replace
    // the vector op with an AMX store directly.
    Value res = contractOp.getResult();
    if (res.hasOneUse()) {
      auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
      LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
      if (succeeded(storeRes)) {
        rewriter.eraseOp(writeOp);
        rewriter.eraseOp(contractOp);
        return success();
      }
    }

    // Load the result back into a vector.
    Value newResult = storeTile(rewriter, tileMul);
    rewriter.replaceOp(contractOp, newResult);

    return success();
  }
};

struct ConvertVectorToAMXPass
    : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
  void runOnOperation() override {
    MLIRContext &ctx = getContext();
    RewritePatternSet patterns(&ctx);
    populateVectorToAMXConversionPatterns(patterns);
    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
      return signalPassFailure();
  }
};

} // namespace

void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
  patterns.add<ContractionToAMX>(patterns.getContext());
}