diff options
| author | Nicolas Vasilache <nicolasvasilache@users.noreply.github.com> | 2023-09-18 15:08:18 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-18 15:08:18 +0200 |
| commit | bf7c490ab73a22620c3d7790c09bfb11b669e51b (patch) | |
| tree | caf99be4c16233ac9868f439e37bf123cc335704 | |
| parent | b8f64431eaf750941490f8fb155b597d618d71a7 (diff) | |
| download | llvm-bf7c490ab73a22620c3d7790c09bfb11b669e51b.zip llvm-bf7c490ab73a22620c3d7790c09bfb11b669e51b.tar.gz llvm-bf7c490ab73a22620c3d7790c09bfb11b669e51b.tar.bz2 | |
[mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (#66387)
…(trunci) expansion
This revision adds a rewrite for sequences of vector `bitcast(trunci)`
to use a more efficient sequence of vector operations comprising
`shuffle` and `bitwise` ops.
Such patterns appear naturally when writing quantization /
dequantization functionality with the vector dialect.
The rewrite performs a simple enumeration of each of the bits in the
result vector and determines its provenance in the pre-trunci vector.
The enumeration is used to generate the proper sequence of `shuffle`,
`andi`, `ori` followed by an optional final `trunci`/`extui`.
The rewrite currently only applies to 1-D non-scalable vectors and bails
out if the final vector element type is not a multiple of 8. This is a
failsafe heuristic determined empirically: if the resulting type is not
an even number of bytes, further complexities arise that are not
improved by this pattern: the heavy lifting still needs to be done by
LLVM.
7 files changed, 658 insertions, 5 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 9e718a0c..133ee4e 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -292,6 +292,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect, }]; } +def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect, + "apply_patterns.vector.rewrite_narrow_types", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Indicates that vector narrow rewrite operations should be applied. + + This is usually a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect, "apply_patterns.vector.split_transfer_full_partial", [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index c644090..8652fc7 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -24,6 +24,7 @@ class RewritePatternSet; namespace arith { class NarrowTypeEmulationConverter; +class TruncIOp; } // namespace arith namespace vector { @@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( /// Patterns that remove redundant vector broadcasts. void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); + PatternBenefit benefit = 1); /// Populate `patterns` with the following patterns. /// @@ -301,6 +302,18 @@ void populateVectorNarrowTypeEmulationPatterns( arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns); +/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of +/// vector operations comprising `shuffle` and `bitwise` ops. +FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter, + vector::BitCastOp bitCastOp, + arith::TruncIOp truncOp, + vector::BroadcastOp maybeBroadcastOp); + +/// Appends patterns for rewriting vector operations over narrow types with +/// ops over wider types. +void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index f031eb0..9df5548 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -357,6 +357,16 @@ public: return *this; } + /// Set a dim in shape @pos to val. + Builder &setDim(unsigned pos, int64_t val) { + if (storage.empty()) + storage.append(shape.begin(), shape.end()); + assert(pos < storage.size() && "overflow"); + storage[pos] = val; + shape = {storage.data(), storage.size()}; + return *this; + } + operator VectorType() { return VectorType::get(shape, elementType, scalableDims); } diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index b388dea..37127ea 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -159,6 +159,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns( } } +void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( + RewritePatternSet &patterns) { + populateVectorNarrowTypeRewritePatterns(patterns); +} + void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index b2b7bfc..488dcff 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -7,7 +7,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -15,13 +14,23 @@ #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" -#include <cassert> +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <cstdint> using namespace mlir; +#define DEBUG_TYPE "vector-narrow-type-emulation" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace { //===----------------------------------------------------------------------===// @@ -156,6 +165,292 @@ struct ConvertVectorTransferRead final } // end anonymous namespace //===----------------------------------------------------------------------===// +// RewriteBitCastOfTruncI +//===----------------------------------------------------------------------===// + +namespace { + +/// Helper struct to keep track of the provenance of a contiguous set of bits +/// in a source vector. +struct SourceElementRange { + /// The index of the source vector element that contributes bits to *this. + int64_t sourceElementIdx; + /// The range of bits in the source vector element that contribute to *this. + int64_t sourceBitBegin; + int64_t sourceBitEnd; +}; + +struct SourceElementRangeList : public SmallVector<SourceElementRange> { + /// Given the index of a SourceElementRange in the SourceElementRangeList, + /// compute the amount of bits that need to be shifted to the left to get the + /// bits in their final location. This shift amount is simply the sum of the + /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always + /// the LSBs, the bits of `shuffleIdx = ` come next, etc). + int64_t computeLeftShiftAmount(int64_t shuffleIdx) const { + int64_t res = 0; + for (int64_t i = 0; i < shuffleIdx; ++i) + res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin; + return res; + } +}; + +/// Helper struct to enumerate the source elements and bit ranges that are +/// involved in a bitcast operation. +/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for +/// any 1-D vector shape and any source/target bitwidths. +/// This creates and holds a mapping of the form: +/// [dstVectorElementJ] == +/// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ] +/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as: +/// [0] = {0, [0-8)} +/// [1] = {0, [8-16)} +/// [2] = {0, [16-24)} +/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as: +/// [0] = {0, [0, 10)}, {1, [0, 5)} +/// [1] = {1, [5, 10)}, {2, [0, 10)} +struct BitCastBitsEnumerator { + BitCastBitsEnumerator(VectorType sourceVectorType, + VectorType targetVectorType); + + int64_t getMaxNumberOfEntries() { + int64_t numVectors = 0; + for (const auto &l : sourceElementRanges) + numVectors = std::max(numVectors, (int64_t)l.size()); + return numVectors; + } + + VectorType sourceVectorType; + VectorType targetVectorType; + SmallVector<SourceElementRangeList> sourceElementRanges; +}; + +} // namespace + +static raw_ostream &operator<<(raw_ostream &os, + const SmallVector<SourceElementRangeList> &vec) { + for (const auto &l : vec) { + for (auto it : llvm::enumerate(l)) { + os << "{ " << it.value().sourceElementIdx << ": b@[" + << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd + << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } "; + } + os << "\n"; + } + return os; +} + +BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, + VectorType targetVectorType) + : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) { + + assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() && + "requires -D non-scalable vector type"); + assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() && + "requires -D non-scalable vector type"); + int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth(); + int64_t mostMinorSourceDim = sourceVectorType.getShape().back(); + LDBG("sourceVectorType: " << sourceVectorType); + + int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth(); + int64_t mostMinorTargetDim = targetVectorType.getShape().back(); + LDBG("targetVectorType: " << targetVectorType); + + int64_t bitwidth = targetBitWidth * mostMinorTargetDim; + assert(bitwidth == sourceBitWidth * mostMinorSourceDim && + "source and target bitwidths must match"); + + // Prepopulate one source element range per target element. + sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim); + for (int64_t resultBit = 0; resultBit < bitwidth;) { + int64_t resultElement = resultBit / targetBitWidth; + int64_t resultBitInElement = resultBit % targetBitWidth; + int64_t sourceElementIdx = resultBit / sourceBitWidth; + int64_t sourceBitInElement = resultBit % sourceBitWidth; + int64_t step = std::min(sourceBitWidth - sourceBitInElement, + targetBitWidth - resultBitInElement); + sourceElementRanges[resultElement].push_back( + {sourceElementIdx, sourceBitInElement, sourceBitInElement + step}); + resultBit += step; + } +} + +namespace { +/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take +/// advantage of high-level information to avoid leaving LLVM to scramble with +/// peephole optimizations. + +// BitCastBitsEnumerator encodes for each element of the target vector the +// provenance of the bits in the source vector. We can "transpose" this +// information to build a sequence of shuffles and bitwise ops that will +// produce the desired result. +// +// Let's take the following motivating example to explain the algorithm: +// ``` +// %0 = arith.trunci %a : vector<32xi64> to vector<32xi5> +// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8> +// ``` +// +// BitCastBitsEnumerator contains the following information: +// ``` +// { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 } +// { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 } +// { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 } +// { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 } +// { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 } +// { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 } +// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 } +// { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 } +// { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 } +// { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3} +// { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5} +// { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7} +// { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4} +// { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6} +// { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 } +// { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 } +// { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 } +// { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4} +// { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6} +// { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 } +// ``` +// +// In the above, each row represents one target vector element and each +// column represents one bit contribution from a source vector element. +// The algorithm creates vector.shuffle operations (in this case there are 3 +// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The +// algorithm populates the bits as follows: +// ``` +// src bits 0 ... +// 1st shuffle |xxxxx |xx |... +// 2nd shuffle | xxx| xxxxx |... +// 3rd shuffle | | x|... +// ``` +// +// The algorithm proceeds as follows: +// 1. for each vector.shuffle, collect the source vectors that participate in +// this shuffle. One source vector per target element of the resulting +// vector.shuffle. If there is no source element contributing bits for the +// current vector.shuffle, take 0 (i.e. row 0 in the above example has only +// 2 columns). +// 2. represent the bitrange in the source vector as a mask. If there is no +// source element contributing bits for the current vector.shuffle, take 0. +// 3. shift right by the proper amount to align the source bitrange at +// position 0. This is exactly the low end of the bitrange. For instance, +// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to +// shift right by 3 to get the bits contributed by the source element #1 +// into position 0. +// 4. shift left by the proper amount to to align to the desired position in +// the result element vector. For instance, the contribution of the second +// source element for the first row needs to be shifted by `5` to form the +// first i8 result element. +// Eventually, we end up building the sequence +// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the +// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits +// extracted from the source vector (i.e. the `shuffle -> and` part). +struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, + PatternRewriter &rewriter) const override { + // The source must be a trunc op. + auto truncOp = + bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>(); + if (!truncOp) + return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source"); + + VectorType targetVectorType = bitCastOp.getResultVectorType(); + if (targetVectorType.getRank() != 1 || targetVectorType.isScalable()) + return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector"); + // TODO: consider relaxing this restriction in the future if we find ways + // to really work with subbyte elements across the MLIR/LLVM boundary. + int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth(); + if (resultBitwidth % 8 != 0) + return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8"); + + VectorType sourceVectorType = bitCastOp.getSourceVectorType(); + BitCastBitsEnumerator be(sourceVectorType, targetVectorType); + LDBG("\n" << be.sourceElementRanges); + + Value initialValue = truncOp.getIn(); + auto initalVectorType = initialValue.getType().cast<VectorType>(); + auto initalElementType = initalVectorType.getElementType(); + auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth(); + + Value res; + for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e; + ++shuffleIdx) { + SmallVector<int64_t> shuffles; + SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts; + + // Create the attribute quantities for the shuffle / mask / shift ops. + for (auto &srcEltRangeList : be.sourceElementRanges) { + bool idxContributesBits = + (shuffleIdx < (int64_t)srcEltRangeList.size()); + int64_t sourceElementIdx = + idxContributesBits ? srcEltRangeList[shuffleIdx].sourceElementIdx + : 0; + shuffles.push_back(sourceElementIdx); + + int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size()) + ? srcEltRangeList[shuffleIdx].sourceBitBegin + : 0; + int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size()) + ? srcEltRangeList[shuffleIdx].sourceBitEnd + : 0; + IntegerAttr mask = IntegerAttr::get( + rewriter.getIntegerType(initalElementBitWidth), + llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi)); + masks.push_back(mask); + + int64_t shiftRight = bitLo; + shiftRightAmounts.push_back(IntegerAttr::get( + rewriter.getIntegerType(initalElementBitWidth), shiftRight)); + + int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx); + shiftLeftAmounts.push_back(IntegerAttr::get( + rewriter.getIntegerType(initalElementBitWidth), shiftLeft)); + } + + // Create vector.shuffle #shuffleIdx. + auto shuffleOp = rewriter.create<vector::ShuffleOp>( + bitCastOp.getLoc(), initialValue, initialValue, shuffles); + // And with the mask. + VectorType vt = VectorType::Builder(initalVectorType) + .setDim(initalVectorType.getRank() - 1, masks.size()); + auto constOp = rewriter.create<arith::ConstantOp>( + bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks)); + Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(), + shuffleOp, constOp); + // Align right on 0. + auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>( + bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts)); + Value shiftedRight = rewriter.create<arith::ShRUIOp>( + bitCastOp.getLoc(), andValue, shiftRightConstantOp); + + auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>( + bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts)); + Value shiftedLeft = rewriter.create<arith::ShLIOp>( + bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp); + + res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res, + shiftedLeft) + : shiftedLeft; + } + + bool narrowing = resultBitwidth <= initalElementBitWidth; + if (narrowing) { + rewriter.replaceOpWithNewOp<arith::TruncIOp>( + bitCastOp, bitCastOp.getResultVectorType(), res); + } else { + rewriter.replaceOpWithNewOp<arith::ExtUIOp>( + bitCastOp, bitCastOp.getResultVectorType(), res); + } + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// // Public Interface Definition //===----------------------------------------------------------------------===// @@ -167,3 +462,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns( patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>( typeConverter, patterns.getContext()); } + +void vector::populateVectorNarrowTypeRewritePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<RewriteBitCastOfTruncI>(patterns.getContext(), benefit); +} diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir new file mode 100644 index 0000000..ba6efde --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -0,0 +1,157 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +/// Note: Inspect generated assembly and llvm-mca stats: +/// ==================================================== +/// mlir-opt --test-transform-dialect-interpreter mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir -test-transform-dialect-erase-schedule -test-lower-to-llvm | mlir-translate -mlir-to-llvmir | llc -o - -mcpu=skylake-avx512 --function-sections -filetype=obj > /tmp/a.out; objdump -d --disassemble=f1 --no-addresses --no-show-raw-insn -M att /tmp/a.out | ./build/bin/llvm-mca -mcpu=skylake-avx512 + +// CHECK-LABEL: func.func @f1( +// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<32xi64>) -> vector<20xi8> +func.func @f1(%a: vector<32xi64>) -> vector<20xi8> { + /// Rewriting this standalone pattern is about 2x faster on skylake-ax512 according to llvm-mca. + /// Benefit further increases when mixed with other compute ops. + /// + /// The provenance of the 20x8 bits of the result are the following bits in the + /// source vector: + // { 0: b@[0..5) lshl: 0 } { 1: b@[0..3) lshl: 5 } + // { 1: b@[3..5) lshl: 0 } { 2: b@[0..5) lshl: 2 } { 3: b@[0..1) lshl: 7 } + // { 3: b@[1..5) lshl: 0 } { 4: b@[0..4) lshl: 4 } + // { 4: b@[4..5) lshl: 0 } { 5: b@[0..5) lshl: 1 } { 6: b@[0..2) lshl: 6 } + // { 6: b@[2..5) lshl: 0 } { 7: b@[0..5) lshl: 3 } + // { 8: b@[0..5) lshl: 0 } { 9: b@[0..3) lshl: 5 } + // { 9: b@[3..5) lshl: 0 } { 10: b@[0..5) lshl: 2 } { 11: b@[0..1) lshl: 7 } + // { 11: b@[1..5) lshl: 0 } { 12: b@[0..4) lshl: 4 } + // { 12: b@[4..5) lshl: 0 } { 13: b@[0..5) lshl: 1 } { 14: b@[0..2) lshl: 6 } + // { 14: b@[2..5) lshl: 0 } { 15: b@[0..5) lshl: 3 } + // { 16: b@[0..5) lshl: 0 } { 17: b@[0..3) lshl: 5 } + // { 17: b@[3..5) lshl: 0 } { 18: b@[0..5) lshl: 2 } { 19: b@[0..1) lshl: 7 } + // { 19: b@[1..5) lshl: 0 } { 20: b@[0..4) lshl: 4 } + // { 20: b@[4..5) lshl: 0 } { 21: b@[0..5) lshl: 1 } { 22: b@[0..2) lshl: 6 } + // { 22: b@[2..5) lshl: 0 } { 23: b@[0..5) lshl: 3 } + // { 24: b@[0..5) lshl: 0 } { 25: b@[0..3) lshl: 5 } + // { 25: b@[3..5) lshl: 0 } { 26: b@[0..5) lshl: 2 } { 27: b@[0..1) lshl: 7 } + // { 27: b@[1..5) lshl: 0 } { 28: b@[0..4) lshl: 4 } + // { 28: b@[4..5) lshl: 0 } { 29: b@[0..5) lshl: 1 } { 30: b@[0..2) lshl: 6 } + // { 30: b@[2..5) lshl: 0 } { 31: b@[0..5) lshl: 3 } + /// This results in 3 shuffles + 1 shr + 2 shl + 3 and + 2 or. + /// The third vector is empty for positions 0, 2, 4, 5, 7, 9, 10, 12, 14, 15, + /// 17 and 19 (i.e. there are only 2 entries in that row). + /// + /// 0: b@[0..5), 1: b@[3..5), etc + // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28, 31, 24, 30, 16, 28]> : vector<20xi64> + /// 1: b@[0..3), 2: b@[0..5), etc + // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31, 7, 31, 15, 31, 31]> : vector<20xi64> + /// empty, 3: b@[0..1), empty etc + // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0]> : vector<20xi64> + // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2, 0, 3, 1, 4, 2]> : vector<20xi64> + // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3, 5, 2, 4, 1, 3]> : vector<20xi64> + // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8, 8, 7, 8, 6, 8]> : vector<20xi64> + // + // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 1, 3, 4, 6, 8, 9, 11, 12, 14, 16, 17, 19, 20, 22, 24, 25, 27, 28, 30] : vector<32xi64>, vector<32xi64> + // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<20xi64> + // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<20xi64> + // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 2, 4, 5, 7, 9, 10, 12, 13, 15, 17, 18, 20, 21, 23, 25, 26, 28, 29, 31] : vector<32xi64>, vector<32xi64> + // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<20xi64> + // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<20xi64> + // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<20xi64> + // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [0, 3, 0, 6, 0, 0, 11, 0, 14, 0, 0, 19, 0, 22, 0, 0, 27, 0, 30, 0] : vector<32xi64>, vector<32xi64> + // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK2]] : vector<20xi64> + // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<20xi64> + // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<20xi64> + // CHECK: %[[TR:.*]] = arith.trunci %[[O2]] : vector<20xi64> to vector<20xi8> + // CHECK-NOT: bitcast + %0 = arith.trunci %a : vector<32xi64> to vector<32xi5> + %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8> + return %1 : vector<20xi8> +} + +// CHECK-LABEL: func.func @f2( +// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<16xi16>) -> vector<3xi16> +func.func @f2(%a: vector<16xi16>) -> vector<3xi16> { + /// Rewriting this standalone pattern is about 1.8x faster on skylake-ax512 according to llvm-mca. + /// Benefit further increases when mixed with other compute ops. + /// + // { 0: b@[0..3) lshl: 0 } { 1: b@[0..3) lshl: 3 } { 2: b@[0..3) lshl: 6 } { 3: b@[0..3) lshl: 9 } { 4: b@[0..3) lshl: 12 } { 5: b@[0..1) lshl: 15 } + // { 5: b@[1..3) lshl: 0 } { 6: b@[0..3) lshl: 2 } { 7: b@[0..3) lshl: 5 } { 8: b@[0..3) lshl: 8 } { 9: b@[0..3) lshl: 11 } { 10: b@[0..2) lshl: 14 } + // { 10: b@[2..3) lshl: 0 } { 11: b@[0..3) lshl: 1 } { 12: b@[0..3) lshl: 4 } { 13: b@[0..3) lshl: 7 } { 14: b@[0..3) lshl: 10 } { 15: b@[0..3) lshl: 13 } + /// 0: b@[0..3), 5: b@[1..3), 10: b@[2..3) + // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<[7, 6, 4]> : vector<3xi16> + /// 1: b@[0..3), 6: b@[0..3), 11: b@[0..3) + /// ... + // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<7> : vector<3xi16> + /// 5: b@[0..1), 10: b@[0..2), 15: b@[0..3) + // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<[1, 3, 7]> : vector<3xi16> + // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi16> + // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[3, 2, 1]> : vector<3xi16> + // CHECK-DAG: %[[SHL2_CST:.*]] = arith.constant dense<[6, 5, 4]> : vector<3xi16> + // CHECK-DAG: %[[SHL3_CST:.*]] = arith.constant dense<[9, 8, 7]> : vector<3xi16> + // CHECK-DAG: %[[SHL4_CST:.*]] = arith.constant dense<[12, 11, 10]> : vector<3xi16> + // CHECK-DAG: %[[SHL5_CST:.*]] = arith.constant dense<[15, 14, 13]> : vector<3xi16> + + // + // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 5, 10] : vector<16xi16>, vector<16xi16> + // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<3xi16> + // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<3xi16> + // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 6, 11] : vector<16xi16>, vector<16xi16> + // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<3xi16> + // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<3xi16> + // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<3xi16> + // CHECK: %[[V2:.*]] = vector.shuffle %[[A]], %[[A]] [2, 7, 12] : vector<16xi16>, vector<16xi16> + // CHECK: %[[A2:.*]] = arith.andi %[[V2]], %[[MASK1]] : vector<3xi16> + // CHECK: %[[SHL2:.*]] = arith.shli %[[A2]], %[[SHL2_CST]] : vector<3xi16> + // CHECK: %[[O2:.*]] = arith.ori %[[O1]], %[[SHL2]] : vector<3xi16> + // CHECK: %[[V3:.*]] = vector.shuffle %[[A]], %[[A]] [3, 8, 13] : vector<16xi16>, vector<16xi16> + // CHECK: %[[A3:.*]] = arith.andi %[[V3]], %[[MASK1]] : vector<3xi16> + // CHECK: %[[SHL3:.*]] = arith.shli %[[A3]], %[[SHL3_CST]] : vector<3xi16> + // CHECK: %[[O3:.*]] = arith.ori %[[O2]], %[[SHL3]] : vector<3xi16> + // CHECK: %[[V4:.*]] = vector.shuffle %[[A]], %[[A]] [4, 9, 14] : vector<16xi16>, vector<16xi16> + // CHECK: %[[A4:.*]] = arith.andi %[[V4]], %[[MASK1]] : vector<3xi16> + // CHECK: %[[SHL4:.*]] = arith.shli %[[A4]], %[[SHL4_CST]] : vector<3xi16> + // CHECK: %[[O4:.*]] = arith.ori %[[O3]], %[[SHL4]] : vector<3xi16> + // CHECK: %[[V5:.*]] = vector.shuffle %[[A]], %[[A]] [5, 10, 15] : vector<16xi16>, vector<16xi16> + // CHECK: %[[A5:.*]] = arith.andi %[[V5]], %[[MASK2]] : vector<3xi16> + // CHECK: %[[SHL5:.*]] = arith.shli %[[A5]], %[[SHL5_CST]] : vector<3xi16> + // CHECK: %[[O5:.*]] = arith.ori %[[O4]], %[[SHL5]] : vector<3xi16> + /// No trunci needed as the result is already in i16. + // CHECK-NOT: arith.trunci + // CHECK-NOT: bitcast + %0 = arith.trunci %a : vector<16xi16> to vector<16xi3> + %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16> + return %1 : vector<3xi16> +} + +/// This pattern requires an extui 16 -> 32 and not a trunci. +// CHECK-LABEL: func.func @f3( +func.func @f3(%a: vector<16xi16>) -> vector<2xi32> { + /// Rewriting this standalone pattern is about 25x faster on skylake-ax512 according to llvm-mca. + /// Benefit further increases when mixed with other compute ops. + /// + // CHECK-NOT: arith.trunci + // CHECK-NOT: bitcast + // CHECK: arith.extui + %0 = arith.trunci %a : vector<16xi16> to vector<16xi4> + %1 = vector.bitcast %0 : vector<16xi4> to vector<2xi32> + return %1 : vector<2xi32> +} + +/// This pattern is not rewritten as the result i6 is not a multiple of i8. +// CHECK-LABEL: func.func @f4( +func.func @f4(%a: vector<16xi16>) -> vector<8xi6> { + // CHECK: trunci + // CHECK: bitcast + // CHECK-NOT: shuffle + // CHECK-NOT: andi + // CHECK-NOT: ori + %0 = arith.trunci %a : vector<16xi16> to vector<16xi3> + %1 = vector.bitcast %0 : vector<16xi3> to vector<8xi6> + return %1 : vector<8xi6> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.rewrite_narrow_types + } : !transform.any_op +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir new file mode 100644 index 0000000..44c6087 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir @@ -0,0 +1,155 @@ +/// Run once without applying the pattern and check the source of truth. +// RUN: mlir-opt %s --test-transform-dialect-erase-schedule -test-lower-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +/// Run once with the pattern and compare. +// RUN: mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule -test-lower-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +func.func @print_as_i1_16xi5(%v : vector<16xi5>) { + %bitsi16 = vector.bitcast %v : vector<16xi5> to vector<80xi1> + vector.print %bitsi16 : vector<80xi1> + return +} + +func.func @print_as_i1_10xi8(%v : vector<10xi8>) { + %bitsi16 = vector.bitcast %v : vector<10xi8> to vector<80xi1> + vector.print %bitsi16 : vector<80xi1> + return +} + +func.func @f(%v: vector<16xi16>) { + %trunc = arith.trunci %v : vector<16xi16> to vector<16xi5> + func.call @print_as_i1_16xi5(%trunc) : (vector<16xi5>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, 1, 1, + // CHECK-SAME: 0, 1, 1, 1, 1, + // CHECK-SAME: 1, 0, 1, 1, 1, + // CHECK-SAME: 0, 0, 1, 1, 1, + // CHECK-SAME: 1, 1, 0, 1, 1, + // CHECK-SAME: 0, 1, 0, 1, 1, + // CHECK-SAME: 1, 0, 0, 1, 1, + // CHECK-SAME: 0, 0, 0, 1, 1, + // CHECK-SAME: 1, 1, 1, 0, 1, + // CHECK-SAME: 0, 1, 1, 0, 1, + // CHECK-SAME: 1, 0, 1, 0, 1, + // CHECK-SAME: 0, 0, 1, 0, 1, + // CHECK-SAME: 1, 1, 0, 0, 1, + // CHECK-SAME: 0, 1, 0, 0, 1, + // CHECK-SAME: 1, 0, 0, 0, 1, + // CHECK-SAME: 0, 0, 0, 0, 1 ) + + %bitcast = vector.bitcast %trunc : vector<16xi5> to vector<10xi8> + func.call @print_as_i1_10xi8(%bitcast) : (vector<10xi8>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, 1, 1, 0, 1, 1, + // CHECK-SAME: 1, 1, 1, 0, 1, 1, 1, 0, + // CHECK-SAME: 0, 1, 1, 1, 1, 1, 0, 1, + // CHECK-SAME: 1, 0, 1, 0, 1, 1, 1, 0, + // CHECK-SAME: 0, 1, 1, 0, 0, 0, 1, 1, + // CHECK-SAME: 1, 1, 1, 0, 1, 0, 1, 1, + // CHECK-SAME: 0, 1, 1, 0, 1, 0, 1, 0, + // CHECK-SAME: 0, 1, 0, 1, 1, 1, 0, 0, + // CHECK-SAME: 1, 0, 1, 0, 0, 1, 1, 0, + // CHECK-SAME: 0, 0, 1, 0, 0, 0, 0, 1 ) + + return +} + +func.func @print_as_i1_8xi3(%v : vector<8xi3>) { + %bitsi12 = vector.bitcast %v : vector<8xi3> to vector<24xi1> + vector.print %bitsi12 : vector<24xi1> + return +} + +func.func @print_as_i1_3xi8(%v : vector<3xi8>) { + %bitsi12 = vector.bitcast %v : vector<3xi8> to vector<24xi1> + vector.print %bitsi12 : vector<24xi1> + return +} + +func.func @f2(%v: vector<8xi32>) { + %trunc = arith.trunci %v : vector<8xi32> to vector<8xi3> + func.call @print_as_i1_8xi3(%trunc) : (vector<8xi3>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, + // CHECK-SAME: 0, 1, 1, + // CHECK-SAME: 1, 0, 1, + // CHECK-SAME: 0, 0, 1, + // CHECK-SAME: 1, 1, 0, + // CHECK-SAME: 0, 1, 0, + // CHECK-SAME: 1, 0, 0, + // CHECK-SAME: 0, 0, 0 ) + + %bitcast = vector.bitcast %trunc : vector<8xi3> to vector<3xi8> + func.call @print_as_i1_3xi8(%bitcast) : (vector<3xi8>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, 0, 1, 1, 1, 0, + // CHECK-SAME: 1, 0, 0, 1, 1, 1, 0, 0, + // CHECK-SAME: 1, 0, 1, 0, 0, 0, 0, 0 ) + + return +} + +func.func @print_as_i1_2xi24(%v : vector<2xi24>) { + %bitsi48 = vector.bitcast %v : vector<2xi24> to vector<48xi1> + vector.print %bitsi48 : vector<48xi1> + return +} + +func.func @print_as_i1_3xi16(%v : vector<3xi16>) { + %bitsi48 = vector.bitcast %v : vector<3xi16> to vector<48xi1> + vector.print %bitsi48 : vector<48xi1> + return +} + +func.func @f3(%v: vector<2xi48>) { + %trunc = arith.trunci %v : vector<2xi48> to vector<2xi24> + func.call @print_as_i1_2xi24(%trunc) : (vector<2xi24>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + // CHECK-SAME: 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0 ) + + %bitcast = vector.bitcast %trunc : vector<2xi24> to vector<3xi16> + func.call @print_as_i1_3xi16(%bitcast) : (vector<3xi16>) -> () + // CHECK: ( + // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + // CHECK-SAME: 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, + // CHECK-SAME: 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0 ) + + return +} + +func.func @entry() { + %v = arith.constant dense<[ + 0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8, + 0xfff7, 0xfff6, 0xfff5, 0xfff4, 0xfff3, 0xfff2, 0xfff1, 0xfff0 + ]> : vector<16xi16> + func.call @f(%v) : (vector<16xi16>) -> () + + %v2 = arith.constant dense<[ + 0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8 + ]> : vector<8xi32> + func.call @f2(%v2) : (vector<8xi32>) -> () + + %v3 = arith.constant dense<[ + 0xf345aeffffff, 0xffff015f345a + ]> : vector<2xi48> + func.call @f3(%v3) : (vector<2xi48>) -> () + + return +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.rewrite_narrow_types + } : !transform.any_op +} |
