aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp41
1 files changed, 37 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index f44552c..a90dcc8 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -699,6 +699,35 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return success();
}
+template <typename OpAdaptor>
+static FailureOr<SmallVector<Value>>
+extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) {
+ // At present we only support linear "tiling" as specified in Vulkan, this
+ // means that texels are assumed to be laid out in memory in a row-major
+ // order. This allows us to support any memref layout that is a permutation of
+ // the dimensions. Future work will pass an optional image layout to the
+ // rewrite pattern so that we can support optimized target specific tilings.
+ SmallVector<Value> indices = adaptor.getIndices();
+ AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap();
+ if (!map.isPermutation())
+ return rewriter.notifyMatchFailure(
+ loadOp,
+ "Cannot lower memrefs with memory layout which is not a permutation");
+
+ // The memrefs layout determines the dimension ordering so we need to follow
+ // the map to get the ordering of the dimensions/indices.
+ const unsigned dimCount = map.getNumDims();
+ SmallVector<Value, 3> coords(dimCount);
+ for (unsigned dim = 0; dim < dimCount; ++dim)
+ coords[map.getDimPosition(dim)] = indices[dim];
+
+ // We need to reverse the coordinates because the memref layout is slowest to
+ // fastest moving and the vector coordinates for the image op is fastest to
+ // slowest moving.
+ return llvm::to_vector(llvm::reverse(coords));
+}
+
LogicalResult
ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -755,13 +784,17 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// Build a vector of coordinates or just a scalar index if we have a 1D image.
Value coords;
- if (memrefType.getRank() != 1) {
+ if (memrefType.getRank() == 1) {
+ coords = adaptor.getIndices()[0];
+ } else {
+ FailureOr<SmallVector<Value>> maybeCoords =
+ extractLoadCoordsForComposite(loadOp, adaptor, rewriter);
+ if (failed(maybeCoords))
+ return failure();
auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
adaptor.getIndices().getType()[0]);
coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
- adaptor.getIndices());
- } else {
- coords = adaptor.getIndices()[0];
+ maybeCoords.value());
}
// Fetch the value out of the image.