diff options
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 99 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/CMakeLists.txt | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp | 154 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h | 27 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 11 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 153 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 89 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 17 |
10 files changed, 480 insertions, 76 deletions
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 68990ef..d9c097c 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } +/// Returns stride expressed in number of bytes for the given `elementStride` +/// stride encoded in number of elements of the type `mType`. +static Value computeStrideInBytes(Location loc, MemRefType mType, + Value elementStride, RewriterBase &rewriter) { + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) + .getResult(); +} + /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -static Value getStride(Location loc, MemRefType mType, Value base, - RewriterBase &rewriter) { +static Value inferStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = rewriter.getIntegerType(64); @@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base, if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); + return computeStrideInBytes( + loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); @@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands, return getTileSizes(getLoc(), getTileType(), rewriter); } -LogicalResult amx::TileLoadOp::verify() { - MemRefType memrefTy = getMemRefType(); +template <typename OpTy, + typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> || + std::is_same_v<OpTy, amx::TileStoreOp>>> +static LogicalResult tileTransferVerifier(OpTy op) { + MemRefType memrefTy = op.getMemRefType(); unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); + if (op.getIndices().size() != rank) + return op.emitOpError("requires ") << rank << " indices"; + + if (failed(verifyTileSize(op, op.getTileType()))) + return failure(); + + // Validate basic buffer properties when the stride is implicit. + if (!op.getStride()) { + if (rank < 2) + return op.emitOpError("requires at least 2D memref"); + SmallVector<int64_t> strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return op.emitOpError("requires memref with unit innermost stride"); + } + + return success(); +} + +void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res, + Value base, ValueRange indices) { + build(builder, state, res, base, indices, /*stride=*/nullptr); } +LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); return intrinsicOperands; } -LogicalResult amx::TileStoreOp::verify() { - MemRefType memrefTy = getMemRefType(); - unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); +void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange indices, Value val) { + build(builder, state, base, indices, val, /*stride=*/nullptr); } +LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); intrinsicOperands.push_back(adaptor.getVal()); return intrinsicOperands; diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ec581ac..cc66fac 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect IR/LLVMMemorySlot.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp + IR/LLVMDialectBytecode.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR DEPENDS + MLIRLLVMDialectBytecodeIncGen MLIRLLVMOpsIncGen MLIRLLVMTypesIncGen MLIRLLVMIntrinsicOpsIncGen diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5d08ccc..7ca09d9 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -29,6 +29,8 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/Error.h" +#include "LLVMDialectBytecode.h" + #include <numeric> #include <optional> @@ -4237,6 +4239,7 @@ void LLVMDialect::initialize() { // Support unknown operations because not all LLVM operations are registered. allowUnknownOperations(); declarePromisedInterface<DialectInlinerInterface, LLVMDialect>(); + detail::addBytecodeInterface(this); } #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp new file mode 100644 index 0000000..41d1f80 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp @@ -0,0 +1,154 @@ +//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMDialectBytecode.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <type_traits> + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +// Provide some forward declarations of the functions that will be generated by +// the include below. +static void write(DIExpressionElemAttr attribute, + DialectBytecodeWriter &writer); +static LogicalResult writeAttribute(Attribute attribute, + DialectBytecodeWriter &writer); + +//===--------------------------------------------------------------------===// +// Optional ArrayRefs +// +// Note that both the writer and reader functions consider attributes to be +// optional. This is because the attribute may be present or empty. +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalArrayRef(DialectBytecodeWriter &writer, + ArrayRef<EntryTy> storage) { + if (storage.empty()) { + writer.writeOwnedBool(false); + return; + } + + writer.writeOwnedBool(true); + writer.writeList(storage, [&](EntryTy val) { + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + (void)writer.writeOptionalAttribute(val); + } else if constexpr (std::is_integral_v<EntryTy>) { + (void)writer.writeVarInt(val); + } else { + static_assert(true, "EntryTy not supported"); + } + }); +} + +template <class EntryTy> +static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader, + SmallVectorImpl<EntryTy> &storage) { + bool isPresent = false; + if (failed(reader.readBool(isPresent))) + return failure(); + // Nothing to do here, the array is empty. + if (!isPresent) + return success(); + + auto readEntry = [&]() -> FailureOr<EntryTy> { + EntryTy temp; + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + if (succeeded(reader.readOptionalAttribute(temp))) + return temp; + } else if constexpr (std::is_integral_v<EntryTy>) { + if (succeeded(reader.readVarInt(temp))) + return temp; + } else { + static_assert(true, "EntryTy not supported"); + } + return failure(); + }; + + return reader.readList(storage, readEntry); +} + +//===--------------------------------------------------------------------===// +// Optional integral types +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalInt(DialectBytecodeWriter &writer, + std::optional<EntryTy> storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + EntryTy val = storage.value_or(0); + writer.writeVarIntWithFlag(val, storage.has_value()); +} + +template <class EntryTy> +static LogicalResult readOptionalInt(DialectBytecodeReader &reader, + std::optional<EntryTy> &storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + uint64_t result = 0; + bool flag = false; + if (failed(reader.readVarIntWithFlag(result, flag))) + return failure(); + if (flag) + storage = static_cast<EntryTy>(result); + else + storage = std::nullopt; + return success(); +} + +//===--------------------------------------------------------------------===// +// Tablegen generated bytecode functions +//===--------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc" + +//===--------------------------------------------------------------------===// +// LLVMDialectBytecodeInterface +//===--------------------------------------------------------------------===// + +/// This class implements the bytecode interface for the LLVM dialect. +struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface { + LLVMDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + // Attributes + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } + + // Types + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } +}; +} // namespace + +void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) { + dialect->addInterfaces<LLVMDialectBytecodeInterface>(); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h new file mode 100644 index 0000000..1a17cb4 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h @@ -0,0 +1,27 @@ +//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the LLVM dialect bytecode +// implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H +#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H + +namespace mlir::LLVM { +class LLVMDialect; + +namespace detail { +/// Add the interfaces necessary for encoding the LLVM dialect components in +/// bytecode. +void addBytecodeInterface(LLVMDialect *dialect); +} // namespace detail +} // namespace mlir::LLVM + +#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 59013a2..cbc565b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() { SmallVector<int64_t> PackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getDestType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index dd9b4c2..d8f983f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( // FuseOp //===----------------------------------------------------------------------===// +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, loopTypes, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + // Loop types are automaticaly splat by the callee, setting up one is + // enough. + SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>()); + build(builder, result, loopTypes, target, mixedTileSizes, + mixedTileInterchange, applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + SmallVector<int64_t> staticTileSizes; + SmallVector<Value> dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); + SmallVector<int64_t> staticTileInterchange; + SmallVector<Value> dynamicTileInterchange; + dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange, + staticTileInterchange); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + auto staticTileInterchangeAttr = + builder.getDenseI64ArrayAttr(staticTileInterchange); + unsigned numExpectedLoops = + useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0); + SmallVector<Type> resultTypes; + resultTypes.reserve(numExpectedLoops); + assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && + "expected one loop type or as many as loops"); + if (loopTypes.size() == 1) + resultTypes.append(numExpectedLoops, loopTypes[0]); + else + llvm::append_range(resultTypes, loopTypes); + build(builder, result, /*transformed=*/target.getType(), + /*loops=*/resultTypes, + /*target=*/target, + /*tile_sizes=*/dynamicTileSizes, + /*tile_interchange=*/dynamicTileInterchange, + /*static_tile_sizes=*/staticTileSizesAttr, + /*static_tile_interchange=*/staticTileInterchangeAttr, + /*apply_cleanup=*/applyCleanup, + /*use_forall=*/useForall); +} + /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template <typename Range> @@ -630,13 +710,25 @@ DiagnosedSilenceableFailure transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - SmallVector<int64_t> tileSizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - SmallVector<int64_t> tileInterchange = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); + auto transformOp = cast<TransformOpInterface>(getOperation()); + + SmallVector<int64_t> tileSizes; + DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileSizes(), tileSizes); + if (!status.succeeded()) + return status; + SmallVector<int64_t> tileInterchange; + status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileInterchange(), tileInterchange); + if (!status.succeeded()) + return status; scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; + bool useForall = getUseForall(); + tilingOptions.setLoopType(useForall + ? scf::SCFTilingOptions::LoopType::ForallOp + : scf::SCFTilingOptions::LoopType::ForOp); SmallVector<OpFoldResult> tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); @@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } + size_t numLoops = + useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); LogicalResult result = applyTilingToAll( - rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizes.size() - llvm::count(tileSizes, 0), transformResults, + rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, + transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr<scf::SCFTileAndFuseResult> { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { - SmallVector<int64_t> permutation = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); - auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); - if (!std::is_permutation(sequence.begin(), sequence.end(), - permutation.begin(), permutation.end())) { - return emitOpError() << "expects interchange to be a permutation, found " - << getTileInterchange(); + auto iterspace_rank = getStaticTileSizes().size(); + ArrayRef<int64_t> permutation = getStaticTileInterchange(); + if (permutation.size() > iterspace_rank) + return emitOpError() + << "interchange length exceeds iteration space dimensions (" + << iterspace_rank << "), found " << getTileInterchange(); + SmallVector<bool> seen(iterspace_rank, false); + for (int64_t v : permutation) { + if (!ShapedType::isDynamic(v)) { + if (v < 0 || v >= static_cast<int64_t>(iterspace_rank)) + return emitOpError() << "expects interchange values to be in range [0, " + << iterspace_rank << "), found: " << v; + if (seen[v]) + return emitOpError() << "found duplicate interchange value: " << v; + seen[v] = true; + } } - SmallVector<int64_t> sizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); + ArrayRef<int64_t> sizes = getStaticTileSizes(); + size_t numExpectedLoops = + getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; return success(); } +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() { + return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext()); +} + +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() { + return getMixedValues(getStaticTileInterchange(), getTileInterchange(), + getContext()); +} + +void transform::FuseOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getTileInterchangeMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 0dac688..eb2d825 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { - // TODO: support the case that outer dimensions are not all 1s. A - // tensor.expand_shape will be generated in this case. - if (llvm::any_of(packOp.getAllOuterDims(), + if (llvm::any_of(packOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "not all outer dimensions of the result are 1s"); } + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + + // Verify that there are no: + // * non-unit + un-tiled-outer-dims, + // that are permuted. Supporting such cases would require refining the logic + // that generates the Transpose Op. + if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) { + static int prev = 0; + // Skip tiled dims - these can be permuted. + if (llvm::is_contained(innerDimsPos, dim)) + return true; + + // Check whether this dim has been permuted. Permuting unit dims is fine + // as that's effectively a no-op. + if (dim < prev && (packOp.getType().getShape()[prev] != 1 || + packOp.getType().getShape()[dim] != 1)) + return false; + + prev = dim; + return true; + })) { + return rewriter.notifyMatchFailure( + packOp, "At least one non-unit and un-tiled outer dim is permuted, " + "this is not supported ATM!"); + } + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); - int64_t numberOfTiles = innerDimsPos.size(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // - All outer dims are 1 - the corresponding transposition order doesn't - // matter, but requires all dim indices to be present. + // - All tiled outer dims are 1 - the corresponding transposition order + // doesn't matter, but requires all dim indices to be present. + // - Un-tiled outer dims remain un-permuted. - // 2.1 Get the permutation for linalg.transpose + // 2.1 Get the permutation for linalg.transpose: + // [ untiled-dims, inner-dims-pos ] + // Note, this logic assumes that the untiled dims are not permuted. SmallVector<int64_t> srcPermForTranspose; for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the @@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); - // 2.2 Create the init tensor for linalg.transpose with the correct shape - SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles, - oneIdxAttr); + // 2.2 Create the init tensor for linalg.transpose with the correct shape: + // [ untiled-dims, tiled-dims ] + ShapedType inputTy = cast<ShapedType>(input.getType()); + SmallVector<OpFoldResult> shapeForEmptyOp; + for (int64_t i = 0; i < srcRank; i++) { + if (llvm::is_contained(innerDimsPos, i)) { + // The tiled dims are appended after this loop. + continue; + } + if (inputTy.isStaticDim(i)) + shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i])); + else + shapeForEmptyOp.emplace_back( + tensor::DimOp::create(rewriter, loc, input, i).getResult()); + } shapeForEmptyOp.append(packOp.getMixedTiles()); // getMixedTiles() may contain Values pointing to constant ops, not the @@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); - // 3. Insert the inner tile to the destination: + // 3. Insert the inner tile into the destination tensor: // %inserted_tile = tensor.insert_slice(%transposed_tile) - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - // Outer dims are all 1s! - SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr); - SmallVector<int64_t> writeShape; + + // Compute the sizes attribute: + // [ outer-dims, tile-sizes ] + // Note that the output from the transpose Op excludes the tiled outer dims. + // However, given the assumption that: + // * all tiled outer dims == 1, + // we can just use a rank-expanding tensor.insert_slice. + SmallVector<OpFoldResult> writeSizes; + for (auto size : packOp.getAllOuterDims()) { + writeSizes.push_back(rewriter.getIndexAttr(size)); + } for (auto tileSize : packOp.getMixedTiles()) { - auto [tileSizeStatic, tileSizeOfr] = + auto [_, tileSizeOfr] = getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); writeSizes.push_back(tileSizeOfr); - writeShape.push_back(tileSizeStatic); } - // 4. Replace tensor.packOp with tensor.insert_slice created above + // TODO: Add a constructor for tensor.insert_slice that doesn't require + // strides nor offsets. + SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); + auto insert = tensor::InsertSliceOp::create( rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); + + // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fa97b49..ac72002 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType( sourceTensorType.getEncoding()); } +// TODO: This uses neither offsets nor strides! RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 36c498e..f77784a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op)) return getTileShape(op->getOpResult(0)); if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, - xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op)) + xegpu::StoreMatrixOp>(op)) return getTileShape(op->getOpOperand(0)); - if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op)) + if (isa<xegpu::StoreNdOp>(op)) return getTileShape(op->getOpOperand(1)); + // Handle LoadGatherOp and StoreScatterOp (with and without offset) + if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) { + if (loadGatherOp.getOffsets()) + return getTileShape(loadGatherOp->getOpResult(0)); + else + return getTileShape(loadGatherOp->getOpOperand(0)); + } + + if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) + return getTileShape(storeScatterOp.getOffsets() + ? storeScatterOp->getOpOperand(0) + : storeScatterOp->getOpOperand(1)); + if (isa<xegpu::DpasOp>(op)) { std::optional<SmallVector<int64_t>> aTile = getTileShape(op->getOpOperand(0)); |