//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements in-dialect rewriting of the shuffle op for types i64 and // f64, rewriting 64bit shuffles into two 32bit shuffles. This particular // implementation using shifts and truncations can be obtained using clang: by // emitting IR for shuffle operations with `-O3`. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; namespace { struct GpuShuffleRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; void initialize() { // Required as the pattern will replace the Op with 2 additional ShuffleOps. setHasBoundedRewriteRecursion(); } LogicalResult matchAndRewrite(gpu::ShuffleOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto value = op.getValue(); auto valueType = value.getType(); auto valueLoc = value.getLoc(); auto i32 = rewriter.getI32Type(); auto i64 = rewriter.getI64Type(); // If the type of the value is either i32 or f32, the op is already valid. if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64) return rewriter.notifyMatchFailure( op, "only 64-bit int/float types are supported"); Value lo, hi; // Float types must be converted to i64 to extract the bits. if (isa(valueType)) value = arith::BitcastOp::create(rewriter, valueLoc, i64, value); // Get the low bits by trunc(value). lo = arith::TruncIOp::create(rewriter, valueLoc, i32, value); // Get the high bits by trunc(value >> 32). auto c32 = arith::ConstantOp::create(rewriter, valueLoc, rewriter.getIntegerAttr(i64, 32)); hi = arith::ShRUIOp::create(rewriter, valueLoc, value, c32); hi = arith::TruncIOp::create(rewriter, valueLoc, i32, hi); // Shuffle the values. ValueRange loRes = gpu::ShuffleOp::create(rewriter, op.getLoc(), lo, op.getOffset(), op.getWidth(), op.getMode()) .getResults(); ValueRange hiRes = gpu::ShuffleOp::create(rewriter, op.getLoc(), hi, op.getOffset(), op.getWidth(), op.getMode()) .getResults(); // Convert lo back to i64. lo = arith::ExtUIOp::create(rewriter, valueLoc, i64, loRes[0]); // Convert hi back to i64. hi = arith::ExtUIOp::create(rewriter, valueLoc, i64, hiRes[0]); hi = arith::ShLIOp::create(rewriter, valueLoc, hi, c32); // Obtain the shuffled bits hi | lo. value = arith::OrIOp::create(rewriter, loc, hi, lo); // Convert the value back to float. if (isa(valueType)) value = arith::BitcastOp::create(rewriter, valueLoc, valueType, value); // Obtain the shuffle validity by combining both validities. auto validity = arith::AndIOp::create(rewriter, loc, loRes[1], hiRes[1]); // Replace the op. rewriter.replaceOp(op, {value, validity}); return success(); } }; } // namespace void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }