aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
blob: 8116044594fca5784e42ef8d8154928e100133b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s

func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %gemm : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b, %c = transform.structured.fuse %matmul tile_sizes [10, 20]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}
//      CHECK: func.func @gemm_fill_fusion(
// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
//      CHECK:   %[[INIT:.+]] = tensor.empty
//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME:           outs(%[[INIT_TILE]] :
//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME:           outs(%[[FILL_TILE]] :
//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
//      CHECK:       scf.yield %[[INSERT]]

// -----

func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
    %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
  %generic = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
      %add = arith.addf %b0, %b1 : f32
      linalg.yield %add : f32
  } -> tensor<?x?xf32>
  return %generic : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}
//      CHECK: func.func @gemm_generic_fusion(
// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
//      CHECK:   %[[INIT:.+]] = tensor.empty
//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME:           outs(%[[INIT_TILE]] :
//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME:           outs(%[[FILL_TILE]] :
//  CHECK-DAG:       %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
//  CHECK-DAG:       %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
// CHECK-SAME:           ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
// CHECK-SAME:           outs(%[[OUTS_TILE]] :
//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
//      CHECK:       scf.yield %[[INSERT]]

// -----

func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
  %init0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %gemm0 = linalg.matmul
      ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
  %init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32>
  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %gemm1 = linalg.matmul  
      ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %gemm1 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %mm1, %mm2 = transform.split_handle %matmuls
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    %a, %b = transform.structured.fuse %mm2 tile_sizes [10]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    transform.yield
  }
}
//      CHECK: func.func @gemm_gemm_fusion(
// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]]
//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]]
//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
//      CHECK:   %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D2]])
//      CHECK:   scf.for %[[IV:[a-zA-Z0-9]+]] =
// CHECK-SAME:       iter_args(%[[ITERARG:.+]] = %[[INIT1]])
//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0]
//      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
// CHECK-SAME:         outs(%[[INIT0_TILE]] :
//      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
// CHECK-SAME:         outs(%[[FILL0_TILE]] :
//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
//      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
// CHECK-SAME:         outs(%[[INIT1_TILE]] :
//      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
// CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
// CHECK-SAME:         outs(%[[FILL1_TILE]] :
//      CHECK:     %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
//      CHECK:     scf.yield %[[INSERT]]

// -----

func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
  %init0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %gemm = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
  %init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32>
  %transpose = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
    ^bb0(%b0 : f32, %b1 : f32):
      linalg.yield %b0 : f32
  } -> tensor<?x?xf32>
  return %transpose : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}
//      CHECK: func.func @gemm_transpose_fusion(
// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty(%[[D0]], %[[D1]])
//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty(%[[D1]], %[[D0]])
//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT1]])
//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
//  CHECK-DAG:       %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]]
//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME:           outs(%[[INIT0_TILE]] :
//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME:           outs(%[[FILL_TILE]] :
//  CHECK-DAG:       %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
// CHECK-SAME:           ins(%[[GEMM_TILE]] :
// CHECK-SAME:           outs(%[[OUTS_TILE]] :
//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
//      CHECK:       scf.yield %[[INSERT]]

// -----

func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
  %cst = arith.constant 0.0 : f32
  %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %3 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
      ^bb0(%b0 : f32, %b1 : f32):
        %4 = arith.addf %b0, %b0 : f32
        linalg.yield %4 : f32
      } -> tensor<?x?xf32>
  return %3 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20] interchange[1, 0]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}
//      CHECK: func.func @interchange_matmul_fusion(
// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
//      CHECK:   %[[INIT:.+]] = tensor.empty
//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
// CHECK-SAME:           outs(%[[INIT_TILE]] :
//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
// CHECK-SAME:           outs(%[[FILL_TILE]] :
//      CHECK:       %[[INIT_TILE_2:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
// CHECK-SAME:           ins(%[[GEMM_TILE]] :
// CHECK-SAME:           outs(%[[INIT_TILE_2]] :
//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
//      CHECK:       scf.yield %[[INSERT]]

// -----

func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
                         %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
  %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
  %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
  %5 = tensor.empty(%3, %4) : tensor<?x?xf32>
  %6 = linalg.generic
    {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                      affine_map<(d0, d1) -> (d0, d1)>,
                      affine_map<(d0, d1) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel"]}
    ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%5 : tensor<?x?xf32>) {
    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
      %7 = arith.addf %arg3, %arg4 : f32
      linalg.yield %7 : f32
    } -> tensor<?x?xf32>
  return %6 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}
//       CHECK: func @matmul_plus_matmul
//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
//  CHECK-SAME:     iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
//       CHECK:     %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
//  CHECK-SAME:       iter_args(%[[ARG6:.+]] = %[[ARG4]])
//   CHECK-DAG:       %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
//   CHECK-DAG:       %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
//   CHECK-DAG:       %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
//       CHECK:       %[[MATMUL:.+]] = linalg.matmul
//  CHECK-SAME:         ins(%[[ST_ARG0]], %[[ST_ARG1]] :
//  CHECK-SAME:         outs(%[[ST_ARG2]] :
//       CHECK:       %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
//       CHECK:       %[[ST_RESULT:.+]] = linalg.generic
//  CHECK-SAME:         ins(%[[MATMUL]], %[[MATMUL]] :
//  CHECK-SAME:         outs(%[[ST_ARG6]] :
//       CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
//  CHECK-SAME:         into %[[ARG6]][%[[IV0]], %[[IV1]]]
//       CHECK:       scf.yield %[[UPDATE]]
//       CHECK:     scf.yield %[[YIELD]]
//       CHECK:   return %[[RESULT]]

// -----

func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
                         %arg2: tensor<?x?xf32>) -> tensor<?x?xf32>{
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
  %1 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %3 = tensor.dim %2, %c0 : tensor<?x?xf32>
  %4 = tensor.dim %2, %c1 : tensor<?x?xf32>
  %5 = tensor.empty(%3, %4) : tensor<?x?xf32>
  %6 = linalg.generic
    {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                      affine_map<(d0, d1) -> (d1, d0)>,
                      affine_map<(d0, d1) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel"]}
    ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%5 : tensor<?x?xf32>) {
    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
      %7 = arith.addf %arg3, %arg4 : f32
      linalg.yield %7 : f32
    } -> tensor<?x?xf32>
  return %6 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b, %c = transform.structured.fuse %generic tile_sizes [10, 20]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}
//       CHECK: func @matmul_plus_transpose_matmul
//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//       CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]]
//  CHECK-SAME:     iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}})
//       CHECK:     %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]]
//  CHECK-SAME:       iter_args(%[[ARG6:.+]] = %[[ARG4]])
//   CHECK-DAG:       %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
//   CHECK-DAG:       %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
//   CHECK-DAG:       %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]]
//       CHECK:       %[[LHS:.+]] = linalg.matmul
//  CHECK-SAME:         ins(%[[ST_ARG0]], %[[ST_ARG1]]
//  CHECK-SAME:           : tensor<?x?xf32>, tensor<?x?xf32>)
//  CHECK-SAME:         outs(%[[ST_ARG2]] : tensor<?x?xf32>)
//   CHECK-DAG:       %[[STR_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
//   CHECK-DAG:       %[[STR_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
//   CHECK-DAG:       %[[STR_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV0]]]
//       CHECK:       %[[RHS:.+]] = linalg.matmul
//  CHECK-SAME:         ins(%[[STR_ARG0]], %[[STR_ARG1]] :
//  CHECK-SAME:         outs(%[[STR_ARG2]] :
//       CHECK:       %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]]
//       CHECK:       %[[ST_RESULT:.+]] = linalg.generic
//  CHECK-SAME:         ins(%[[LHS]], %[[RHS]] :
//  CHECK-SAME:         outs(%[[ST_ARG6]] :
//       CHECK:       %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]]
//  CHECK-SAME:         into %[[ARG6]][%[[IV0]], %[[IV1]]]
//       CHECK:       scf.yield %[[UPDATE]]
//       CHECK:     scf.yield %[[YIELD]]
//       CHECK:   return %[[RESULT]]

// -----

func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
    %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>,
    %arg5: tensor<?x?xf32>, %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
  %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
  %2 = linalg.matmul ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
    outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
  return %2 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %mm1, %mm2, %mm3 = transform.split_handle %matmuls
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    %a, %b = transform.structured.fuse %mm3 tile_sizes [10]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    transform.yield
  }
}
//       CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
//       CHECK: func @matmul_sequence_fusion(
//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG5:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
//  CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//   CHECK-DAG:   %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
//   CHECK-DAG:   %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] :
//   CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
//   CHECK-DAG:   %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
//   CHECK-DAG:   %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
//       CHECK:   %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
//  CHECK-SAME:       iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
//   CHECK-DAG:     %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
//   CHECK-DAG:     %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
//   CHECK-DAG:     %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%[[M]]]
//   CHECK-DAG:     %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]]
//   CHECK-DAG:     %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
//   CHECK-DAG:     %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]]
//   CHECK-DAG:     %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] :
//  CHECK-SAME:         outs(%[[SLICE_ARG2]] :
//   CHECK-DAG:     %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]]
//   CHECK-DAG:     %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]]
//   CHECK-DAG:     %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] :
//  CHECK-SAME:         outs(%[[SLICE_ARG4]] :
//   CHECK-DAG:     %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]]
//   CHECK-DAG:     %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
//   CHECK-DAG:     %[[TILE_GEMM3:.+]] = linalg.matmul
//  CHECK-SAME:         ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] :
//  CHECK-SAME:         outs(%[[SLICE_ARG6]] :
//       CHECK:     %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]]
//       CHECK:     scf.yield %[[UPDATE]]

// -----

func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 0xFF800000 : f32
  %0 = tensor.empty() : tensor<30xf32>
  %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
  %2 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
      iterator_types = ["parallel", "reduction"]}
      ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) {
    ^bb0(%arg1: f32, %arg2: f32):
      %8 = arith.maximumf %arg2, %arg1 : f32
      linalg.yield %8 : f32
    } -> tensor<30xf32>
  %3 = tensor.empty() : tensor<30x3xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
  %5:2 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
                       affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "reduction"]}
      ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) {
    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):
      %8 = arith.subf %arg1, %arg2 : f32
      %9 = math.exp %8 : f32
      %10 = arith.addf %arg3, %9 : f32
      linalg.yield %10, %9 : f32, f32
    } -> (tensor<30xf32>, tensor<30x3xf32>)
  %6 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
                       affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"]}
      ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) {
    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
      %8 = arith.divf %arg1, %arg2 : f32
      linalg.yield %8 : f32
    } -> tensor<30x3xf32>
  return %6 : tensor<30x3xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generics = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %generic1, %generic2, %generic3 = transform.split_handle %generics
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
    %a, %b = transform.structured.fuse %generic3 tile_sizes [10]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    transform.yield
  }
}
//       CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>)
//   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<30xf32>
//   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32>
//       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV:[a-zA-Z0-9]+]]
//  CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]])
//   CHECK-DAG:     %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0]
//   CHECK-DAG:     %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]]
//       CHECK:     %[[FILL0:.+]] = linalg.fill
//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
//       CHECK:     %[[GENERIC0:.+]] = linalg.generic
//  CHECK-SAME:         ins(%[[ARG0_SLICE]] :
//  CHECK-SAME:         outs(%[[FILL0]] :
//       CHECK:     %[[FILL1:.+]] = linalg.fill
//  CHECK-SAME:         outs(%[[INIT0_SLICE]] :
//       CHECK:     %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
//       CHECK:     %[[GENERIC1:.+]]:2 = linalg.generic
//  CHECK-SAME:         ins(%[[ARG0_SLICE]], %[[GENERIC0]] :
//  CHECK-SAME:         outs(%[[FILL1]], %[[INIT1_SLICE]] :
//       CHECK:     %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
//       CHECK:     %[[GENERIC2:.+]] = linalg.generic
//  CHECK-SAME:         ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 :
//  CHECK-SAME:         outs(%[[ITERARG0_SLICE]] :
//   CHECK-DAG:     %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0]
//       CHECK:     scf.yield %[[INSERTSLICE]]
//       CHECK:   return %[[RESULT]]

// -----

func.func @pad_producer_fusion(%arg0 : tensor<10xf32>) -> tensor<16xf32> {
  %0 = tensor.empty() : tensor<10xf32>
  %1 = linalg.generic {
      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
      iterator_types = ["parallel"]}
      ins(%arg0 : tensor<10xf32>) outs(%0 : tensor<10xf32>) {
    ^bb0(%b0 : f32, %b1 : f32):
      %2 = arith.addf %b0, %b0: f32
      linalg.yield %2 : f32
  } -> tensor<10xf32>
  %cst = arith.constant 0.0 : f32
  %2 = tensor.pad %1 low[4] high[2] {
    ^bb0(%arg1 : index):
      tensor.yield %cst : f32
  } : tensor<10xf32> to tensor<16xf32>
  return %2 : tensor<16xf32>
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %pad = transform.structured.match ops{["tensor.pad"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b = transform.structured.fuse %pad tile_sizes [8]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    transform.yield
  }
}
// CHECK-LABEL: func @pad_producer_fusion
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<10xf32>
//       CHECK:   %[[FOR_RESULT:.+]] = scf.for
//       CHECK:     %[[IF_RESULT:.+]] = scf.if
//       CHECK:     else
//       CHECK:       %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
//       CHECK:       %[[GENERIC:.+]] = linalg.generic
//  CHECK-SAME:           ins(%[[SLICE]] :
//       CHECK:       %[[PAD:.+]] = tensor.pad %[[GENERIC]]
//       CHECK:       %[[CAST:.+]] = tensor.cast %[[PAD]]
//       CHECK:       scf.yield %[[CAST]]
//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
//       CHECK:     scf.yield %[[INSERT_SLICE]]
//       CHECK:   return %[[FOR_RESULT]]

// -----

func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
  %0 = linalg.unpack %source
      outer_dims_perm = [0, 1, 2]
      inner_dims_pos = [1, 2]
      inner_tiles = [8, 4] into %dest
      : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
  %1 = tensor.empty() : tensor<1x2x1152xf32>
  %cst = arith.constant 1.0 : f32
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
                       iterator_types = ["parallel", "parallel", "parallel"]}
                       ins(%0 : tensor<1x2x1152xf32>)
                       outs(%1 : tensor<1x2x1152xf32>) {
  ^bb0(%in: f32, %out: f32):
    %7 = arith.addf %in, %cst : f32
    linalg.yield %7 : f32
  } -> tensor<1x2x1152xf32>
  return %2 : tensor<1x2x1152xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %b = transform.structured.fuse %matmul tile_sizes [0, 1, 0]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    transform.yield
  }
}

// CHECK-LABEL: func @imperfect_unpack_producer_fusion
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
//  CHECK-SAME:     %[[ARG1:.+]]: tensor<1x2x1152xf32>
//       CHECK:   %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
//       CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
//       CHECK:     %[[UNPACK:.+]] = linalg.unpack %[[SLICE]]
//   CHECK-DAG:     %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
//       CHECK:     %[[GENERIC:.+]] = linalg.generic
//  CHECK-SAME:         ins(%[[UNPACK_SLICE]]
//  CHECK-SAME:         outs(%[[INIT_SLICE]]
//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
//       CHECK:     scf.yield %[[INSERT_SLICE]]
//       CHECK:   return %[[FOR_RESULT]]

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
module {
  func.func private @tile_one_consumer_using_tile_and_fuse(%arg0: tensor<16x128x48x96xf32>, %arg1: tensor<16x96x48x128xf32>) -> tensor<16x96x48x128xf32> {
    %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<16x128x48x96xf32>) outs(%arg1 : tensor<16x96x48x128xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<16x96x48x128xf32>
    return %0 : tensor<16x96x48x128xf32>
  }
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
      : (!transform.any_op) -> !transform.any_op
    %a, %loops:4 = transform.structured.fuse %generic tile_sizes [1, 16, 16, 16] interchange [0, 1, 2, 3]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
    transform.yield
  }
}

// CHECK:           func.func private @tile_one_consumer_using_tile_and_fuse(%[[VAL_0:.*]]: tensor<16x128x48x96xf32>, %[[VAL_1:.*]]: tensor<16x96x48x128xf32>) -> tensor<16x96x48x128xf32> {
// CHECK:             %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK:             %[[VAL_3:.*]] = arith.constant 16 : index
// CHECK:             %[[VAL_4:.*]] = arith.constant 128 : index
// CHECK:             %[[VAL_5:.*]] = arith.constant 48 : index
// CHECK:             %[[VAL_6:.*]] = arith.constant 96 : index
// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
// CHECK:             %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_2]] to %[[VAL_3]] step %[[VAL_7]] iter_args(%[[VAL_10:.*]] = %[[VAL_1]]) -> (tensor<16x96x48x128xf32>) {
// CHECK:               %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (tensor<16x96x48x128xf32>) {
// CHECK:                 %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<16x96x48x128xf32>) {
// CHECK:                   %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_2]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<16x96x48x128xf32>) {
// CHECK:                     %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_12]], %[[VAL_15]], %[[VAL_18]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<16x128x48x96xf32> to tensor<1x16x16x16xf32>
// CHECK:                     %[[VAL_21:.*]] = tensor.extract_slice %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_18]], %[[VAL_15]], %[[VAL_12]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<16x96x48x128xf32> to tensor<1x16x16x16xf32>
// CHECK:                     %[[VAL_22:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_20]] : tensor<1x16x16x16xf32>) outs(%[[VAL_21]] : tensor<1x16x16x16xf32>) {
// CHECK:                     ^bb0(%[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
// CHECK:                       linalg.yield %[[VAL_23]] : f32
// CHECK:                     } -> tensor<1x16x16x16xf32>
// CHECK:                     %[[VAL_25:.*]] = tensor.insert_slice %[[VAL_26:.*]] into %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_18]], %[[VAL_15]], %[[VAL_12]]] [1, 16, 16, 16] [1, 1, 1, 1] : tensor<1x16x16x16xf32> into tensor<16x96x48x128xf32>
// CHECK:                     scf.yield %[[VAL_25]] : tensor<16x96x48x128xf32>
// CHECK:                   }
// CHECK:                   scf.yield %[[VAL_27:.*]] : tensor<16x96x48x128xf32>
// CHECK:                 }
// CHECK:                 scf.yield %[[VAL_28:.*]] : tensor<16x96x48x128xf32>
// CHECK:               }
// CHECK:               scf.yield %[[VAL_29:.*]] : tensor<16x96x48x128xf32>
// CHECK:             }
// CHECK:             return %[[VAL_30:.*]] : tensor<16x96x48x128xf32>
// CHECK:           }
// CHECK:         }