aboutsummaryrefslogtreecommitdiff
path: root/mlir/docs/Tutorials/transform/ChH.md
blob: 7c12728ac32466de35e92428fd8392da23b9664a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
# Chapter H: Reproducing Halide Schedule

This chapter demonstrates how a schedule from the [Halide
DSL](http://halide-lang.org) can be implemented using transform dialect for
structured ops.

Note that the IR below is pseudo-code with types removed for brevity. It may
also get out of sync with the current syntax. Always refer to the source code in
[mlir/examples/transform/ChH](https://github.com/llvm/llvm-project/tree/main/mlir/test/Examples/transform/ChH)
as the source of truth.

## Channeled Convolution

The Transform dialect provides a substrate for implementing “transformation
directive” domain-specific languages (DSLs) in MLIR. Such a DSL, at least in its
scheduling part, can target the operations in the Transform dialect that are
later applied by the compiler. Sets of transform operations, or even new
dialects leveraging the same interfaces and infrastructure, can be added to
support a specific DSL for a particular scheduling model. In this chapter, we
will revisit the Halide DSL that has (re)popularized separate specification of
schedules originally for image processing programs.

Two approaches Halide to the Transform dialect are possible:

*   Create a new dialect that corresponds to the computational part of Halide
    DSL, and define a set of transformations wrapped into Transform dialect
    operations, that correspond to the scheduling part of the DSL.
*   Map the Halide abstractions to the existing MLIR abstractions, for both
    parts of the DSL.

We will consider the latter approach as the computational part of the DSL easily
maps to the structured ops in the Linalg dialect. This also gives us the
opportunity to discuss how Linalg transformations on the so-called structured
operations are similar to or different from the existing transformations.

We will consider the 2D channeled convolution example extracted from Halide
[application
examples](https://github.com/halide/Halide/tree/294f80c49bf3bb8582446613c25fcce03b82bcd8/apps/conv_layer).

```cpp
// Sizes of the problem.
const int N = 5, CI = 128, CO = 128, W = 100, H = 80;

// Sized inputs. Note that the order of dimensions is
// inverted in Halide with respect to C++, so the last dimension
// in the list (N for input, CI for filter) is the least
// frequently varying. The C++ equivalent is input[N][H+2][W+2][CI].
Buffer<float, 4> input({CI, W+2, H+2, N}, "input");
Buffer<float, 4> filter({CO, 3, 3, CI}, "filter");
Buffer<float, 1> bias(std::vector<int>{CO}, "bias");

// ... data initialization happens here ...

// Declarations of "mathematical functions" for convolution and relu.
Func conv("conv"), relu("relu");

// Iterators/subscripts.
Var x("x"), y("y"), c("c"), n("n");

// 3D reduction domain (channels and 2 window dimensions),
// dimensions are later referred to as r.x, r.y, r.z.
RDom r(0, CI, 0, 3, 0, 3);

// Core convolution with the result initialized to the bias value.
// Note that the order of iterators is inverted in Halide DSL,
// i.e. `n` corresponds to the lest frequently-varying (outermost) dimension
// here and below.
conv(c, x, y, n) = bias(c);
conv(c, x, y, n) += filter(c, r.y, r.z, r.x) * input(r.x, x + r.y, y + r.z, n);

// ReLU rectification, an elementwise operation.
relu(c, x, y, n) = max(0, conv(c, x, y, n));
```

This can be almost directly converted to Linalg dialect operating on tensors,
which is conceptually closer to the “mathematical function” abstraction and is
where the majority of transformations are available.

```mlir
// Bias. Using a named Linalg operation for brevity.
%bias_init = tensor.empty() : !toutput
%biased = linalg.broadcast ins(%bias : !tbias)
                          outs(%bias_init : !toutput) dimensions = [0, 1, 2]

// Convolution proper. While Linalg has named operations for 2D convolutions,
// the one in the Halide example has an uncommon order of filter dimensions
// and is not supported. It also takes the filter as first argument. This
// code recreates it faithfully using the generic form.
%convolved = linalg.generic {
  iterator_types = ["parallel", "parallel", "parallel", "parallel",
                    "reduction", "reduction", "reduction"],
  indexing_maps = [
    affine_map<(n, y, x, c, rz, ry, rx) -> (rx, rz, ry, c)>,
    affine_map<(n, y, x, c, rz, ry, rx) -> (n, y+rz, x+ry, rx)>,
    affine_map<(n, y, x, c, rz, ry, rx) -> (n, y, x, c)>
  ]
} ins(%filter, %input: !tfilter, !tinput)
  outs(%biased : !toutput) {
^bb0(%in: f32, %f: f32, %b: f32):
  // Note the fastmath attributes that allow operations to be recombined into
  //   %0 = math.fma %in, %f, %b : f32
  // later on and to reorder reductions.
  %m1 = arith.mulf %in, %f  {fastmath = #arith.fastmath<fast>} : f32
  %0 = arith.addf %b, %m1  {fastmath = #arith.fastmath<fast>} : f32
  linalg.yield %0 : f32
} -> !toutput

// ReLU is just a max(0, x).
%c0 = arith.constant 0.0 : f32
%relued = linalg.generic {
  iterator_types = ["parallel", "parallel", "parallel", "parallel"],
  indexing_maps = [
    affine_map<(d0, d1, d2, d3) -> ()>,
    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  ]
} ins(%c0, %convolved : f32, !toutput)
  outs(%output : !toutput) {
^bb0(%cst: f32, %in: f32, %out: f32):
  %0 = llvm.intr.maxnum(%cst, %in) : (f32, f32) -> f32
  linalg.yield %0 : f32
} -> !toutput
```

In Halide, a function such as `conv` may consist of two parts: a “functional”
initialization computation and an in-place update for reductions. This is
expressed as two C++ statements in the embedded DSL, but internally is
represented in a single object. Linalg doesn’t have such a capability to the
initialization and the update are represented as two distinct Linalg operations
that are not connected to each other. Furthermore, the `x`, `y`, `c`, `n`
variables in Halide DSL correspond to implicit loops iterating over the
corresponding objects, which implies that functions sharing these variables in
their definitions also share the corresponding loops. In other words, the loop
equivalent of the Halide definition starts in a fully-fused form. The Linalg
model is the opposite with each structured operation corresponding to its own
loop nest, resulting in a fully-distributed form. This will affect how the
schedule is constructed later on.

The loop structure for Halide computation resembles the following (adapted from
debug dump with `HL_DEBUG_CODEGEN=1`)

```python
for n
  for y
    for x
      for c
        conv[n, y, x, c] = bias[c]
        for rz
          for ry
            for rx
              conv[n, y, x, c] += filter[rx, rz, ry, c] * input[n, y+rz, x+ry, rx]
        relu[n, y, x, c] = max(0, conv[n, y, x, c])
```

The loop structure for the Linalg computation is as follows (obtained by
`mlir-opt --linalg-generalize-named-ops --empty-tensor-to-alloc-tensor
--one-shot-bufferize --convert-linalg-to-loops`)

```python
for n
  for y
    for x
      for c
        init[n, y, x, c] = bias[c]
for n
  for y
    for x
      for c
        for rz
          for ry
            for rx
              conv[n, y, x, c] += filter[rx, rz, ry, c] * input[n, y+rz, x+ry, rx]
for n
  for y
    for x
      for c
        relu[n, y, x, c] = max(0, conv[n, y, x, c])

```

## Mapping Halide Scheduling Primitives to Linalg Structured Transforms

The complete Halide schedule listed in the example is as follows

```cpp
Var co, ci, xo, xi;
relu.split(c, co, ci, vec * tile_w)
  .split(x, xo, xi, tile_h)
  .reorder(ci, xi, xo, y, n, co)
  .vectorize(ci, vec)
  .unroll(ci)
  .unroll(xi)
  .parallel(y)
  .parallel(n)
  .parallel(co);

conv.compute_at(relu, xo)
  .vectorize(c, vec)
  .unroll(c)
  .unroll(x)
  .unroll(y)
  .update()
  .reorder(c, x, y, r.x, r.y, r.z, n)
  .vectorize(c, vec)
  .unroll(c)
  .unroll(x)
  .unroll(y)
  .unroll(r.x, 2);
```

We will consider only the case without parallelization to avoid the difference
in parallel runtimes generated by Halide and used by MLIR. This schedule
corresponds to a sequence of loop manipulations, unrolling and vectorization.
The following directives are present and can be mapped to transformations on
Linalg as described below.

*   `split` decomposes a loop dimension into two immediately nested loops with
    the inner loop having at most the given number of iterations. This can be
    understood as loop _strip-mining_ or a degenerate case of tiling a single
    dimension using any of `linalg.tile_` transform ops. We will be using
    `transform.structured.tile_using_forall` as this kind of loop is best
    supported by bufferization and can also be turned into a parallel loop later
    on. Unlike Halide, this doesn’t add new dimensions to the original
    operation, but rather creates a loop around it and rewrites the operation
    itself to operate on a subset of the original data.
*   `reorder` rearranges the loops arbitrarily. In Linalg representation, loops
    are implicit and are intended to remain so as long as possible to target
    microkernels. The order of implicit loops in a `linalg.generic` operation
    can be changed by using `transform.structured.interchange`, but this does
    not apply to named operations that need to be “generalized” first by calling
    `transform.structured.generalize`. However, this can only reorder implicit
    dimensions and not the explicit loops materialized by tiling operations that
    can no longer be “folded” into the original operation. Instead, we can
    leverage this behavior by materializing loops directly in the desired order
    by “tiling” to size 1.
*   `vectorize` indicates that the given dimension should be vectorized with the
    given factor; if the loop extent is larger than the factor, the loop is
    effectively split into two parts and the inner one is vectorized. On the
    contrary, structured Linalg op vectorization applies as a global
    transformation to all suitable operations at, e.g., a function scope via
    `transform.structured.vectorize_children_and_apply_patterns`. It relies on
    MLIR’s support for multidimensional vectors to directly map multidimensional
    tensors, which are later decomposed into operations on smaller
    hardware-compatible vectors during lowering.
*   `unroll` performs loop unrolling, fully or up to the given factor. It is
    equivalent to `transform.loop.unroll`.
*   `compute_at` indicates that the value of the function must be computed
    within the given loop that will be produced for another function; depending
    on the relation between loops surrounding functions, this corresponds to
    either a loop distribution or a producer/consumer fusion. Given that the
    Linalg representation starts in the fully distributed form, it can be
    represented as a sequence of `transform.structured.fuse_into_containing_op`
    that operates on `forall` loops materialized by tiling beforehand.


## Recreating the Loop Structure

The three first transformation directives for `relu` in the Halide schedule aim
at producing the following loop structure.

```python
for co
  for n
    for y
      for xo
        for xi
          for ci
            relu[n, y, xo*tile_h + xi, co*tile_w*vec + ci] = ...
```

Note that the outer part of the `c` gets hoisted from all of the surrounding
loops. The implicit loop order for the operation is `n, y, x, c`, so the `co`
loop needs to be materialized first in order to achieve the desired reordering.
The remaining dimensions can be materialized as loops in one transformation.

```mlir
    //                                                             [n  y  x  c]
    %co, %relu2 = transform.structured.tile_using_forall %relu
                                                        tile_sizes [0, 0, 0, 64]
    %n_y_xo, %relu3 = transform.structured.tile_using_forall %relu2
                                                        tile_sizes [1, 1, 5, 0]
```

This will result in the following loops being created in the IR with the nested
elementwise operation operating on a smaller subset of original data via
implicit loops.

```mlir
scf.forall (%co) in (2) {
  scf.forall (%n, %y, %xo) in (5, 80, 20) {
    tensor.extract_slice
    // Implicit dimensions [ni=0:1, y=0:1, xi=0:5, ci=0:64]
    %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } // ...
    scf.forall.in_parallel {
      tensor.parallel_insert_slice // ...
    }
  }
}
```

The following loop restructuring transformations are `compute_at` and `reorder`
on the `conv` function that need to happen before loops are destroyed by
unrolling and vectorization. They intend to produce the final desired loop
structure.

```python
for co
  for n
    for y
      for xo
        for xi
          for ci
            conv[n, y, x*tile_h + xi, co*tile_w*vec + ci] = ...
        for rz
          for ry
            for rx
              for xi
                for ci
                  conv[n, y, x*tile_h + xi, co*tile_w*vec + ci] += ...
        for xi
          for ci
            relu[n, y, xo*tile_h + xi, co*tile_w*vec + ci] = ...
```

Practically, this corresponds to fusing the convolution initialization and
update into the `co, n, y, xo` loops materialized by tiling earlier. Structured
op transformation set supports fusing the producer of a value into its consumer,
so fusion happens in two stages:

*   first the main convolution update is fused into ReLU that uses it and has
    loops materialized;
*   then the bias initialization is fused into the convolution+relu loop nest.

Each stage consists of two transformations fusing the computational operation
into the outer loop, then the inner loop.

```mlir
%conv2, %co2 = transform.structured.fuse_into_containing_op %conv into %co
%conv3, %n_y_xo2 = transform.structured.fuse_into_containing_op %conv2
  into %n_y_xo

%bias2, %co3 = transform.structured.fuse_into_containing_op %bias into %co2
%bias3, %n_y_xo3 = transform.structured.fuse_into_containing_op %bias2
  into %n_y_xo2
```

To complete the structure, we need to put the `rz, ry, rx` loops outside the
“tile” loops `xi, ci`. This can be achieved materializing the corresponding
loops from the convolution operation. However, these are reduction loops and it
wouldn’t be valid to materialize them as intrinsically parallel “forall” loops.
Instead, we use the dedicated “reduction tiling” transformation and produce
sequential `scf.for` loops. (`scf.forall` loops can also express parallel
reductions, but the corresponding transformation doesn’t handle reductions along
more than one dimension at the moment of writing.)

```mlir
%rz_ry_rx, %red_fill, %conv4, %comb
  = transform.structured.tile_reduction_using_for %conv3
//               n  y  x  c  rz ry rx
  by tile_sizes=[0, 0, 0, 0, 1, 1, 1]
```

This transformation materializes the desired loops around the convolution
operation. It is also more capable than merely producing (reduction) loops: the
transformed code performs `tile_size` partial reductions of `N / tile_size`
elements, potentially in parallel by changing the dimension kind of the
structured operation inside the loop, and then performs a final reduction of
these partial results by producing a new “combiner” structured operation after
the loops. In our case, `tile_size = 1` along all dimensions, so the reduction
is entirely performed by the generated loops. The combiner structured operation
is still produced and adds up the reduction result with the initial value. This
changes the order of floating point operations (so would reduction tiling with
non-unit size) and may affect the final result due to non-commutativity of these
operations, but is explicitly allowed by `fastmath` flags. Halide also emits
LLVM IR with full `fastmath` flags.

Finally, we need to produce innermost loops `xi` and `ci` that are still not
explicit. As our next step is going to be vectorization along `ci`, we need to
take into account the way it operates on MLIR structured operations: rather than
selecting a specific vector size and loop/dimension to vectorize, it directly
substitutes multidimensional vector types for tensor types and updates the
operations accordingly. Therefore, our tensor type should not become trivial,
i.e. size-1, and retain a `vector_size` sized dimension along the desired axis,
`ci`. This can be achieved by tiling with `vector_size` as tile size in that
dimension:

```mlir
//                                                                  n  y  xi ci
%1, %c5 = transform.structured.tile_using_forall %conv4 tile_sizes [0, 0, 1, 16]
%2, %b4 = transform.structured.tile_using_forall %bias3 tile_sizes [0, 0, 1, 16]
%3, %r4 = transform.structured.tile_using_forall %relu3 tile_sizes [0, 0, 1, 16]
%4, %c2 = transform.structured.tile_using_forall %comb  tile_sizes [0, 0, 1, 16]
```

Note that the combiner operation produced by reduction tiling is also tiled here.


## Explicit Loop Unrolling

The remaining unhandled loop transformation is unrolling. Specifically,
unrolling is requested for the innermost loops that form the 4x5 tile of
16-element vector operations to ensure a contiguous sequence of `vfma`
instructions using 20 512-bit vector registers as accumulators. Unrolling
additional loops,, `unroll(y)` and `unroll(r.x, 2)`, is requested in the
schedule but _has no practical effect_. That is, the code, and all intermediate
representations, produced by Halide with these directives removed is _strictly
identical_ to the code with the full schedule. Therefore, we will only unroll
the corresponding loops corresponding to `xi` and `ci` dimensions that actually
get unrolled by Halide.

As tiling in the transform dialect produces handles to the loops materialized by
tiling, unrolling those loops is just a matter of chaining the corresponding
transformation. Note that the inner loop must be unrolled first as unrolling the
outer loop will invalidate the handles to the inner loop.

```mlir
transform.loop.unroll %bias_ci {factor = 4}
transform.loop.unroll %bias_xi {factor = 5}
transform.loop.unroll %conv_ci {factor = 4}
transform.loop.unroll %conv_xi {factor = 5}
transform.loop.unroll %relu_ci {factor = 4}
transform.loop.unroll %relu_xi {factor = 5}
transform.loop.unroll %comb_ci {factor = 4}
transform.loop.unroll %comb_xi {factor = 5}
```

## Vectorization

These transformations produced the desired loop structure and we are now ready
to vectorize. Before proceeding it is desirable to simplify the code as tiling
and fusion may have produced a lot of operations computing tensor subsets and
loop ranges, some of which may be duplicated or excessively complex.
Simplification involving canonicalization, common subexpression elimination,
loop invariant code motion and various rewrite patterns can be applied directly
from the transform dialect. Furthermore, an arbitrary combination of rewrite
patterns can be applied _in one sweep_ to a given scope, a functionality that
_cannot be achieved with conventional compiler passes_ that apply each group of
patterns separately (at least without creating a new pass for each combination
of pattern groups).

```mlir
%f00 = transform.structured.match ops{["func.func"]} in %arg0
transform.apply_patterns to %f00 {
  transform.apply_patterns.canonicalization
  transform.apply_patterns.linalg.tiling_canonicalization
}
transform.apply_cse to %f00

%all_loops = transform.structured.match interface{LoopLikeInterface} in %arg0
transform.apply_licm to %all_loops
```

One final simplification is necessary to produce good vectorized code.
Tiling-by-one as a way of materializing loops produced structured (`linalg`)
operations processing 4D types where only one dimension isn’t unit-sized, e.g.,
`tensor<1x1x1x16xf32>` where 16 is the vector size corresponding to AVX512,
as structured tiling doesn’t modify the rank of the operation in order to
preserve the original structure. Even though the core computation is the same,
the produced code may end up more complicated than necessary, in particular when
decomposing multidimensional vectors into single-dimensional vectors supported
by hardware. Such unit dimensions can be explicitly folded away using the
corresponding pattern set before vectorization.

```mlir
transform.apply_patterns to %f00 {
  transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes
}

%fv = transform.structured.vectorize_children_and_apply_patterns %f00
```

This produces the desired code performing arithmetic operations on
`vector<16xf32>` types that can be easily lowered to AVX512 instructions by the
downstream compiler. Vectorization may have created new opportunities for code
simplification, in particular combining tensor subsetting and vector slicing
operations. Another round of simplification can be applied post vectorization.

```mlir
transform.apply_patterns to %fv {
  transform.apply_patterns.canonicalization
  transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers
}
transform.apply_cse to %fv
transform.structured.hoist_redundant_vector_transfers %fv
```

## Lowering to LLVM and The Bufferization Hurdle

With the loop restructuring done, the program now needs to be converted to the
executable form. The first step in doing so is _bufferization_, the process that
associates a memory buffer with every tensor in the payload IR. MLIR’s one-shot
bufferization is directly available as a transform operation.

```mlir
%arg1 = transform.bufferization.one_shot_bufferize %arg0 {
  bufferize_function_boundaries = true,
  function_boundary_type_conversion = 1 : i32 }
```

One-shot bufferization itself does not produce buffer deallocations, which may
lead to leaks. So we have to run the buffer deallocation pass pipeline to avoid
them. Note that the transform dialect seamlessly runs named passes and pass
pipelines: if desired, one could replace complex `--pass-pipeline expressions`
with operations. Note that we apply the pipeline to functions rather than entire
module to avoid running it on the transform IR that is contained in the module.

```mlir
%f = transform.structured.match ops{["func.func"]} in %arg1
  : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "buffer-deallocation-pipeline" to %f
  : (!transform.any_op) -> !transform.any_op
```

In this particular case, the transformed IR could be directly bufferized. This
is not always the case in general as some operations, in particular
`tensor.empty` may not be bufferizable. Such operations need to be removed
before running the bufferization, which can often be achieved by sufficient
fusion (as in our case), or by running dedicated transformations
`transform.bufferization.eliminate_empty_tensors` that removes the
`tensor.empty` operations only serving for defining the size of a computation or
`transform.bufferization.empty_tensor_to_alloc_tensor` that materializes a new
temporary buffer for empty tensors to be used as local caches.

```mlir
// Apply general canonicalization and CSE to each function after
// bufferization as new simplification opportunities may have appeared.
%fb = transform.structured.match ops{["func.func"]} in %arg1
transform.apply_patterns to %fb {
  transform.apply_patterns.canonicalization
}
transform.apply_cse to %fb

// Lower complex, multidimensional vector operations into simpler
// primitives. This particular selection of the pattern groups corresponds
// to vector dialect operations present in the payload IR at this stage.
// Many of these groups can be parameterized to use different strategies or
// lower-level primitives offering performance trade-offs. In this case, we
// are selecting the simplest strategies.
transform.apply_patterns to %fb {
  transform.apply_patterns.vector.lower_contraction
    lowering_strategy = parallelarith
  transform.apply_patterns.vector.lower_transfer
    max_transfer_rank = 1
  transform.apply_patterns.vector.lower_transpose
    lowering_strategy = eltwise
  transform.apply_patterns.vector.lower_shape_cast
}

// These patterns apply in a separate sweep to avoid transfer-to-scf
// patterns overlap with lower-transfer patterns as they apply to the same
// kind of operations. These patterns may produce local allocations to act
// as temporary caches deep inside loops, which could lead to catastrophic
// performance. Such allocations are moved onto the stack and hoisted from
// all the surrounding loops.
transform.apply_patterns to %fb {
  transform.apply_patterns.vector.transfer_to_scf
  transform.apply_patterns.memref.alloc_to_alloca
  }
transform.bufferization.buffer_loop_hoisting %fb

// A final round of cleanups additionally includes patterns to simplify
// buffer aliasing operations that may have been introduced during
// bufferization and could result in excessively complex address
// computation.
transform.apply_patterns to %fb {
  transform.apply_patterns.memref.fold_memref_alias_ops
  transform.apply_patterns.canonicalization
}
transform.apply_cse to %fb
```

Due to its inter-procedural nature, one-bufferization processes the entire
payload module and thus invalidates all previously created handles. Therefore,
it is typically a late step in the transformation sequence where precise
targeting of transformation is no longer required. The following transformations
are typically module- or function-wide rewrites that are often pattern-based
lowerings. This part of the sequence can be seen as a pass pipeline specified
directly in the transform dialect, with pattern-based lowering passes
constructed _on-the-fly_ from named groups of patterns.

The resulting IR can be further completely lowered to the LLVM dialect, then to
LLVM IR and processed by the LLVM compiler to produce an executable or JITted.

The generated code runs in ~420ms on an Intel processor with Skylake
microarchitecture clocked at 2.0GHz. Given that the computation performs
$5*80*100*128*(2*3*3*128 + 2) ~= 5.9 * 10^9$ floating point operations, it
reaches ~14 GFlops. With 1 FMA unit available, the single-core performance of
the test processor is 64 GFlops $16 * 2 * 2 * 10^9$, where 16 is the vector
width), so only 22% of the theoretical peak is achieved.

The code produced by Halide runs in ~120ms on the same processor, a 3.5x
improvement and 77% of peak. Let us analyze the generated assembly to understand
the source of the difference. The main computational effort is expected to
happen around floating point multiplications and additions in the convolution.
In both cases, the assembly features AVX512 `vfma231ps` instructions operating
on `%zmm` 512-bit vector registers. In the MLIR-generated code, they are
interspersed with memory accesses loading _two _of the `fma` operands before
each operation and leading to increased latency.

```asm
vmovups       -192(%r10), %zmm0
vbroadcastss  -1536(%rdi,%r9), %zmm1
vmovups       112(%rsp), %zmm2
vfmadd231ps   %zmm1, %zmm0, %zmm2     # zmm2 = (zmm0 * zmm1) + zmm2
vmovups       %ymm2, 112(%rsp)
vextractf64x4 $1, %zmm2, 144(%rsp)
// 19 more blocks of either
//  (a) vmovups,vbroadcast,vfma(z,z),vextract,
//  (b) vbroadcast,vfma(z,mem),vextract
```

The Halide-generated code however features compact blocks of `vfma231ps` and
`vbroadcastss` loading one of the operands while the other two are resident in
registers and loaded before `fma`.

```asm
vbroadcastss    -1536(%rsi,%rbx), %zmm25
vmovups         -192(%rdi), %zmm26
vmovups         -128(%rdi), %zmm27
vmovups         -64(%rdi), %zmm28
vmovups         (%rdi), %zmm29
vfmadd231ps     %zmm25, %zmm26, %zmm24  # zmm24 = (zmm26 * zmm25) + zmm24
vfmadd231ps     %zmm25, %zmm27, %zmm23  # zmm23 = (zmm27 * zmm25) + zmm23
vfmadd231ps     %zmm25, %zmm28, %zmm22  # zmm22 = (zmm28 * zmm25) + zmm22
vfmadd231ps     %zmm25, %zmm29, %zmm21  # zmm21 = (zmm29 * zmm25) + zmm21
vbroadcastss    -1024(%rsi,%rbx), %zmm25
vfmadd231ps     %zmm25, %zmm26, %zmm20  # zmm20 = (zmm26 * zmm25) + zmm20
vfmadd231ps     %zmm25, %zmm27, %zmm19  # zmm19 = (zmm27 * zmm25) + zmm19
vfmadd231ps     %zmm25, %zmm28, %zmm18  # zmm18 = (zmm28 * zmm25) + zmm18
vfmadd231ps     %zmm25, %zmm29, %zmm17  # zmm17 = (zmm29 * zmm25) + zmm17
vbroadcastss    -512(%rsi,%rbx), %zmm25

// 3 more blocks of 4 vfmadd231 followed by a vbroadcast
```

Inspecting the progressive intermediate representations produced by MLIR, one
can observe the load(transfer)/fma interspersing at all levels starting after
schedule application. The repeated tensor subsetting operations, that are later
transformed into vector transfer operations, and vector memory loads, are
produced by loop unrolling that was explicitly requested in the schedule! The
issue is the single-assignment model of tensors (and vectors) that results in
long and complex chains of access and update operations that become so long that
the lower-level transformations and the downstream compiler can no longer
simplify them. In fact, unrolling loops early in the transformation sequence can
lead to all sorts of compiler-performance related problems (including the
compiler failing to perform some optimizations due to excessive code length) in
the process.

It is therefore desirable to perform loop unrolling at a later stage,
specifically after bufferization and relevant simplification. However,
bufferization invalidates all loop handles including to loops that we are
willing to unroll. This hurdle can be overcome by matching the payload IR
operations after bufferization to produce new handles. We will first change the
kind of loops produced in the schedule from `scf.for` to `scf.forall` to have
less operations to match by using `transform.structured.tile_using_forall`
instead of `transform.structured.tile` when tiling with sizes `[0, 0, 1, 16]`.
Then we can match all `scf.forall` operations in the payload IR and transform
them into single-iterator `scf.for` loops _after bufferization_.

```mlir
%foralls = transform.structured.match ops{["scf.forall"]} in %arg1
%xi_bias, %ci_bias = transform.loop.forall_to_for %xi_ci_bias
%xi_conv, %ci_conv = transform.loop.forall_to_for %xi_ci_conv
%xi_relu, %ci_relu = transform.loop.forall_to_for %xi_ci_relu
%xi_comb, %ci_comb = transform.loop.forall_to_for %xi_ci_comb
```

We can then move our loop unrolling transformations later in the transformation
sequence as desired. Compiling this new version to assembly produces exactly the
same core computation around `vfmadd231ps` as Halide’s version, which only
differs slightly in allocated registers. Unsurprisingly, this version runs
roughly in 120ms on the same machine.


## Multi-Dimensional Vectors to the Rescue

While we managed to produce similar code to Halide in the previous section, we
did so by rematching generated loops after bufferization, which partially defies
the purpose of using handles to chain transformations in the Transform dialect.
Luckily, this step is not really necessary. It only served as an exercise in
producing the desired loop structure.

Multidimensional structured operations on vectors are lowered to target-specific
vectors by unrolling and splitting. For example, an elementwise arithmetic
operation on `vector<5x64xf32>` is replaced with 5 operations on
`vector<64xf32>` and additional vector value manipulations to recreate the
required type at the MLIR level. Each of these operations is then split into 4
operations on `vector<16xf32>` at the LLVM level where the information about
the target vector width becomes available. Collectively, this has exactly the
same effect as first materializing the 5x4 loop nest, and then fully unrolling
these loops. Therefore, the last stage of tiling, re-matching and unrolling can
be removed from the schedule.

The resulting assembly has all `vbroadcast` grouped together before `vfmadd231`
but otherwise has a similar structure. This grouping is due to each
multi-dimensional vector operation being “unrolled” separately. When executed,
it runs in ~110ms, a slight improvement of 8% over both the previous version and
Halide, and reaches ~53.7 GFlop/s or 84% of peak single-core performance. The
improvement is largely due to the intermediate representation being shorter and
simpler in presence of large-vector operations, which allowed for more
aggressive address computation and load placement optimization.

The final transformation strategy is checked into the repository at
[mlir/examples/transform/ChH/full.mlir](
https://github.com/llvm/llvm-project/tree/main/mlir/test/Examples/transform/ChH/full.mlir).