From b7360fbe8ca0c9411e89fafd654856c484f84f5e Mon Sep 17 00:00:00 2001 From: erman-gurses <99776114+erman-gurses@users.noreply.github.com> Date: Fri, 19 Jan 2024 18:44:45 -0500 Subject: [mlir][amdgpu] Shared memory access optimization pass (#75627) It implements transformation to optimize accesses to shared memory. Reference: https://reviews.llvm.org/D127457 _This change adds a transformation and pass to the NvGPU dialect that attempts to optimize reads/writes from a memref representing GPU shared memory in order to avoid bank conflicts. Given a value representing a shared memory memref, it traverses all reads/writes within the parent op and, subject to suitable conditions, rewrites all last dimension index values such that element locations in the final (col) dimension are given by newColIdx = col % vecSize + perm[row](col / vecSize, row) where perm is a permutation function indexed by row and vecSize is the vector access size in elements (currently assumes 128bit vectorized accesses, but this can be made a parameter). This specific transformation can help optimize typical distributed & vectorized accesses common to loading matrix multiplication operands to/from shared memory._ --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 17 ++ .../mlir/Dialect/AMDGPU/Transforms/Passes.h | 3 +- .../mlir/Dialect/AMDGPU/Transforms/Passes.td | 13 ++ .../mlir/Dialect/AMDGPU/Transforms/Transforms.h | 54 +++++ .../include/mlir/Dialect/AMDGPU/Transforms/Utils.h | 24 ++ mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 15 ++ mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt | 2 + .../AMDGPU/Transforms/OptimizeSharedMemory.cpp | 243 +++++++++++++++++++++ mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 39 ++++ .../AMDGPU/optimize_shmem_reads_writes.mlir | 57 +++++ 10 files changed, 466 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h create mode 100644 mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp create mode 100644 mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir (limited to 'mlir') diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index ffb302f..b4bf1b5 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -29,6 +29,23 @@ def AMDGPU_Dialect : Dialect { "gpu::GPUDialect" ]; let useDefaultAttributePrinterParser = 1; + + let extraClassDeclaration = [{ + /// Return true if the given MemRefType has an integer address + /// space that matches the ROCDL shared memory address space or + /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`. + static bool hasSharedMemoryAddressSpace(MemRefType type); + + /// Return true if the given Attribute has an integer address + /// space that matches the ROCDL shared memory address space or + /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`. + static bool isSharedMemoryAddressSpace(Attribute type); + + /// Defines the MemRef memory space attribute numeric value that indicates + /// a memref is located in shared memory. This should correspond to the + /// value used in ROCDL. + static constexpr unsigned kSharedMemoryAddressSpace = 3; + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h index 8dd5ff1..11d182b 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h @@ -20,7 +20,8 @@ namespace mlir { class ConversionTarget; namespace amdgpu { -#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS +#define GEN_PASS_DECL + #define GEN_PASS_REGISTRATION #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td index e6b27aa..c8059e6d 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td @@ -30,4 +30,17 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> { "Chipset that these operations will run on">]; } +def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> { + let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts."; + let description = [{ + This pass adds a transformation and pass to the AMDGPU dialect that + attempts to optimize reads/writes from a memref representing GPU shared + memory in order to avoid bank conflicts. + }]; + + let dependentDialects = [ + "memref::MemRefDialect", "vector::VectorDialect" + ]; +} + #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_ diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h new file mode 100644 index 0000000..140bc12 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h @@ -0,0 +1,54 @@ +//===- Transforms.h - AMDGPU Dialect transformations --------------*- +// C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares functions that assist transformations for the amdgpu +// dialect. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_ +#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +class RewriterBase; + +namespace amdgpu { + +/// +/// Passes +/// + +/// Optimizes vectorized accesses to a shared memory buffer specified by +/// memrefValue. This transformation assumes the following: +/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`. +/// 2) The function will fail precondition checks if any subviews are +/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur +/// through `memrefValue` directly. +/// +/// Shared memory bank conflicts occur when multiple threads attempt to read or +/// write locations assigned to the same shared memory bank. For `2^N` byte +/// vectorized accesses, we need to be concerned with conflicts among threads +/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation +/// changes any indexed memory access (vector.load, memref.load, etc) +/// such that the final dimension's index value is permuted such that +/// `newColIndex = oldColIndex % vectorSize + +/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the +/// index for the second-to last dimension and `perm[rowIndex]` is a permutation +/// function that depends on the row Index. The permutation function is chosen +/// to ensure that sequential distributed+vectorized reads/writes down a single +/// dimension of the memref have minimal conflicts. +mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, + Value memrefValue); + +} // namespace amdgpu +} // namespace mlir + +#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h new file mode 100644 index 0000000..6be57ca --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h @@ -0,0 +1,24 @@ +//===- Utils.h - Transform utilities -----------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace amdgpu { + +/// Get and set the indices that the given load/store operation is operating on. +/// Preconditions: +/// - The Op must have memory affects +/// - Considers memref::LoadOp, vector::LoadOp, vector::TransferReadOp +/// - Considers memref::StoreOp, vector::StoreOp, vector::TransferWriteOp +/// - Excludes subview op +std::optional getIndices(Operation *op); +void setIndices(Operation *op, ArrayRef indices); + +} // namespace amdgpu +} // namespace mlir diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 2575ad4..4e72fbf 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() { >(); } +bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { + if (!memorySpace) + return false; + if (auto intAttr = llvm::dyn_cast(memorySpace)) + return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace; + if (auto gpuAttr = llvm::dyn_cast(memorySpace)) + return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; + return false; +} + +bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { + Attribute memorySpace = type.getMemorySpace(); + return isSharedMemoryAddressSpace(memorySpace); +} + //===----------------------------------------------------------------------===// // 8-bit float ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index e11b6cc..a1a9127 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms EmulateAtomics.cpp + OptimizeSharedMemory.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp new file mode 100644 index 0000000..c7001fc --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -0,0 +1,243 @@ +//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to optimize accesses to shared memory. +// It is inspired by +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h" +#include "mlir/Dialect/AMDGPU/Transforms/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace amdgpu { +#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" +} // namespace amdgpu +} // namespace mlir + +using namespace mlir; +using namespace mlir::amdgpu; + +/// The size of a shared memory line according to AMD documentation. +/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf +constexpr int64_t kSharedMemoryLineSizeBytes = 64; +/// We optimize for 64bit accesses, but this can be made an argument in the +/// future. +constexpr int64_t kDefaultVectorSizeBits = 64; + +/// Uses `srcIndexValue` to permute `tgtIndexValue` via +/// `result = xor(floordiv(srcIdxVal,permuteEveryN), +/// floordiv(tgtIdxVal,vectorSize))) +/// + tgtIdxVal % vectorSize` +/// This is done using an optimized sequence of `arith` operations. +static Value permuteVectorOffset(OpBuilder &b, Location loc, + ArrayRef indices, MemRefType memrefTy, + int64_t srcDim, int64_t tgtDim) { + // Adjust the src index to change how often the permutation changes + // if necessary. + Value src = indices[srcDim]; + + // We only want to permute every N iterations of the target dim where N is + // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). + const int64_t permuteEveryN = std::max( + 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * + memrefTy.getElementTypeBitWidth()) / + 8)); + + // clang-format off + // Index bit representation (b0 = least significant bit) for dim(1) + // of a `memref` is as follows: + // N := log2(128/elementSizeBits) + // M := log2(dimSize(1)) + // then + // bits[0:N] = sub-vector element offset + // bits[N:M] = vector index + // clang-format on + int64_t n = + llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); + int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); + + // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. + int64_t mask = (1LL << (m - n)) - 1; + if (permuteEveryN > 1) + mask = mask << llvm::Log2_64(permuteEveryN); + Value srcBits = b.create(loc, mask); + srcBits = b.create(loc, src, srcBits); + + // Use the src bits to permute the target bits b[N:M] containing the + // vector offset. + if (permuteEveryN > 1) { + int64_t shlBits = n - llvm::Log2_64(permuteEveryN); + if (shlBits > 0) { + Value finalShiftVal = b.create(loc, shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } else if (shlBits < 0) { + Value finalShiftVal = b.create(loc, -1 * shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + } else { + Value finalShiftVal = b.create(loc, n); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + + Value permutedVectorIdx = + b.create(loc, indices[tgtDim], srcBits); + return permutedVectorIdx; +} + +static void transformIndices(OpBuilder &builder, Location loc, + SmallVector &indices, + MemRefType memrefTy, int64_t srcDim, + int64_t tgtDim) { + indices[tgtDim] = + permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); +} + +/// Return all operations within `parentOp` that read from or write to +/// `shmMemRef`. +static LogicalResult +getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, + SmallVector &readOps, + SmallVector &writeOps) { + parentOp->walk([&](Operation *op) { + MemoryEffectOpInterface iface = dyn_cast(op); + if (!iface) + return; + std::optional effect = + iface.getEffectOnValue(shmMemRef); + if (effect) { + readOps.push_back(op); + return; + } + effect = iface.getEffectOnValue(shmMemRef); + if (effect) + writeOps.push_back(op); + }); + + // Restrict to a supported set of ops. We also require at least 2D access, + // although this could be relaxed. + if (llvm::any_of(readOps, [](Operation *op) { + return !isa( + op) || + amdgpu::getIndices(op)->size() < 2; + })) + return failure(); + if (llvm::any_of(writeOps, [](Operation *op) { + return !isa( + op) || + amdgpu::getIndices(op)->size() < 2; + })) + return failure(); + + return success(); +} + +mlir::LogicalResult +mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, + Value memrefValue) { + auto memRefType = dyn_cast(memrefValue.getType()); + if (!memRefType || + !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType)) + return failure(); + + // Abort if the given value has any sub-views; we do not do any alias + // analysis. + bool hasSubView = false; + parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; }); + if (hasSubView) + return failure(); + + // Check if this is necessary given the assumption of 128b accesses: + // If dim[rank-1] is small enough to fit 8 rows in a 128B line. + const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); + const int64_t rowsPerLine = + (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / + rowSize; + const int64_t threadGroupSize = + 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8)); + if (rowsPerLine >= threadGroupSize) + return failure(); + + // Get sets of operations within the function that read/write to shared + // memory. + SmallVector shmReadOps; + SmallVector shmWriteOps; + if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps, + shmWriteOps))) + return failure(); + + if (shmReadOps.empty() || shmWriteOps.empty()) + return failure(); + + OpBuilder builder(parentOp->getContext()); + + int64_t tgtDim = memRefType.getRank() - 1; + int64_t srcDim = memRefType.getRank() - 2; + + // Transform indices for the ops writing to shared memory. + while (!shmWriteOps.empty()) { + Operation *shmWriteOp = shmWriteOps.pop_back_val(); + builder.setInsertionPoint(shmWriteOp); + + auto indices = amdgpu::getIndices(shmWriteOp); + SmallVector transformedIndices(indices->begin(), indices->end()); + transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + amdgpu::setIndices(shmWriteOp, transformedIndices); + } + + // Transform indices for the ops reading from shared memory. + while (!shmReadOps.empty()) { + Operation *shmReadOp = shmReadOps.pop_back_val(); + builder.setInsertionPoint(shmReadOp); + + auto indices = amdgpu::getIndices(shmReadOp); + SmallVector transformedIndices(indices->begin(), indices->end()); + transformIndices(builder, shmReadOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + amdgpu::setIndices(shmReadOp, transformedIndices); + } + + return success(); +} + +struct OptimizeSharedMemoryPass + : public amdgpu::impl::OptimizeSharedMemoryBase { +public: + OptimizeSharedMemoryPass() = default; + + void runOnOperation() override { + Operation *op = getOperation(); + SmallVector shmAllocOps; + op->walk([&](memref::AllocOp allocOp) { + if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace( + allocOp.getType())) + return; + shmAllocOps.push_back(allocOp); + }); + for (auto allocOp : shmAllocOps) { + if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), + allocOp.getMemref()))) + return; + } + } +}; diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp new file mode 100644 index 0000000..8163eea --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp @@ -0,0 +1,39 @@ +#include "mlir/Dialect/AMDGPU/Transforms/Utils.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; +using namespace mlir::amdgpu; + +std::optional amdgpu::getIndices(Operation *op) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndices(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndices(); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndices(); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndices(); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndices(); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndices(); + return std::nullopt; +} + +void amdgpu::setIndices(Operation *op, ArrayRef indices) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndicesMutable().assign(indices); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndicesMutable().assign(indices); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndicesMutable().assign(indices); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndicesMutable().assign(indices); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndicesMutable().assign(indices); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndicesMutable().assign(indices); +} diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir new file mode 100644 index 0000000..41111dd --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s + + // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index) + func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, + %readRow: index, %readCol: index, + %writeRow: index, %writeCol: index, + %fragRow: index, %fragCol: index, + %fragColPerm: index, + %stRow: index, %stCol: index) { + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16 + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK: [[shmA:%.+]] = memref.alloc + // CHECK: [[shmB:%.+]] = memref.alloc + %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3> + %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3> + + // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] + // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3> + vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3> + gpu.barrier + gpu.barrier + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] + // CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16> + %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16> + + // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] + // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3> + vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3> + gpu.barrier + gpu.barrier + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] + // CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16> + %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16> + return + } + \ No newline at end of file -- cgit v1.1