//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Vector dialect to SPIRV dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" #include #include #include using namespace mlir; /// Returns the integer value from the first valid input element, assuming Value /// inputs are defined by a constant index ops and Attribute inputs are integer /// attributes. static uint64_t getFirstIntValue(ArrayAttr attr) { return (*attr.getAsValueRange().begin()).getZExtValue(); } /// Returns the number of bits for the given scalar/vector type. static int getNumBits(Type type) { // TODO: This does not take into account any memory layout or widening // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even // though in practice it will likely be stored as in a 4xi64 vector register. if (auto vectorType = dyn_cast(type)) return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); return type.getIntOrFloatBitWidth(); } namespace { struct VectorShapeCast final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(shapeCastOp.getType()); if (!dstType) return failure(); // If dstType is same as the source type or the vector size is 1, it can be // directly replaced by the source. if (dstType == adaptor.getSource().getType() || shapeCastOp.getResultVectorType().getNumElements() == 1) { rewriter.replaceOp(shapeCastOp, adaptor.getSource()); return success(); } // Lowering for size-n vectors when n > 1 hasn't been implemented. return failure(); } }; // Convert `vector.splat` to `vector.broadcast`. There is a path from // `vector.broadcast` to SPIRV via other patterns. struct VectorSplatToBroadcast final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(splat, splat.getType(), adaptor.getInput()); return success(); } }; struct VectorBitcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(bitcastOp.getType()); if (!dstType) return failure(); if (dstType == adaptor.getSource().getType()) { rewriter.replaceOp(bitcastOp, adaptor.getSource()); return success(); } // Check that the source and destination type have the same bitwidth. // Depending on the target environment, we may need to emulate certain // types, which can cause issue with bitcast. Type srcType = adaptor.getSource().getType(); if (getNumBits(dstType) != getNumBits(srcType)) { return rewriter.notifyMatchFailure( bitcastOp, llvm::formatv("different source ({0}) and target ({1}) bitwidth", srcType, dstType)); } rewriter.replaceOpWithNewOp(bitcastOp, dstType, adaptor.getSource()); return success(); } }; struct VectorBroadcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = getTypeConverter()->convertType(castOp.getResultVectorType()); if (!resultType) return failure(); if (isa(resultType)) { rewriter.replaceOp(castOp, adaptor.getSource()); return success(); } SmallVector source(castOp.getResultVectorType().getNumElements(), adaptor.getSource()); rewriter.replaceOpWithNewOp(castOp, resultType, source); return success(); } }; // SPIR-V does not have a concept of a poison index for certain instructions, // which creates a UB hazard when lowering from otherwise equivalent Vector // dialect instructions, because this index will be considered out-of-bounds. // To avoid this, this function implements a dynamic sanitization that returns // some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask // (presumably more efficient), and otherwise index 0 (always in-bounds). static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter, Location loc, Value dynamicIndex, int64_t kPoisonIndex, unsigned vectorSize) { if (llvm::isPowerOf2_32(vectorSize)) { Value inBoundsMask = spirv::ConstantOp::create( rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1)); return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex, inBoundsMask); } Value poisonIndex = spirv::ConstantOp::create( rewriter, loc, dynamicIndex.getType(), rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex)); Value cmpResult = spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex); return spirv::SelectOp::create( rewriter, loc, cmpResult, spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter), dynamicIndex); } struct VectorExtractOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } if (std::optional id = getConstantIntValue(extractOp.getMixedPosition()[0])) { if (id == vector::ExtractOp::kPoisonIndex) return rewriter.notifyMatchFailure( extractOp, "Static use of poison index handled elsewhere (folded to poison)"); rewriter.replaceOpWithNewOp( extractOp, dstType, adaptor.getVector(), rewriter.getI32ArrayAttr(id.value())); } else { Value sanitizedIndex = sanitizeDynamicIndex( rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0], vector::ExtractOp::kPoisonIndex, extractOp.getSourceVectorType().getNumElements()); rewriter.replaceOpWithNewOp( extractOp, dstType, adaptor.getVector(), sanitizedIndex); } return success(); } }; struct VectorExtractStridedSliceOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); uint64_t offset = getFirstIntValue(extractOp.getOffsets()); uint64_t size = getFirstIntValue(extractOp.getSizes()); uint64_t stride = getFirstIntValue(extractOp.getStrides()); if (stride != 1) return failure(); Value srcVector = adaptor.getOperands().front(); // Extract vector<1xT> case. if (isa(dstType)) { rewriter.replaceOpWithNewOp(extractOp, srcVector, offset); return success(); } SmallVector indices(size); std::iota(indices.begin(), indices.end(), offset); rewriter.replaceOpWithNewOp( extractOp, dstType, srcVector, srcVector, rewriter.getI32ArrayAttr(indices)); return success(); } }; template struct VectorFmaOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(fmaOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; struct VectorFromElementsOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = getTypeConverter()->convertType(op.getType()); if (!resultType) return failure(); OperandRange elements = op.getElements(); if (isa(resultType)) { // In the case with a single scalar operand / single-element result, // pass through the scalar. rewriter.replaceOp(op, elements[0]); return success(); } // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional // vector.from_elements cases should not need to be handled, only 1d. assert(cast(resultType).getRank() == 1); rewriter.replaceOpWithNewOp(op, resultType, elements); return success(); } }; struct VectorInsertOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (isa(insertOp.getValueToStoreType())) return rewriter.notifyMatchFailure(insertOp, "unsupported vector source"); if (!getTypeConverter()->convertType(insertOp.getDestVectorType())) return rewriter.notifyMatchFailure(insertOp, "unsupported dest vector type"); // Special case for inserting scalar values into size-1 vectors. if (insertOp.getValueToStoreType().isIntOrFloat() && insertOp.getDestVectorType().getNumElements() == 1) { rewriter.replaceOp(insertOp, adaptor.getValueToStore()); return success(); } if (std::optional id = getConstantIntValue(insertOp.getMixedPosition()[0])) { if (id == vector::InsertOp::kPoisonIndex) return rewriter.notifyMatchFailure( insertOp, "Static use of poison index handled elsewhere (folded to poison)"); rewriter.replaceOpWithNewOp( insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value()); } else { Value sanitizedIndex = sanitizeDynamicIndex( rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0], vector::InsertOp::kPoisonIndex, insertOp.getDestVectorType().getNumElements()); rewriter.replaceOpWithNewOp( insertOp, insertOp.getDest(), adaptor.getValueToStore(), sanitizedIndex); } return success(); } }; struct VectorExtractElementOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resultType = getTypeConverter()->convertType(extractOp.getType()); if (!resultType) return failure(); if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } APInt cstPos; if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) rewriter.replaceOpWithNewOp( extractOp, resultType, adaptor.getVector(), rewriter.getI32ArrayAttr({static_cast(cstPos.getSExtValue())})); else rewriter.replaceOpWithNewOp( extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); return success(); } }; struct VectorInsertElementOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type vectorType = getTypeConverter()->convertType(insertOp.getType()); if (!vectorType) return failure(); if (isa(vectorType)) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } APInt cstPos; if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) rewriter.replaceOpWithNewOp( insertOp, adaptor.getSource(), adaptor.getDest(), cstPos.getSExtValue()); else rewriter.replaceOpWithNewOp( insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), adaptor.getPosition()); return success(); } }; struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value srcVector = adaptor.getOperands().front(); Value dstVector = adaptor.getOperands().back(); uint64_t stride = getFirstIntValue(insertOp.getStrides()); if (stride != 1) return failure(); uint64_t offset = getFirstIntValue(insertOp.getOffsets()); if (isa(srcVector.getType())) { assert(!isa(dstVector.getType())); rewriter.replaceOpWithNewOp( insertOp, dstVector.getType(), srcVector, dstVector, rewriter.getI32ArrayAttr(offset)); return success(); } uint64_t totalSize = cast(dstVector.getType()).getNumElements(); uint64_t insertSize = cast(srcVector.getType()).getNumElements(); SmallVector indices(totalSize); std::iota(indices.begin(), indices.end(), 0); std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, totalSize); rewriter.replaceOpWithNewOp( insertOp, dstVector.getType(), dstVector, srcVector, rewriter.getI32ArrayAttr(indices)); return success(); } }; static SmallVector extractAllElements( vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor, VectorType srcVectorType, ConversionPatternRewriter &rewriter) { int numElements = static_cast(srcVectorType.getDimSize(0)); SmallVector values; values.reserve(numElements + (adaptor.getAcc() ? 1 : 0)); Location loc = reduceOp.getLoc(); for (int i = 0; i < numElements; ++i) { values.push_back(spirv::CompositeExtractOp::create( rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(), rewriter.getI32ArrayAttr({i}))); } if (Value acc = adaptor.getAcc()) values.push_back(acc); return values; } struct ReductionRewriteInfo { Type resultType; SmallVector extractedElements; }; FailureOr static getReductionInfo( vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) { Type resultType = typeConverter.convertType(op.getType()); if (!resultType) return failure(); auto srcVectorType = dyn_cast(adaptor.getVector().getType()); if (!srcVectorType || srcVectorType.getRank() != 1) return rewriter.notifyMatchFailure(op, "not a 1-D vector source"); SmallVector extractedElements = extractAllElements(op, adaptor, srcVectorType, rewriter); return ReductionRewriteInfo{resultType, std::move(extractedElements)}; } template struct VectorReductionPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto reductionInfo = getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter()); if (failed(reductionInfo)) return failure(); auto [resultType, extractedElements] = *reductionInfo; Location loc = reduceOp->getLoc(); Value result = extractedElements.front(); for (Value next : llvm::drop_begin(extractedElements)) { switch (reduceOp.getKind()) { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ case vector::CombiningKind::kind: \ if (llvm::isa(resultType)) { \ result = spirv::iop::create(rewriter, loc, resultType, result, next); \ } else { \ assert(llvm::isa(resultType)); \ result = spirv::fop::create(rewriter, loc, resultType, result, next); \ } \ break #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ result = fop::create(rewriter, loc, resultType, result, next); \ break INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp); INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp); INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp); INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp); case vector::CombiningKind::AND: case vector::CombiningKind::OR: case vector::CombiningKind::XOR: return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); default: return rewriter.notifyMatchFailure(reduceOp, "not handled here"); } #undef INT_AND_FLOAT_CASE #undef INT_OR_FLOAT_CASE } rewriter.replaceOp(reduceOp, result); return success(); } }; template struct VectorReductionFloatMinMax final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto reductionInfo = getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter()); if (failed(reductionInfo)) return failure(); auto [resultType, extractedElements] = *reductionInfo; Location loc = reduceOp->getLoc(); Value result = extractedElements.front(); for (Value next : llvm::drop_begin(extractedElements)) { switch (reduceOp.getKind()) { #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ result = fop::create(rewriter, loc, resultType, result, next); \ break INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp); INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp); default: return rewriter.notifyMatchFailure(reduceOp, "not handled here"); } #undef INT_OR_FLOAT_CASE } rewriter.replaceOp(reduceOp, result); return success(); } }; class VectorScalarBroadcastPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (isa(op.getSourceType())) { return rewriter.notifyMatchFailure( op, "only conversion of 'broadcast from scalar' is supported"); } Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); if (isa(dstType)) { rewriter.replaceOp(op, adaptor.getSource()); } else { auto dstVecType = cast(dstType); SmallVector source(dstVecType.getNumElements(), adaptor.getSource()); rewriter.replaceOpWithNewOp(op, dstType, source); } return success(); } }; struct VectorShuffleOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType oldResultType = shuffleOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(shuffleOp, "unsupported result vector type"); auto mask = llvm::to_vector_of(shuffleOp.getMask()); VectorType oldV1Type = shuffleOp.getV1VectorType(); VectorType oldV2Type = shuffleOp.getV2VectorType(); // When both operands and the result are SPIR-V vectors, emit a SPIR-V // shuffle. if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 && oldResultType.getNumElements() > 1) { rewriter.replaceOpWithNewOp( shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), rewriter.getI32ArrayAttr(mask)); return success(); } // When at least one of the operands or the result becomes a scalar after // type conversion for SPIR-V, extract all the required elements and // construct the result vector. auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( Value scalarOrVec, int32_t idx) -> Value { if (auto vecTy = dyn_cast(scalarOrVec.getType())) return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec, idx); assert(idx == 0 && "Invalid scalar element index"); return scalarOrVec; }; int32_t numV1Elems = oldV1Type.getNumElements(); SmallVector newOperands(mask.size()); for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) { Value vec = adaptor.getV1(); int32_t elementIdx = shuffleIdx; if (elementIdx >= numV1Elems) { vec = adaptor.getV2(); elementIdx -= numV1Elems; } newOperand = getElementAtIdx(vec, elementIdx); } // Handle the scalar result corner case. if (newOperands.size() == 1) { rewriter.replaceOp(shuffleOp, newOperands.front()); return success(); } rewriter.replaceOpWithNewOp( shuffleOp, newResultType, newOperands); return success(); } }; struct VectorInterleaveOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check the result vector type. VectorType oldResultType = interleaveOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(interleaveOp, "unsupported result vector type"); // Interleave the indices. VectorType sourceType = interleaveOp.getSourceVectorType(); int n = sourceType.getNumElements(); // Input vectors of size 1 are converted to scalars by the type converter. // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to // use `spirv::CompositeConstructOp`. if (n == 1) { Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()}; rewriter.replaceOpWithNewOp( interleaveOp, newResultType, newOperands); return success(); } auto seq = llvm::seq(2 * n); auto indices = llvm::map_to_vector( seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }); // Emit a SPIR-V shuffle. rewriter.replaceOpWithNewOp( interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(), rewriter.getI32ArrayAttr(indices)); return success(); } }; struct VectorDeinterleaveOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check the result vector type. VectorType oldResultType = deinterleaveOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(deinterleaveOp, "unsupported result vector type"); Location loc = deinterleaveOp->getLoc(); // Deinterleave the indices. Value sourceVector = adaptor.getSource(); VectorType sourceType = deinterleaveOp.getSourceVectorType(); int n = sourceType.getNumElements(); // Output vectors of size 1 are converted to scalars by the type converter. // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to // use `spirv::CompositeExtractOp`. if (n == 2) { auto elem0 = spirv::CompositeExtractOp::create( rewriter, loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0})); auto elem1 = spirv::CompositeExtractOp::create( rewriter, loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1})); rewriter.replaceOp(deinterleaveOp, {elem0, elem1}); return success(); } // Indices for `shuffleEven` (result 0). auto seqEven = llvm::seq(n / 2); auto indicesEven = llvm::map_to_vector(seqEven, [](int i) { return i * 2; }); // Indices for `shuffleOdd` (result 1). auto seqOdd = llvm::seq(n / 2); auto indicesOdd = llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; }); // Create two SPIR-V shuffles. auto shuffleEven = spirv::VectorShuffleOp::create( rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesEven)); auto shuffleOdd = spirv::VectorShuffleOp::create( rewriter, loc, newResultType, sourceVector, sourceVector, rewriter.getI32ArrayAttr(indicesOdd)); rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd}); return success(); } }; struct VectorLoadOpConverter final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto memrefType = loadOp.getMemRefType(); auto attr = dyn_cast_or_null(memrefType.getMemorySpace()); if (!attr) return rewriter.notifyMatchFailure( loadOp, "expected spirv.storage_class memory space"); const auto &typeConverter = *getTypeConverter(); auto loc = loadOp.getLoc(); Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(), adaptor.getIndices(), loc, rewriter); if (!accessChain) return rewriter.notifyMatchFailure( loadOp, "failed to get memref element pointer"); spirv::StorageClass storageClass = attr.getValue(); auto vectorType = loadOp.getVectorType(); // Use the converted vector type instead of original (single element vector // would get converted to scalar). auto spirvVectorType = typeConverter.convertType(vectorType); if (!spirvVectorType) return rewriter.notifyMatchFailure(loadOp, "unsupported vector type"); auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass); // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. Value castedAccessChain = (vectorType.getNumElements() == 1) ? accessChain : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, accessChain); rewriter.replaceOpWithNewOp(loadOp, spirvVectorType, castedAccessChain); return success(); } }; struct VectorStoreOpConverter final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto memrefType = storeOp.getMemRefType(); auto attr = dyn_cast_or_null(memrefType.getMemorySpace()); if (!attr) return rewriter.notifyMatchFailure( storeOp, "expected spirv.storage_class memory space"); const auto &typeConverter = *getTypeConverter(); auto loc = storeOp.getLoc(); Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(), adaptor.getIndices(), loc, rewriter); if (!accessChain) return rewriter.notifyMatchFailure( storeOp, "failed to get memref element pointer"); spirv::StorageClass storageClass = attr.getValue(); auto vectorType = storeOp.getVectorType(); auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. Value castedAccessChain = (vectorType.getNumElements() == 1) ? accessChain : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, accessChain); rewriter.replaceOpWithNewOp(storeOp, castedAccessChain, adaptor.getValueToStore()); return success(); } }; struct VectorReductionToIntDotProd final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { if (op.getKind() != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure(op, "combining kind is not 'add'"); auto resultType = dyn_cast(op.getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "result is not an integer"); int64_t resultBitwidth = resultType.getIntOrFloatBitWidth(); if (!llvm::is_contained({32, 64}, resultBitwidth)) return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth"); VectorType inVecTy = op.getSourceVectorType(); if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) || inVecTy.getShape().size() != 1 || inVecTy.isScalable()) return rewriter.notifyMatchFailure(op, "unsupported vector shape"); auto mul = op.getVector().getDefiningOp(); if (!mul) return rewriter.notifyMatchFailure( op, "reduction operand is not 'arith.muli'"); if (succeeded(handleCase(op, mul, rewriter))) return success(); if (succeeded(handleCase(op, mul, rewriter))) return success(); if (succeeded(handleCase(op, mul, rewriter))) return success(); if (succeeded(handleCase(op, mul, rewriter))) return success(); return failure(); } private: template static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul, PatternRewriter &rewriter) { auto lhs = mul.getLhs().getDefiningOp(); if (!lhs) return failure(); Value lhsIn = lhs.getIn(); auto lhsInType = cast(lhsIn.getType()); if (!lhsInType.getElementType().isInteger(8)) return failure(); auto rhs = mul.getRhs().getDefiningOp(); if (!rhs) return failure(); Value rhsIn = rhs.getIn(); auto rhsInType = cast(rhsIn.getType()); if (!rhsInType.getElementType().isInteger(8)) return failure(); if (op.getSourceVectorType().getNumElements() == 3) { IntegerType i8Type = rewriter.getI8Type(); auto v4i8Type = VectorType::get({4}, i8Type); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter); lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, ValueRange{lhsIn, zero}); rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type, ValueRange{rhsIn, zero}); } // There's no variant of dot prod ops for unsigned LHS and signed RHS, so // we have to swap operands instead in that case. if (SwapOperands) std::swap(lhsIn, rhsIn); if (Value acc = op.getAcc()) { rewriter.replaceOpWithNewOp(op, op.getType(), lhsIn, rhsIn, acc, nullptr); } else { rewriter.replaceOpWithNewOp(op, op.getType(), lhsIn, rhsIn, nullptr); } return success(); } }; struct VectorReductionToFPDotProd final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (op.getKind() != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure(op, "combining kind is not 'add'"); auto resultType = getTypeConverter()->convertType(op.getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "result is not a float"); Value vec = adaptor.getVector(); Value acc = adaptor.getAcc(); auto vectorType = dyn_cast(vec.getType()); if (!vectorType) { assert(isa(vec.getType()) && "Expected the vector to be scalarized"); if (acc) { rewriter.replaceOpWithNewOp(op, acc, vec); return success(); } rewriter.replaceOp(op, vec); return success(); } Location loc = op.getLoc(); Value lhs; Value rhs; if (auto mul = vec.getDefiningOp()) { lhs = mul.getLhs(); rhs = mul.getRhs(); } else { // If the operand is not a mul, use a vector of ones for the dot operand // to just sum up all values. lhs = vec; Attribute oneAttr = rewriter.getFloatAttr(vectorType.getElementType(), 1.0); oneAttr = SplatElementsAttr::get(vectorType, oneAttr); rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr); } assert(lhs); assert(rhs); Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs); if (acc) res = spirv::FAddOp::create(rewriter, loc, acc, res); rewriter.replaceOp(op, res); return success(); } }; struct VectorStepOpConvert final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const auto &typeConverter = *getTypeConverter(); Type dstType = typeConverter.convertType(stepOp.getType()); if (!dstType) return failure(); Location loc = stepOp.getLoc(); int64_t numElements = stepOp.getType().getNumElements(); auto intType = rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth()); // Input vectors of size 1 are converted to scalars by the type converter. // We just create a constant in this case. if (numElements == 1) { Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter); rewriter.replaceOp(stepOp, zero); return success(); } SmallVector source; source.reserve(numElements); for (int64_t i = 0; i < numElements; ++i) { Attribute intAttr = rewriter.getIntegerAttr(intType, i); Value constOp = spirv::ConstantOp::create(rewriter, loc, intType, intAttr); source.push_back(constOp); } rewriter.replaceOpWithNewOp(stepOp, dstType, source); return success(); } }; struct VectorToElementOpConvert final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector results(toElementsOp->getNumResults()); Location loc = toElementsOp.getLoc(); // Input vectors of size 1 are converted to scalars by the type converter. // We cannot use `spirv::CompositeExtractOp` directly in this case. // For a scalar source, the result is just the scalar itself. if (isa(adaptor.getSource().getType())) { results[0] = adaptor.getSource(); rewriter.replaceOp(toElementsOp, results); return success(); } Type srcElementType = toElementsOp.getElements().getType().front(); Type elementType = getTypeConverter()->convertType(srcElementType); if (!elementType) return rewriter.notifyMatchFailure( toElementsOp, llvm::formatv("failed to convert element type '{0}' to SPIR-V", srcElementType)); for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) { // Create an CompositeExtract operation only for results that are not // dead. if (element.use_empty()) continue; Value result = spirv::CompositeExtractOp::create( rewriter, loc, elementType, adaptor.getSource(), rewriter.getI32ArrayAttr({static_cast(idx)})); results[idx] = result; } rewriter.replaceOp(toElementsOp, results); return success(); } }; } // namespace #define CL_INT_MAX_MIN_OPS \ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp #define GL_INT_MAX_MIN_OPS \ spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< VectorBitcastConvert, VectorBroadcastConvert, VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, VectorFmaOpConvert, VectorFromElementsOpConvert, VectorToElementOpConvert, VectorInsertElementOpConvert, VectorInsertOpConvert, VectorReductionPattern, VectorReductionPattern, VectorReductionFloatMinMax, VectorReductionFloatMinMax, VectorShapeCast, VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, VectorInterleaveOpConvert, VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern, VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>( typeConverter, patterns.getContext(), PatternBenefit(1)); // Make sure that the more specialized dot product pattern has higher benefit // than the generic one that extracts all elements. patterns.add(typeConverter, patterns.getContext(), PatternBenefit(2)); } void mlir::populateVectorReductionToSPIRVDotProductPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }