diff options
Diffstat (limited to 'mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 41 |
1 files changed, 37 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index f44552c..a90dcc8 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -699,6 +699,35 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, return success(); } +template <typename OpAdaptor> +static FailureOr<SmallVector<Value>> +extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) { + // At present we only support linear "tiling" as specified in Vulkan, this + // means that texels are assumed to be laid out in memory in a row-major + // order. This allows us to support any memref layout that is a permutation of + // the dimensions. Future work will pass an optional image layout to the + // rewrite pattern so that we can support optimized target specific tilings. + SmallVector<Value> indices = adaptor.getIndices(); + AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap(); + if (!map.isPermutation()) + return rewriter.notifyMatchFailure( + loadOp, + "Cannot lower memrefs with memory layout which is not a permutation"); + + // The memrefs layout determines the dimension ordering so we need to follow + // the map to get the ordering of the dimensions/indices. + const unsigned dimCount = map.getNumDims(); + SmallVector<Value, 3> coords(dimCount); + for (unsigned dim = 0; dim < dimCount; ++dim) + coords[map.getDimPosition(dim)] = indices[dim]; + + // We need to reverse the coordinates because the memref layout is slowest to + // fastest moving and the vector coordinates for the image op is fastest to + // slowest moving. + return llvm::to_vector(llvm::reverse(coords)); +} + LogicalResult ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -755,13 +784,17 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, // Build a vector of coordinates or just a scalar index if we have a 1D image. Value coords; - if (memrefType.getRank() != 1) { + if (memrefType.getRank() == 1) { + coords = adaptor.getIndices()[0]; + } else { + FailureOr<SmallVector<Value>> maybeCoords = + extractLoadCoordsForComposite(loadOp, adaptor, rewriter); + if (failed(maybeCoords)) + return failure(); auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]); coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType, - adaptor.getIndices()); - } else { - coords = adaptor.getIndices()[0]; + maybeCoords.value()); } // Fetch the value out of the image. |