diff options
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index c7ecd83..2e00b42 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" #include <cassert> +#include <limits> #include <optional> #define DEBUG_TYPE "memref-to-spirv-pattern" @@ -475,7 +476,12 @@ struct MemoryRequirements { /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if /// any. static FailureOr<MemoryRequirements> -calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { +calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, + uint64_t preferredAlignment) { + if (preferredAlignment >= std::numeric_limits<uint32_t>::max()) { + return failure(); + } + MLIRContext *ctx = accessedPtr.getContext(); auto memoryAccess = spirv::MemoryAccess::None; @@ -484,7 +490,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { } auto ptrType = cast<spirv::PointerType>(accessedPtr.getType()); - if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { + bool mayOmitAlignment = + !preferredAlignment && + ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer; + if (mayOmitAlignment) { if (memoryAccess == spirv::MemoryAccess::None) { return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; } @@ -493,6 +502,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { } // PhysicalStorageBuffers require the `Aligned` attribute. + // Other storage types may show an `Aligned` attribute. auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType()); if (!pointeeType) return failure(); @@ -504,7 +514,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); - auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); + auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes; + auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue); return MemoryRequirements{memAccessAttr, alignment}; } @@ -518,16 +529,9 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value, "Must be called on either memref::LoadOp or memref::StoreOp"); - Operation *memrefAccessOp = loadOrStoreOp.getOperation(); - auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>( - spirv::attributeName<spirv::MemoryAccess>()); - auto memrefAlignment = - memrefAccessOp->getAttrOfType<IntegerAttr>("alignment"); - if (memrefMemAccess && memrefAlignment) - return MemoryRequirements{memrefMemAccess, memrefAlignment}; - return calculateMemoryRequirements(accessedPtr, - loadOrStoreOp.getNontemporal()); + loadOrStoreOp.getNontemporal(), + loadOrStoreOp.getAlignment().value_or(0)); } LogicalResult |