diff options
Diffstat (limited to 'mlir/lib/Dialect')
17 files changed, 850 insertions, 127 deletions
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 68990ef..d9c097c 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } +/// Returns stride expressed in number of bytes for the given `elementStride` +/// stride encoded in number of elements of the type `mType`. +static Value computeStrideInBytes(Location loc, MemRefType mType, + Value elementStride, RewriterBase &rewriter) { + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) + .getResult(); +} + /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -static Value getStride(Location loc, MemRefType mType, Value base, - RewriterBase &rewriter) { +static Value inferStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = rewriter.getIntegerType(64); @@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base, if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); + return computeStrideInBytes( + loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); @@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands, return getTileSizes(getLoc(), getTileType(), rewriter); } -LogicalResult amx::TileLoadOp::verify() { - MemRefType memrefTy = getMemRefType(); +template <typename OpTy, + typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> || + std::is_same_v<OpTy, amx::TileStoreOp>>> +static LogicalResult tileTransferVerifier(OpTy op) { + MemRefType memrefTy = op.getMemRefType(); unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); + if (op.getIndices().size() != rank) + return op.emitOpError("requires ") << rank << " indices"; + + if (failed(verifyTileSize(op, op.getTileType()))) + return failure(); + + // Validate basic buffer properties when the stride is implicit. + if (!op.getStride()) { + if (rank < 2) + return op.emitOpError("requires at least 2D memref"); + SmallVector<int64_t> strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return op.emitOpError("requires memref with unit innermost stride"); + } + + return success(); +} + +void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res, + Value base, ValueRange indices) { + build(builder, state, res, base, indices, /*stride=*/nullptr); } +LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); return intrinsicOperands; } -LogicalResult amx::TileStoreOp::verify() { - MemRefType memrefTy = getMemRefType(); - unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); +void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange indices, Value val) { + build(builder, state, base, indices, val, /*stride=*/nullptr); } +LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); intrinsicOperands.push_back(adaptor.getVal()); return intrinsicOperands; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index 624519f..70faa71 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -64,12 +64,13 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { module.walk([&](func::CallOp callOp) { if (func::FuncOp calledFunc = dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { - callerMap[calledFunc].insert(callOp); + if (!calledFunc.isPublic() && !calledFunc.isExternal()) + callerMap[calledFunc].insert(callOp); } }); for (auto funcOp : module.getOps<func::FuncOp>()) { - if (funcOp.isExternal()) + if (funcOp.isExternal() || funcOp.isPublic()) continue; func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); // TODO: Support functions with multiple blocks. diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ec581ac..cc66fac 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect IR/LLVMMemorySlot.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp + IR/LLVMDialectBytecode.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR DEPENDS + MLIRLLVMDialectBytecodeIncGen MLIRLLVMOpsIncGen MLIRLLVMTypesIncGen MLIRLLVMIntrinsicOpsIncGen diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5d08ccc..7ca09d9 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -29,6 +29,8 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/Error.h" +#include "LLVMDialectBytecode.h" + #include <numeric> #include <optional> @@ -4237,6 +4239,7 @@ void LLVMDialect::initialize() { // Support unknown operations because not all LLVM operations are registered. allowUnknownOperations(); declarePromisedInterface<DialectInlinerInterface, LLVMDialect>(); + detail::addBytecodeInterface(this); } #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp new file mode 100644 index 0000000..41d1f80 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp @@ -0,0 +1,154 @@ +//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMDialectBytecode.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <type_traits> + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +// Provide some forward declarations of the functions that will be generated by +// the include below. +static void write(DIExpressionElemAttr attribute, + DialectBytecodeWriter &writer); +static LogicalResult writeAttribute(Attribute attribute, + DialectBytecodeWriter &writer); + +//===--------------------------------------------------------------------===// +// Optional ArrayRefs +// +// Note that both the writer and reader functions consider attributes to be +// optional. This is because the attribute may be present or empty. +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalArrayRef(DialectBytecodeWriter &writer, + ArrayRef<EntryTy> storage) { + if (storage.empty()) { + writer.writeOwnedBool(false); + return; + } + + writer.writeOwnedBool(true); + writer.writeList(storage, [&](EntryTy val) { + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + (void)writer.writeOptionalAttribute(val); + } else if constexpr (std::is_integral_v<EntryTy>) { + (void)writer.writeVarInt(val); + } else { + static_assert(true, "EntryTy not supported"); + } + }); +} + +template <class EntryTy> +static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader, + SmallVectorImpl<EntryTy> &storage) { + bool isPresent = false; + if (failed(reader.readBool(isPresent))) + return failure(); + // Nothing to do here, the array is empty. + if (!isPresent) + return success(); + + auto readEntry = [&]() -> FailureOr<EntryTy> { + EntryTy temp; + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + if (succeeded(reader.readOptionalAttribute(temp))) + return temp; + } else if constexpr (std::is_integral_v<EntryTy>) { + if (succeeded(reader.readVarInt(temp))) + return temp; + } else { + static_assert(true, "EntryTy not supported"); + } + return failure(); + }; + + return reader.readList(storage, readEntry); +} + +//===--------------------------------------------------------------------===// +// Optional integral types +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalInt(DialectBytecodeWriter &writer, + std::optional<EntryTy> storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + EntryTy val = storage.value_or(0); + writer.writeVarIntWithFlag(val, storage.has_value()); +} + +template <class EntryTy> +static LogicalResult readOptionalInt(DialectBytecodeReader &reader, + std::optional<EntryTy> &storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + uint64_t result = 0; + bool flag = false; + if (failed(reader.readVarIntWithFlag(result, flag))) + return failure(); + if (flag) + storage = static_cast<EntryTy>(result); + else + storage = std::nullopt; + return success(); +} + +//===--------------------------------------------------------------------===// +// Tablegen generated bytecode functions +//===--------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc" + +//===--------------------------------------------------------------------===// +// LLVMDialectBytecodeInterface +//===--------------------------------------------------------------------===// + +/// This class implements the bytecode interface for the LLVM dialect. +struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface { + LLVMDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + // Attributes + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } + + // Types + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } +}; +} // namespace + +void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) { + dialect->addInterfaces<LLVMDialectBytecodeInterface>(); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h new file mode 100644 index 0000000..1a17cb4 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h @@ -0,0 +1,27 @@ +//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the LLVM dialect bytecode +// implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H +#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H + +namespace mlir::LLVM { +class LLVMDialect; + +namespace detail { +/// Add the interfaces necessary for encoding the LLVM dialect components in +/// bytecode. +void addBytecodeInterface(LLVMDialect *dialect); +} // namespace detail +} // namespace mlir::LLVM + +#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 5edcc40b..ab54183 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF4x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f32x2 to f4x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2047,6 +2058,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair +ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getA())); + args.push_back(mt.lookupValue(op.getB())); + + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = + hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite; + + return {intId, std::move(args)}; +} + #define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index c477c6c..dcc1ef9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody( Value yielded = getSourceSkipUnary(terminator->getOperand(0)); Operation *reductionOp = yielded.getDefiningOp(); - if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + if (!reductionOp || reductionOp->getNumResults() != 1 || + reductionOp->getNumOperands() != 2) { errs << "expected reduction op to be binary"; return false; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 59013a2..cbc565b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() { SmallVector<int64_t> PackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getDestType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index dd9b4c2..d8f983f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( // FuseOp //===----------------------------------------------------------------------===// +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, loopTypes, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + // Loop types are automaticaly splat by the callee, setting up one is + // enough. + SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>()); + build(builder, result, loopTypes, target, mixedTileSizes, + mixedTileInterchange, applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + SmallVector<int64_t> staticTileSizes; + SmallVector<Value> dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); + SmallVector<int64_t> staticTileInterchange; + SmallVector<Value> dynamicTileInterchange; + dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange, + staticTileInterchange); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + auto staticTileInterchangeAttr = + builder.getDenseI64ArrayAttr(staticTileInterchange); + unsigned numExpectedLoops = + useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0); + SmallVector<Type> resultTypes; + resultTypes.reserve(numExpectedLoops); + assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && + "expected one loop type or as many as loops"); + if (loopTypes.size() == 1) + resultTypes.append(numExpectedLoops, loopTypes[0]); + else + llvm::append_range(resultTypes, loopTypes); + build(builder, result, /*transformed=*/target.getType(), + /*loops=*/resultTypes, + /*target=*/target, + /*tile_sizes=*/dynamicTileSizes, + /*tile_interchange=*/dynamicTileInterchange, + /*static_tile_sizes=*/staticTileSizesAttr, + /*static_tile_interchange=*/staticTileInterchangeAttr, + /*apply_cleanup=*/applyCleanup, + /*use_forall=*/useForall); +} + /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template <typename Range> @@ -630,13 +710,25 @@ DiagnosedSilenceableFailure transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - SmallVector<int64_t> tileSizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - SmallVector<int64_t> tileInterchange = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); + auto transformOp = cast<TransformOpInterface>(getOperation()); + + SmallVector<int64_t> tileSizes; + DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileSizes(), tileSizes); + if (!status.succeeded()) + return status; + SmallVector<int64_t> tileInterchange; + status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileInterchange(), tileInterchange); + if (!status.succeeded()) + return status; scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; + bool useForall = getUseForall(); + tilingOptions.setLoopType(useForall + ? scf::SCFTilingOptions::LoopType::ForallOp + : scf::SCFTilingOptions::LoopType::ForOp); SmallVector<OpFoldResult> tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); @@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } + size_t numLoops = + useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); LogicalResult result = applyTilingToAll( - rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizes.size() - llvm::count(tileSizes, 0), transformResults, + rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, + transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr<scf::SCFTileAndFuseResult> { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { - SmallVector<int64_t> permutation = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); - auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); - if (!std::is_permutation(sequence.begin(), sequence.end(), - permutation.begin(), permutation.end())) { - return emitOpError() << "expects interchange to be a permutation, found " - << getTileInterchange(); + auto iterspace_rank = getStaticTileSizes().size(); + ArrayRef<int64_t> permutation = getStaticTileInterchange(); + if (permutation.size() > iterspace_rank) + return emitOpError() + << "interchange length exceeds iteration space dimensions (" + << iterspace_rank << "), found " << getTileInterchange(); + SmallVector<bool> seen(iterspace_rank, false); + for (int64_t v : permutation) { + if (!ShapedType::isDynamic(v)) { + if (v < 0 || v >= static_cast<int64_t>(iterspace_rank)) + return emitOpError() << "expects interchange values to be in range [0, " + << iterspace_rank << "), found: " << v; + if (seen[v]) + return emitOpError() << "found duplicate interchange value: " << v; + seen[v] = true; + } } - SmallVector<int64_t> sizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); + ArrayRef<int64_t> sizes = getStaticTileSizes(); + size_t numExpectedLoops = + getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; return success(); } +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() { + return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext()); +} + +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() { + return getMixedValues(getStaticTileInterchange(), getTileInterchange(), + getContext()); +} + +void transform::FuseOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getTileInterchangeMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 0dac688..eb2d825 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { - // TODO: support the case that outer dimensions are not all 1s. A - // tensor.expand_shape will be generated in this case. - if (llvm::any_of(packOp.getAllOuterDims(), + if (llvm::any_of(packOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "not all outer dimensions of the result are 1s"); } + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + + // Verify that there are no: + // * non-unit + un-tiled-outer-dims, + // that are permuted. Supporting such cases would require refining the logic + // that generates the Transpose Op. + if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) { + static int prev = 0; + // Skip tiled dims - these can be permuted. + if (llvm::is_contained(innerDimsPos, dim)) + return true; + + // Check whether this dim has been permuted. Permuting unit dims is fine + // as that's effectively a no-op. + if (dim < prev && (packOp.getType().getShape()[prev] != 1 || + packOp.getType().getShape()[dim] != 1)) + return false; + + prev = dim; + return true; + })) { + return rewriter.notifyMatchFailure( + packOp, "At least one non-unit and un-tiled outer dim is permuted, " + "this is not supported ATM!"); + } + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); - int64_t numberOfTiles = innerDimsPos.size(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // - All outer dims are 1 - the corresponding transposition order doesn't - // matter, but requires all dim indices to be present. + // - All tiled outer dims are 1 - the corresponding transposition order + // doesn't matter, but requires all dim indices to be present. + // - Un-tiled outer dims remain un-permuted. - // 2.1 Get the permutation for linalg.transpose + // 2.1 Get the permutation for linalg.transpose: + // [ untiled-dims, inner-dims-pos ] + // Note, this logic assumes that the untiled dims are not permuted. SmallVector<int64_t> srcPermForTranspose; for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the @@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); - // 2.2 Create the init tensor for linalg.transpose with the correct shape - SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles, - oneIdxAttr); + // 2.2 Create the init tensor for linalg.transpose with the correct shape: + // [ untiled-dims, tiled-dims ] + ShapedType inputTy = cast<ShapedType>(input.getType()); + SmallVector<OpFoldResult> shapeForEmptyOp; + for (int64_t i = 0; i < srcRank; i++) { + if (llvm::is_contained(innerDimsPos, i)) { + // The tiled dims are appended after this loop. + continue; + } + if (inputTy.isStaticDim(i)) + shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i])); + else + shapeForEmptyOp.emplace_back( + tensor::DimOp::create(rewriter, loc, input, i).getResult()); + } shapeForEmptyOp.append(packOp.getMixedTiles()); // getMixedTiles() may contain Values pointing to constant ops, not the @@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); - // 3. Insert the inner tile to the destination: + // 3. Insert the inner tile into the destination tensor: // %inserted_tile = tensor.insert_slice(%transposed_tile) - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - // Outer dims are all 1s! - SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr); - SmallVector<int64_t> writeShape; + + // Compute the sizes attribute: + // [ outer-dims, tile-sizes ] + // Note that the output from the transpose Op excludes the tiled outer dims. + // However, given the assumption that: + // * all tiled outer dims == 1, + // we can just use a rank-expanding tensor.insert_slice. + SmallVector<OpFoldResult> writeSizes; + for (auto size : packOp.getAllOuterDims()) { + writeSizes.push_back(rewriter.getIndexAttr(size)); + } for (auto tileSize : packOp.getMixedTiles()) { - auto [tileSizeStatic, tileSizeOfr] = + auto [_, tileSizeOfr] = getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); writeSizes.push_back(tileSizeOfr); - writeShape.push_back(tileSizeStatic); } - // 4. Replace tensor.packOp with tensor.insert_slice created above + // TODO: Add a constructor for tensor.insert_slice that doesn't require + // strides nor offsets. + SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); + auto insert = tensor::InsertSliceOp::create( rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); + + // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index e25a012..1382c7ac 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR DEPENDS MLIRMemRefOpsIncGen @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda..507597b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } +void SubViewOp::inferStridedMetadataRanges( + ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange, + SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { + auto isUninitialized = + +[](IntegerValueRange range) { return range.isUninitialized(); }; + + // Bail early if any of the operands metadata is not ready: + SmallVector<IntegerValueRange> offsetOperands = + getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); + if (llvm::any_of(offsetOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> sizeOperands = + getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); + if (llvm::any_of(sizeOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> stridesOperands = + getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); + if (llvm::any_of(stridesOperands, isUninitialized)) + return; + + StridedMetadataRange sourceRange = + ranges[getSourceMutable().getOperandNumber()]; + if (sourceRange.isUninitialized()) + return; + + ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides(); + + // Get the dropped dims. + llvm::SmallBitVector droppedDims = getDroppedDims(); + + // Compute the new offset, strides and sizes. + ConstantIntRanges offset = sourceRange.getOffsets()[0]; + SmallVector<ConstantIntRanges> strides, sizes; + + for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { + bool dropped = droppedDims.test(i); + // Compute the new offset. + ConstantIntRanges off = + intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); + offset = intrange::inferAdd({offset, off}); + + // Skip dropped dimensions. + if (dropped) + continue; + // Multiply the strides. + strides.push_back( + intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); + // Get the sizes. + sizes.push_back(sizeOperands[i].getValue()); + } + + setMetadata(getResult(), + StridedMetadataRange::getRanked( + SmallVector<ConstantIntRanges>({std::move(offset)}), + std::move(sizes), std::move(strides))); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6564a4e..642ced9 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" @@ -74,14 +75,16 @@ struct MemRefPointerLikeModel } mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, - StringRef varName, Type varType, - Value originalVar) const { + StringRef varName, Type varType, Value originalVar, + bool &needsFree) const { auto memrefTy = cast<MemRefType>(pointer); // Check if this is a static memref (all dimensions are known) - if yes // then we can generate an alloca operation. - if (memrefTy.hasStaticShape()) + if (memrefTy.hasStaticShape()) { + needsFree = false; // alloca doesn't need deallocation return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + } // For dynamic memrefs, extract sizes from the original variable if // provided. Otherwise they cannot be handled. @@ -99,6 +102,7 @@ struct MemRefPointerLikeModel // Note: We only add dynamic sizes to the dynamicSizes array // Static dimensions are handled automatically by AllocOp } + needsFree = true; // alloc needs deallocation return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) .getResult(); } @@ -108,10 +112,14 @@ struct MemRefPointerLikeModel } bool genFree(Type pointer, OpBuilder &builder, Location loc, - TypedValue<PointerLikeType> varPtr, Type varType) const { - if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + TypedValue<PointerLikeType> varToFree, Value allocRes, + Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) { + // Use allocRes if provided to determine the allocation type + Value valueToInspect = allocRes ? allocRes : memrefValue; + // Walk through casts to find the original allocation - Value currentValue = memrefValue; + Value currentValue = valueToInspect; Operation *originalAlloc = nullptr; // Follow the chain of operations to find the original allocation @@ -150,7 +158,7 @@ struct MemRefPointerLikeModel return true; } if (isa<memref::AllocOp>(originalAlloc)) { - // This is an alloc - generate dealloc + // This is an alloc - generate dealloc on varToFree memref::DeallocOp::create(builder, loc, memrefValue); return true; } @@ -1003,6 +1011,142 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +//===----------------------------------------------------------------------===// +// Recipe Region Helpers +//===----------------------------------------------------------------------===// + +/// Create and populate an init region for privatization recipes. +/// Returns the init block on success, or nullptr on failure. +/// Sets needsFree to indicate if the allocated memory requires deallocation. +static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc, + Type varType, StringRef varName, + ValueRange bounds, + bool &needsFree) { + // Create init block with arguments: original value + bounds + SmallVector<Type> argTypes{varType}; + SmallVector<Location> argLocs{loc}; + for (Value bound : bounds) { + argTypes.push_back(bound.getType()); + argLocs.push_back(loc); + } + + auto initBlock = std::make_unique<Block>(); + initBlock->addArguments(argTypes, argLocs); + builder.setInsertionPointToStart(initBlock.get()); + + Value privatizedValue; + + // Get the block argument that represents the original variable + Value blockArgVar = initBlock->getArgument(0); + + // Generate init region body based on variable type + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + auto typedVar = cast<TypedValue<MappableType>>(blockArgVar); + privatizedValue = mappableTy.generatePrivateInit( + builder, loc, typedVar, varName, bounds, {}, needsFree); + if (!privatizedValue) + return nullptr; + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + // Use PointerLikeType's allocation API with the block argument + privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType, + blockArgVar, needsFree); + if (!privatizedValue) + return nullptr; + } + + // Add yield operation to init block + acc::YieldOp::create(builder, loc, privatizedValue); + + return initBlock; +} + +/// Create and populate a copy region for firstprivate recipes. +/// Returns the copy block on success, or nullptr on failure. +/// TODO: Handle MappableType - it does not yet have a copy API. +static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc, + Type varType, + ValueRange bounds) { + // Create copy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> copyArgTypes{varType, varType}; + SmallVector<Location> copyArgLocs{loc, loc}; + for (Value bound : bounds) { + copyArgTypes.push_back(bound.getType()); + copyArgLocs.push_back(loc); + } + + auto copyBlock = std::make_unique<Block>(); + copyBlock->addArguments(copyArgTypes, copyArgLocs); + builder.setInsertionPointToStart(copyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a copy API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + // Generate copy region body based on variable type + if (isPointerLike) { + auto pointerLikeTy = cast<PointerLikeType>(varType); + Value originalArg = copyBlock->getArgument(0); + Value privatizedArg = copyBlock->getArgument(1); + + // Generate copy operation using PointerLikeType interface + if (!pointerLikeTy.genCopy( + builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg), + cast<TypedValue<PointerLikeType>>(originalArg), varType)) + return nullptr; + } + + // Add terminator to copy block + acc::TerminatorOp::create(builder, loc); + + return copyBlock; +} + +/// Create and populate a destroy region for privatization recipes. +/// Returns the destroy block on success, or nullptr if not needed. +static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder, + Location loc, Type varType, + Value allocRes, + ValueRange bounds) { + // Create destroy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> destroyArgTypes{varType, varType}; + SmallVector<Location> destroyArgLocs{loc, loc}; + for (Value bound : bounds) { + destroyArgTypes.push_back(bound.getType()); + destroyArgLocs.push_back(loc); + } + + auto destroyBlock = std::make_unique<Block>(); + destroyBlock->addArguments(destroyArgTypes, destroyArgLocs); + builder.setInsertionPointToStart(destroyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a deallocation API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + auto privatizedArg = + cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1)); + // Pass allocRes to help determine the allocation type + if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType)) + return nullptr; + + acc::TerminatorOp::create(builder, loc); + + return destroyBlock; +} + } // namespace //===----------------------------------------------------------------------===// @@ -1050,6 +1194,55 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() { return success(); } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1080,6 +1273,60 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { return success(); } +std::optional<FirstprivateRecipeOp> +FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init, copy, and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + auto copyBlock = createCopyRegion(builder, loc, varType, bounds); + if (!copyBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + recipe.getCopyRegion().push_back(copyBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // ReductionRecipeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fa97b49..ac72002 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType( sourceTensorType.getEncoding()); } +// TODO: This uses neither offsets nor strides! RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e95338f..12e6475 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern { // Some values may be yielded multiple times and correspond to multiple // results. Deduplicating occurs by taking each result with its matching // yielded value, and: - // 1. recording the unique first position at which the value is yielded. + // 1. recording the unique first position at which the value with uses is + // yielded. // 2. recording for the result, the first position at which the dedup'ed // value is yielded. // 3. skipping from the new result types / new yielded values any result // that has no use or whose yielded value has already been seen. for (OpResult result : warpOp.getResults()) { + if (result.use_empty()) + continue; Value yieldOperand = yield.getOperand(result.getResultNumber()); auto it = dedupYieldOperandPositionMap.insert( std::make_pair(yieldOperand, newResultTypes.size())); dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); - if (result.use_empty() || !it.second) + if (!it.second) continue; newResultTypes.push_back(result.getType()); newYieldValues.push_back(yieldOperand); @@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), escapingValueDistTypesElse.end()); - llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; for (auto [idx, val] : llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { - origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(val); newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); } - // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + // Replace the old `WarpOp` with the new one that has additional yield + // values and types. + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // `ifOp` returns the result of the inner warp op. SmallVector<Type> newIfOpDistResTypes; for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { @@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newIfOp = scf::IfOp::create( - rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), - static_cast<bool>(ifOp.thenBlock()), + rewriter, ifOp.getLoc(), newIfOpDistResTypes, + newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()), static_cast<bool>(ifOp.elseBlock())); auto encloseRegionInWarpOp = [&](Block *oldIfBranch, Block *newIfBranch, @@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { innerWarpInputVals.push_back( - newWarpOp.getResult(warpResRangeStart)); + newWarpOp.getResult(newIndices[warpResRangeStart])); escapeValToBlockArgIndex[escapingValues[i]] = innerWarpInputTypes.size(); innerWarpInputTypes.push_back(escapingValueInputTypes[i]); @@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newIfOp.getResult(newIdx), newIfOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `IfOp`. - for (auto [origIdx, newIdx] : origToNewYieldIdx) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `IfOp`, they should not have any uses - // at this point. - rewriter.eraseOp(ifOp); - rewriter.eraseOp(warpOp); return success(); } @@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { escapingValueDistTypes.begin(), escapingValueDistTypes.end()); // Next, we insert all non-`ForOp` yielded values and their distributed - // types. We also create a mapping between the non-`ForOp` yielded value - // index and the corresponding new `WarpOp` yield value index (needed to - // update users later). - llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping; + // types. for (auto [i, v] : llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { - nonForResultMapping[i] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); } // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. @@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // escaping values in the new `WarpOp`. SmallVector<Value> newForOpOperands; for (size_t i = 0; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); + newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); @@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { llvm::SmallDenseMap<Value, int64_t> argIndexMapping; for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { - innerWarpInput.push_back(newWarpOp.getResult(i)); + innerWarpInput.push_back(newWarpOp.getResult(newIndices[i])); argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = innerWarpInputType.size(); innerWarpInputType.push_back( @@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (!innerWarp.getResults().empty()) scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); - // Update the users of original `WarpOp` results that were coming from the + // Update the users of the new `WarpOp` results that were coming from the // original `ForOp` to the corresponding new `ForOp` result. for (auto [origIdx, newIdx] : forResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `ForOp`. - for (auto [origIdx, newIdx] : nonForResultMapping) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `ForOp`, they should not have any uses - // at this point. - rewriter.eraseOp(forOp); - rewriter.eraseOp(warpOp); // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 36c498e..f77784a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op)) return getTileShape(op->getOpResult(0)); if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, - xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op)) + xegpu::StoreMatrixOp>(op)) return getTileShape(op->getOpOperand(0)); - if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op)) + if (isa<xegpu::StoreNdOp>(op)) return getTileShape(op->getOpOperand(1)); + // Handle LoadGatherOp and StoreScatterOp (with and without offset) + if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) { + if (loadGatherOp.getOffsets()) + return getTileShape(loadGatherOp->getOpResult(0)); + else + return getTileShape(loadGatherOp->getOpOperand(0)); + } + + if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) + return getTileShape(storeScatterOp.getOffsets() + ? storeScatterOp->getOpOperand(0) + : storeScatterOp->getOpOperand(1)); + if (isa<xegpu::DpasOp>(op)) { std::optional<SmallVector<int64_t>> aTile = getTileShape(op->getOpOperand(0)); |