1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
|
//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace quant {
#define GEN_PASS_DEF_LOWERQUANTOPS
#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
namespace {
// If 'inputType' is a tensor, return its element type. If it is a scalar,
// return it as is.
Type getScalarType(Type inputType) {
if (auto tensorType = dyn_cast<TensorType>(inputType))
return tensorType.getElementType();
return inputType;
}
// Return the shape of an input value as a list of attributes (static
// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
// list is returned. If 'input' is a tensor, its shape is returned.
SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
Location loc, Value input) {
if (isa<TensorType>(input.getType()))
return tensor::getMixedSizes(builder, loc, input);
return {};
}
// If 'referenceType' is a scalar, return 'elementType' as is. If
// 'referenceType' is a tensor, return another tensor with the same shape and
// elements of type 'elementType'.
Type getScalarOrTensorType(Type elementType, Type referenceType) {
if (auto tensorType = dyn_cast<TensorType>(referenceType))
return tensorType.clone(elementType);
return elementType;
}
// Return a constant with the given value. If 'referenceType' is a tensor, a
// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
// scalar, 'referenceShape' is ignored and a scalar constant is returned.
Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
Type referenceType,
ArrayRef<OpFoldResult> referenceShape) {
// If the result type is a scalar, return the unmodified scalar constant.
auto tensorType = dyn_cast<TensorType>(referenceType);
if (!tensorType) {
assert(referenceShape.empty());
return scalar;
}
// Create tensor splat
auto tensorConstant =
tensor::SplatOp::create(builder, loc, scalar, referenceShape);
return tensorConstant;
}
// Reshape an unranked tensor into a 1D ranked tensor.
//
// - input
// Unranked tensor.
//
// Return values:
//
// - flatInput
// 1D ranked, dynamically shaped tensor.
//
// - inputShape
// 1D extent tensor containing the shape of the original unranked input.
//
std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
Value input) {
// Get unranked input shape and total size
auto *context = builder.getContext();
auto shapeType = shape::getExtentTensorType(context);
auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
Value inputSize = shape::NumElementsOp::create(
builder, loc, builder.getIndexType(), inputShape);
// Turn input size into 1D tensor
auto flatShapeType = shape::getExtentTensorType(context, 1);
auto flatInputShape =
tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
// Reshape input tensor into 1D
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType =
RankedTensorType::get({ShapedType::kDynamic}, elementType);
auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
flatInputShape);
return std::make_pair(flatInput, inputShape);
}
// Reshape an unranked tensor into a 3D ranked tensor where the central
// dimension of the result tensor corresponds to dimension 'axis' of the input
// tensor.
//
// - input
// Unranked tensor.
//
// - axis
// Index of the input dimension around which other input dimiensions will be
// collapsed.
//
// - axisSize
// Size of input dimension 'axis'.
//
// Return values:
//
// - flatInput
// 3D ranked tensor of shape [?, axisSize, ?].
//
// - inputShape
// 1D extent tensor containing the shape of the original unranked input.
//
std::pair<Value, Value>
flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
int64_t axis, int64_t axisSize) {
// Get full tensor shape
auto *context = builder.getContext();
auto indexType = builder.getIndexType();
auto shapeType = shape::getExtentTensorType(context);
auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
// Get shape and sizes on left and right of axis
auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis);
auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1);
auto shapeLeft =
shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
inputShape, axisValue)
.getResult(0);
auto sizeLeft =
shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
auto shapeRight =
shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
inputShape, axisNextValue)
.getResult(1);
auto sizeRight =
shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
// Compute flat input shape as a 3-element 1D tensor
auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize);
auto flatShapeType = shape::getExtentTensorType(context, 3);
auto flatInputShape = tensor::FromElementsOp::create(
builder, loc, flatShapeType,
ValueRange{sizeLeft, axisSizeValue, sizeRight});
// Reshape input to 3D tensor
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType = RankedTensorType::get(
{ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
flatInputShape);
return std::make_pair(flatInput, inputShape);
}
// Reshape an input tensor into its original unranked shape.
//
// - input
// Ranked tensor.
//
// - inputShape
// 1D extent tensor.
//
Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
Value inputShape) {
auto inputType = cast<RankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto unrankedType = UnrankedTensorType::get(elementType);
return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
inputShape);
}
// Create a tensor constant containing all scales in a per-channel quantized
// type. Example:
//
// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
//
// produces
//
// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
//
Value materializePerChannelScales(OpBuilder &builder, Location loc,
UniformQuantizedPerAxisType quantizedType) {
auto scales = quantizedType.getScales();
auto expressedType = quantizedType.getExpressedType();
auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
return builder.getFloatAttr(expressedType, scale);
});
auto tensorType =
RankedTensorType::get({(int64_t)scales.size()}, expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
}
// Create a tensor constant containing all zero points in a per-channel
// quantized type. Example:
//
// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
//
// produces
//
// %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
//
Value materializePerChannelZeroPoints(
OpBuilder &builder, Location loc,
UniformQuantizedPerAxisType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
auto storageType = quantizedType.getStorageType();
auto zeroPointAttrs =
llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
return builder.getIntegerAttr(storageType, zeroPoint);
});
auto tensorType =
RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
}
// Create a tensor constant containing all scales in a sub-channel quantized
// type. Example:
//
// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
//
// produces
//
// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
//
Value materializeSubChannelScales(
OpBuilder &builder, Location loc,
UniformQuantizedSubChannelType quantizedType) {
auto scales = quantizedType.getScales();
auto expressedType = quantizedType.getExpressedType();
auto scaleAttrs = llvm::map_to_vector(
scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
return builder.getFloatAttr(expressedType, scale);
});
auto tensorType =
RankedTensorType::get(scales.getType().getShape(), expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
}
// Create a tensor constant containing all zero points in a sub-channel
// quantized type. Example:
//
// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
//
// produces
//
// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
//
Value materializeSubChannelZeroPoints(
OpBuilder &builder, Location loc,
UniformQuantizedSubChannelType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
auto storageType = quantizedType.getStorageType();
auto zeroPointAttrs = llvm::map_to_vector(
zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
return builder.getIntegerAttr(storageType, zeroPoint);
});
auto tensorType =
RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
}
// Clamp the given scalar or tensor input using the storage bounds encoded in
// the given quantized type, if present.
//
// - input
// Scalar or ranked tensor input. The element type must match the storage type
// of 'quantizedType'.
//
// - inputShape
// If 'input' is a tensor, combination of attributes/values representing its
// static/dynamic dimensions. If 'input' is a scalar, empty list.
//
// - quantizedType
// Per-axis or per-channel quantized type.
Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape,
QuantizedType quantizedType) {
// If quantized type does not narrow down the storage type range, there is
// nothing to do.
if (!quantizedType.hasStorageTypeBounds())
return input;
// Materialize bounds
auto inputType = input.getType();
auto storageType = quantizedType.getStorageType();
auto storageMinScalar = arith::ConstantIntOp::create(
builder, loc, storageType, quantizedType.getStorageTypeMin());
auto storageMaxScalar = arith::ConstantIntOp::create(
builder, loc, storageType, quantizedType.getStorageTypeMax());
auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
inputType, inputShape);
auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
inputType, inputShape);
// Clamp
if (quantizedType.isSigned()) {
input = arith::MaxSIOp::create(builder, loc, input, storageMin);
input = arith::MinSIOp::create(builder, loc, input, storageMax);
} else {
input = arith::MaxUIOp::create(builder, loc, input, storageMin);
input = arith::MinUIOp::create(builder, loc, input, storageMax);
}
return input;
}
// Emit op 'arith.fptosi' or 'arith.fptoui'.
Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
return arith::FPToSIOp::create(builder, loc, resultType, input);
return arith::FPToUIOp::create(builder, loc, resultType, input);
}
// Emit op 'arith.sitofp' or 'arith.uitofp'.
Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
return arith::SIToFPOp::create(builder, loc, resultType, input);
return arith::UIToFPOp::create(builder, loc, resultType, input);
}
// Quantize a scalar or ranked tensor value. The stored value is clamped using
// the storage bounds encoded in the given quantized type.
//
// See function 'convertRanked()' below for a description of the arguments.
Value quantizeValue(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Scale input
auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
// Skip unnecessary computations if no zero point is given
Value storedValueFloat = scaledValue;
if (!matchPattern(zeroPoint, m_Zero())) {
// Convert zero point to tensor if necessary
zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
inputShape);
// Convert zero point from storage to expressed type
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Add zero point to stored value
storedValueFloat =
arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
}
// Convert stored value to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), inputType);
auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
storageScalarOrTensorType,
quantizedType.isSigned());
// Clamp stored value it if the storage type is bound
auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
inputShape, quantizedType);
return storedValueClamped;
}
// Dequantize a scalar or ranked tensor input.
//
// See function 'convertRanked()' below for a description of the arguments.
Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Convert stored value to float
auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
quantizedType.isSigned());
// Skip unnecessary computations if no zero point is given
if (!matchPattern(zeroPoint, m_Zero())) {
// Convert zero point to tensor if necessary
zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
inputShape);
// Convert zero point from storage to expressed type
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Subtract zero point to stored value
result = arith::SubFOp::create(builder, loc, result, zeroPoint);
}
// Multiply by scale
result = arith::MulFOp::create(builder, loc, result, scale);
return result;
}
// Convert a scalar or ranked tensor input with the given scale and zero point
// values.
//
// - input
// Scalar or ranked tensor value.
//
// - inputShape
// If 'input' is a tensor, combination or attributes/values representing its
// static/dynamic dimensions. If 'input' is a scalar, empty list.
//
// - scale
// Scale as a floating-point scalar value.
//
// - zeroPoint
// Zero point as an integer scalar value.
//
// - quantizedType
// Scalar quantized type of the result ('quant.qcast') or of the input
// ('quant.dcast').
//
Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
if (isa<QuantizeCastOp>(op))
return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
quantizedType);
if (isa<DequantizeCastOp>(op))
return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
quantizedType);
llvm_unreachable("unexpected quant op");
}
// Convert an operation using per-layer quantization with a scalar or ranked
// tensor input.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar or ranked tensor.
//
// - quantizedType
// Per-layer quantized type.
//
Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
Value input, UniformQuantizedType quantizedType) {
// Create scale and zero point constants
auto expressedType = quantizedType.getExpressedType();
auto storageType = quantizedType.getStorageType();
auto scaleAttr =
builder.getFloatAttr(expressedType, quantizedType.getScale());
auto scale =
arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
auto zeroPointAttr =
builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
auto zeroPoint =
arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
auto inputShape = getScalarOrTensorShape(builder, loc, input);
return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
quantizedType);
}
// Convert an operation using per-layer quantization.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor, or unranked tensor.
//
// - quantizedType
// Per-layer quantized type.
//
Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
Value input, UniformQuantizedType quantizedType) {
// Flatten input if unranked
bool isUnranked = isa<UnrankedTensorType>(input.getType());
Value inputShape;
if (isUnranked)
std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
// Process ranked tensor
auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
// Restore original shape if unranked
if (isUnranked)
result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
return result;
}
// Convert an operation using per-channel quantization and a scalar or ranked
// tensor as an input.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar or ranked tensor.
//
// - quantizedType
// Per-channel quantized type.
//
Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedPerAxisType quantizedType,
int64_t channelAxis) {
auto *context = builder.getContext();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
auto scales = materializePerChannelScales(builder, loc, quantizedType);
auto zeroPoints =
materializePerChannelZeroPoints(builder, loc, quantizedType);
auto elementType = isa<FloatType>(inputType.getElementType())
? quantizedType.getStorageType()
: quantizedType.getExpressedType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
auto channelAxisAffineMap = AffineMap::get(
inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
SmallVector<AffineMap> indexingMaps{
builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
auto result = linalg::GenericOp::create(
builder, loc,
init.getType(), // resultType
ValueRange{input, scales, zeroPoints}, // inputs
ValueRange{init}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
assert(args.size() == 4);
auto input = args[0];
auto scale = args[1];
auto zeroPoint = args[2];
auto result =
convertRanked(builder, loc, op, input, {}, scale,
zeroPoint, quantizedType);
linalg::YieldOp::create(builder, loc, result);
})
.getResult(0);
return result;
}
// Convert an operation using per-channel quantization.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor, or unranked tensor.
//
// - quantizedType
// Per-channel quantized type.
//
Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedPerAxisType quantizedType) {
// Flatten unranked tensor into a 3D ranked tensor if necessary
bool isUnranked = isa<UnrankedTensorType>(input.getType());
int64_t channelAxis = quantizedType.getQuantizedDimension();
int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
Value inputShape;
if (isUnranked) {
std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
builder, loc, input, channelAxis, channelAxisSize);
channelAxis = 1;
}
// Work on a ranked tensor
auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
channelAxis);
// Restore original tensor shape if unranked
if (isUnranked)
result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
return result;
}
// Convert an operation using sub-channel quantization.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor.
//
// - quantizedType
// Sub-channel quantized type.
//
Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedSubChannelType quantizedType) {
auto *context = builder.getContext();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
auto scales = materializeSubChannelScales(builder, loc, quantizedType);
auto zeroPoints =
materializeSubChannelZeroPoints(builder, loc, quantizedType);
auto elementType = isa<FloatType>(inputType.getElementType())
? quantizedType.getStorageType()
: quantizedType.getExpressedType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
quantizedType.getBlockSizeInfo();
SmallVector<AffineExpr> affineExprs(inputRank,
builder.getAffineConstantExpr(0));
for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
affineExprs[quantizedDimension] =
builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
}
auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
SmallVector<AffineMap> indexingMaps{
builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
builder.getMultiDimIdentityMap(inputRank)};
auto result = linalg::GenericOp::create(
builder, loc,
init.getType(), // resultType
ValueRange{input, scales, zeroPoints}, // inputs
ValueRange{init}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
assert(args.size() == 4);
auto input = args[0];
auto scale = args[1];
auto zeroPoint = args[2];
auto result =
convertRanked(builder, loc, op, input, {}, scale,
zeroPoint, quantizedType);
linalg::YieldOp::create(builder, loc, result);
})
.getResult(0);
return result;
}
// Convert a quantization operation.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor, or unranked tensor. The element type matches
// the storage type (quant.dcast) or expressed type (quant.qcast) of
// 'quantizedType'.
//
// - quantizedType
// Per-layer or per-channel quantized type.
//
Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
Value input, Type quantizedType) {
if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
if (auto uniformQuantizedPerAxisType =
dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
return convertPerChannel(builder, loc, op, input,
uniformQuantizedPerAxisType);
if (auto uniformQuantizedSubChannelType =
dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
return convertSubChannel(builder, loc, op, input,
uniformQuantizedSubChannelType);
llvm_unreachable("unexpected quantized type");
}
// Lowering pattern for 'quant.dcast'
struct DequantizeCastOpConversion
: public OpConversionPattern<quant::DequantizeCastOp> {
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getInput();
auto quantizedType =
cast<QuantizedType>(getScalarType(op.getInput().getType()));
// Convert quantized input to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
input = quant::StorageCastOp::create(rewriter, loc,
storageScalarOrTensorType, input);
auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
rewriter.replaceOp(op, result);
return success();
}
};
// Lowering pattern for 'quant.qcast'
struct QuantizeCastOpConversion
: public OpConversionPattern<quant::QuantizeCastOp> {
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getInput();
auto quantizedType = getScalarType(op.getResult().getType());
// Flatten unranked tensor input
auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
// Cast stored value to result quantized value
rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
op, op.getResult().getType(), result);
return success();
}
};
struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateLowerQuantOpsPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalOp<quant::StorageCastOp>();
target.addIllegalDialect<quant::QuantDialect>();
target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
shape::ShapeDialect, tensor::TensorDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
patterns.getContext());
}
} // namespace quant
} // namespace mlir
|