//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" namespace mlir { namespace x86vector { static FailureOr> inferIteratorsFromOutMap(AffineMap map) { if (!map.isProjectedPermutation()) return failure(); SmallVector iterators( map.getNumDims(), mlir::utils::IteratorType::reduction); for (auto expr : map.getResults()) if (auto dim = dyn_cast(expr)) iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel; return iterators; } // Returns true if the operation is in VNNI layout. // Optionally, the check can be constrained to a specific VNNI blocking factor. bool isInVnniLayout(Operation *op, ArrayRef indexingMaps, std::optional blockingFactor) { // Narrow down type operations - VNNI only applies to contractions. FailureOr dims = linalg::inferContractionDims(indexingMaps); if (failed(dims)) return false; auto matA = op->getOperand(0); auto matB = op->getOperand(1); auto typeA = dyn_cast(matA.getType()); auto typeB = dyn_cast(matB.getType()); unsigned rankA = typeA.getRank(); unsigned rankB = typeB.getRank(); // VNNI format requires at least 1 parallel and 2 reduction dimensions. if (rankA < 3 || rankB < 3) return false; // At least two reduction dimensions are expected: // one for the VNNI factor and one for the K dimension if (dims->k.size() < 2) return false; // Validate affine maps - VNNI computation should be defined by the two // innermost reduction iterators. // The input matrix dimensions layout must match the following: // - matrix A - [...][K/vnniFactor][vnniFactor] // - matrix B - [...][K/vnniFactor][N][vnniFactor] auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2] /* outs */); if (failed(maybeIters)) return false; SmallVector iteratorTypes = *maybeIters; AffineMap mapA = indexingMaps[0]; AffineMap mapB = indexingMaps[1]; auto vnniDimA = dyn_cast(mapA.getResult(rankA - 1)); auto vnniDimB = dyn_cast(mapB.getResult(rankB - 1)); if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || iteratorTypes[vnniDimA.getPosition()] != mlir::utils::IteratorType::reduction) return false; auto redDimA = dyn_cast(mapA.getResult(rankA - 2)); auto redDimB = dyn_cast(mapB.getResult(rankB - 3)); if (!redDimA || !redDimB || redDimA != redDimB || iteratorTypes[redDimA.getPosition()] != mlir::utils::IteratorType::reduction) return false; auto parallelDimB = dyn_cast(mapB.getResult(rankB - 2)); if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] != mlir::utils::IteratorType::parallel) return false; // VNNI factor must be: // - the innermost inputs' dimension // - statically known // - multiple of 2 or equal to the specified factor auto vnniDimSize = typeB.getShape().back(); if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 || vnniDimSize % 2 != 0) return false; if (typeA.getShape().back() != vnniDimSize) return false; if (blockingFactor && vnniDimSize != *blockingFactor) return false; // The split reduction dimension size should also match. if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3]) return false; return true; } } // namespace x86vector } // namespace mlir