1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements in-dialect rewriting of the shuffle op for types i64 and
// f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
// implementation using shifts and truncations can be obtained using clang: by
// emitting IR for shuffle operations with `-O3`.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
namespace {
struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
void initialize() {
// Required as the pattern will replace the Op with 2 additional ShuffleOps.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto value = op.getValue();
auto valueType = value.getType();
auto valueLoc = value.getLoc();
auto i32 = rewriter.getI32Type();
auto i64 = rewriter.getI64Type();
// If the type of the value is either i32 or f32, the op is already valid.
if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
return rewriter.notifyMatchFailure(
op, "only 64-bit int/float types are supported");
Value lo, hi;
// Float types must be converted to i64 to extract the bits.
if (isa<FloatType>(valueType))
value = arith::BitcastOp::create(rewriter, valueLoc, i64, value);
// Get the low bits by trunc(value).
lo = arith::TruncIOp::create(rewriter, valueLoc, i32, value);
// Get the high bits by trunc(value >> 32).
auto c32 = arith::ConstantOp::create(rewriter, valueLoc,
rewriter.getIntegerAttr(i64, 32));
hi = arith::ShRUIOp::create(rewriter, valueLoc, value, c32);
hi = arith::TruncIOp::create(rewriter, valueLoc, i32, hi);
// Shuffle the values.
ValueRange loRes =
gpu::ShuffleOp::create(rewriter, op.getLoc(), lo, op.getOffset(),
op.getWidth(), op.getMode())
.getResults();
ValueRange hiRes =
gpu::ShuffleOp::create(rewriter, op.getLoc(), hi, op.getOffset(),
op.getWidth(), op.getMode())
.getResults();
// Convert lo back to i64.
lo = arith::ExtUIOp::create(rewriter, valueLoc, i64, loRes[0]);
// Convert hi back to i64.
hi = arith::ExtUIOp::create(rewriter, valueLoc, i64, hiRes[0]);
hi = arith::ShLIOp::create(rewriter, valueLoc, hi, c32);
// Obtain the shuffled bits hi | lo.
value = arith::OrIOp::create(rewriter, loc, hi, lo);
// Convert the value back to float.
if (isa<FloatType>(valueType))
value = arith::BitcastOp::create(rewriter, valueLoc, valueType, value);
// Obtain the shuffle validity by combining both validities.
auto validity = arith::AndIOp::create(rewriter, loc, loRes[1], hiRes[1]);
// Replace the op.
rewriter.replaceOp(op, {value, validity});
return success();
}
};
} // namespace
void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
patterns.add<GpuShuffleRewriter>(patterns.getContext());
}
|