aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMX/IR/AMXDialect.cpp99
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp5
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp3
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp154
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h27
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp28
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp153
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp89
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp59
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp261
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp62
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp17
17 files changed, 850 insertions, 127 deletions
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 68990ef..d9c097c 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
+/// Returns stride expressed in number of bytes for the given `elementStride`
+/// stride encoded in number of elements of the type `mType`.
+static Value computeStrideInBytes(Location loc, MemRefType mType,
+ Value elementStride, RewriterBase &rewriter) {
+ Type llvmInt64Type = rewriter.getIntegerType(64);
+ unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
+ auto attr = rewriter.getI64IntegerAttr(bytes);
+ Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
+ return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
+ .getResult();
+}
+
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
-static Value getStride(Location loc, MemRefType mType, Value base,
- RewriterBase &rewriter) {
+static Value inferStride(Location loc, MemRefType mType, Value base,
+ RewriterBase &rewriter) {
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = rewriter.getIntegerType(64);
@@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base,
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
- auto attr = rewriter.getI64IntegerAttr(bytes);
- Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
- return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
- memrefDescriptor.stride(rewriter, loc, preLast))
- .getResult();
+ return computeStrideInBytes(
+ loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
@@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
return getTileSizes(getLoc(), getTileType(), rewriter);
}
-LogicalResult amx::TileLoadOp::verify() {
- MemRefType memrefTy = getMemRefType();
+template <typename OpTy,
+ typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
+ std::is_same_v<OpTy, amx::TileStoreOp>>>
+static LogicalResult tileTransferVerifier(OpTy op) {
+ MemRefType memrefTy = op.getMemRefType();
unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+ if (op.getIndices().size() != rank)
+ return op.emitOpError("requires ") << rank << " indices";
+
+ if (failed(verifyTileSize(op, op.getTileType())))
+ return failure();
+
+ // Validate basic buffer properties when the stride is implicit.
+ if (!op.getStride()) {
+ if (rank < 2)
+ return op.emitOpError("requires at least 2D memref");
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+ strides.back() != 1)
+ return op.emitOpError("requires memref with unit innermost stride");
+ }
+
+ return success();
+}
+
+void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
+ Value base, ValueRange indices) {
+ build(builder, state, res, base, indices, /*stride=*/nullptr);
}
+LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
return intrinsicOperands;
}
-LogicalResult amx::TileStoreOp::verify() {
- MemRefType memrefTy = getMemRefType();
- unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
+ Value base, ValueRange indices, Value val) {
+ build(builder, state, base, indices, val, /*stride=*/nullptr);
}
+LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
intrinsicOperands.push_back(adaptor.getVal());
return intrinsicOperands;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 624519f..70faa71 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -64,12 +64,13 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
module.walk([&](func::CallOp callOp) {
if (func::FuncOp calledFunc =
dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
- callerMap[calledFunc].insert(callOp);
+ if (!calledFunc.isPublic() && !calledFunc.isExternal())
+ callerMap[calledFunc].insert(callOp);
}
});
for (auto funcOp : module.getOps<func::FuncOp>()) {
- if (funcOp.isExternal())
+ if (funcOp.isExternal() || funcOp.isPublic())
continue;
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
// TODO: Support functions with multiple blocks.
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ec581ac..cc66fac 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect
IR/LLVMMemorySlot.cpp
IR/LLVMTypes.cpp
IR/LLVMTypeSyntax.cpp
+ IR/LLVMDialectBytecode.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
DEPENDS
+ MLIRLLVMDialectBytecodeIncGen
MLIRLLVMOpsIncGen
MLIRLLVMTypesIncGen
MLIRLLVMIntrinsicOpsIncGen
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5d08ccc..7ca09d9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -29,6 +29,8 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Error.h"
+#include "LLVMDialectBytecode.h"
+
#include <numeric>
#include <optional>
@@ -4237,6 +4239,7 @@ void LLVMDialect::initialize() {
// Support unknown operations because not all LLVM operations are registered.
allowUnknownOperations();
declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
+ detail::addBytecodeInterface(this);
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
new file mode 100644
index 0000000..41d1f80
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
@@ -0,0 +1,154 @@
+//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMDialectBytecode.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+
+// Provide some forward declarations of the functions that will be generated by
+// the include below.
+static void write(DIExpressionElemAttr attribute,
+ DialectBytecodeWriter &writer);
+static LogicalResult writeAttribute(Attribute attribute,
+ DialectBytecodeWriter &writer);
+
+//===--------------------------------------------------------------------===//
+// Optional ArrayRefs
+//
+// Note that both the writer and reader functions consider attributes to be
+// optional. This is because the attribute may be present or empty.
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalArrayRef(DialectBytecodeWriter &writer,
+ ArrayRef<EntryTy> storage) {
+ if (storage.empty()) {
+ writer.writeOwnedBool(false);
+ return;
+ }
+
+ writer.writeOwnedBool(true);
+ writer.writeList(storage, [&](EntryTy val) {
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ (void)writer.writeOptionalAttribute(val);
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ (void)writer.writeVarInt(val);
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ });
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader,
+ SmallVectorImpl<EntryTy> &storage) {
+ bool isPresent = false;
+ if (failed(reader.readBool(isPresent)))
+ return failure();
+ // Nothing to do here, the array is empty.
+ if (!isPresent)
+ return success();
+
+ auto readEntry = [&]() -> FailureOr<EntryTy> {
+ EntryTy temp;
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ if (succeeded(reader.readOptionalAttribute(temp)))
+ return temp;
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ if (succeeded(reader.readVarInt(temp)))
+ return temp;
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ return failure();
+ };
+
+ return reader.readList(storage, readEntry);
+}
+
+//===--------------------------------------------------------------------===//
+// Optional integral types
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalInt(DialectBytecodeWriter &writer,
+ std::optional<EntryTy> storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ EntryTy val = storage.value_or(0);
+ writer.writeVarIntWithFlag(val, storage.has_value());
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalInt(DialectBytecodeReader &reader,
+ std::optional<EntryTy> &storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ uint64_t result = 0;
+ bool flag = false;
+ if (failed(reader.readVarIntWithFlag(result, flag)))
+ return failure();
+ if (flag)
+ storage = static_cast<EntryTy>(result);
+ else
+ storage = std::nullopt;
+ return success();
+}
+
+//===--------------------------------------------------------------------===//
+// Tablegen generated bytecode functions
+//===--------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc"
+
+//===--------------------------------------------------------------------===//
+// LLVMDialectBytecodeInterface
+//===--------------------------------------------------------------------===//
+
+/// This class implements the bytecode interface for the LLVM dialect.
+struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface {
+ LLVMDialectBytecodeInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ // Attributes
+ Attribute readAttribute(DialectBytecodeReader &reader) const override {
+ return ::readAttribute(getContext(), reader);
+ }
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeAttribute(attr, writer);
+ }
+
+ // Types
+ Type readType(DialectBytecodeReader &reader) const override {
+ return ::readType(getContext(), reader);
+ }
+
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeType(type, writer);
+ }
+};
+} // namespace
+
+void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) {
+ dialect->addInterfaces<LLVMDialectBytecodeInterface>();
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
new file mode 100644
index 0000000..1a17cb4
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
@@ -0,0 +1,27 @@
+//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines hooks into the LLVM dialect bytecode
+// implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+
+namespace mlir::LLVM {
+class LLVMDialect;
+
+namespace detail {
+/// Add the interfaces necessary for encoding the LLVM dialect components in
+/// bytecode.
+void addBytecodeInterface(LLVMDialect *dialect);
+} // namespace detail
+} // namespace mlir::LLVM
+
+#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40b..ab54183 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2047,6 +2058,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c477c6c..dcc1ef9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody(
Value yielded = getSourceSkipUnary(terminator->getOperand(0));
Operation *reductionOp = yielded.getDefiningOp();
- if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
+ if (!reductionOp || reductionOp->getNumResults() != 1 ||
+ reductionOp->getNumOperands() != 2) {
errs << "expected reduction op to be binary";
return false;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a2..cbc565b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {
SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getDestType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dd9b4c2..d8f983f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
// FuseOp
//===----------------------------------------------------------------------===//
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result, loopTypes,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target, ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ // Loop types are automaticaly splat by the callee, setting up one is
+ // enough.
+ SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
+ build(builder, result, loopTypes, target, mixedTileSizes,
+ mixedTileInterchange, applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ SmallVector<int64_t> staticTileSizes;
+ SmallVector<Value> dynamicTileSizes;
+ dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
+ SmallVector<int64_t> staticTileInterchange;
+ SmallVector<Value> dynamicTileInterchange;
+ dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
+ staticTileInterchange);
+ // Call the default builder which sets up the proper operands segment sizes
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ auto staticTileInterchangeAttr =
+ builder.getDenseI64ArrayAttr(staticTileInterchange);
+ unsigned numExpectedLoops =
+ useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(numExpectedLoops);
+ assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
+ "expected one loop type or as many as loops");
+ if (loopTypes.size() == 1)
+ resultTypes.append(numExpectedLoops, loopTypes[0]);
+ else
+ llvm::append_range(resultTypes, loopTypes);
+ build(builder, result, /*transformed=*/target.getType(),
+ /*loops=*/resultTypes,
+ /*target=*/target,
+ /*tile_sizes=*/dynamicTileSizes,
+ /*tile_interchange=*/dynamicTileInterchange,
+ /*static_tile_sizes=*/staticTileSizesAttr,
+ /*static_tile_interchange=*/staticTileInterchangeAttr,
+ /*apply_cleanup=*/applyCleanup,
+ /*use_forall=*/useForall);
+}
+
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
@@ -630,13 +710,25 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
- SmallVector<int64_t> tileSizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- SmallVector<int64_t> tileInterchange =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ SmallVector<int64_t> tileSizes;
+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileSizes(), tileSizes);
+ if (!status.succeeded())
+ return status;
+ SmallVector<int64_t> tileInterchange;
+ status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileInterchange(), tileInterchange);
+ if (!status.succeeded())
+ return status;
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
+ bool useForall = getUseForall();
+ tilingOptions.setLoopType(useForall
+ ? scf::SCFTilingOptions::LoopType::ForallOp
+ : scf::SCFTilingOptions::LoopType::ForOp);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
@@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}
+ size_t numLoops =
+ useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
LogicalResult result = applyTilingToAll(
- rewriter, getOperation(), state.getPayloadOps(getTarget()),
- tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+ rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
+ transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}
LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError() << "expects interchange to be a permutation, found "
- << getTileInterchange();
+ auto iterspace_rank = getStaticTileSizes().size();
+ ArrayRef<int64_t> permutation = getStaticTileInterchange();
+ if (permutation.size() > iterspace_rank)
+ return emitOpError()
+ << "interchange length exceeds iteration space dimensions ("
+ << iterspace_rank << "), found " << getTileInterchange();
+ SmallVector<bool> seen(iterspace_rank, false);
+ for (int64_t v : permutation) {
+ if (!ShapedType::isDynamic(v)) {
+ if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
+ return emitOpError() << "expects interchange values to be in range [0, "
+ << iterspace_rank << "), found: " << v;
+ if (seen[v])
+ return emitOpError() << "found duplicate interchange value: " << v;
+ seen[v] = true;
+ }
}
- SmallVector<int64_t> sizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
+ ArrayRef<int64_t> sizes = getStaticTileSizes();
+ size_t numExpectedLoops =
+ getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
return success();
}
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
+ return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
+}
+
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
+ return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
+ getContext());
+}
+
+void transform::FuseOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getTileSizesMutable(), effects);
+ onlyReadsHandle(getTileInterchangeMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0dac688..eb2d825 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
- // TODO: support the case that outer dimensions are not all 1s. A
- // tensor.expand_shape will be generated in this case.
- if (llvm::any_of(packOp.getAllOuterDims(),
+ if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "not all outer dimensions of the result are 1s");
}
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+
+ // Verify that there are no:
+ // * non-unit + un-tiled-outer-dims,
+ // that are permuted. Supporting such cases would require refining the logic
+ // that generates the Transpose Op.
+ if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
+ static int prev = 0;
+ // Skip tiled dims - these can be permuted.
+ if (llvm::is_contained(innerDimsPos, dim))
+ return true;
+
+ // Check whether this dim has been permuted. Permuting unit dims is fine
+ // as that's effectively a no-op.
+ if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
+ packOp.getType().getShape()[dim] != 1))
+ return false;
+
+ prev = dim;
+ return true;
+ })) {
+ return rewriter.notifyMatchFailure(
+ packOp, "At least one non-unit and un-tiled outer dim is permuted, "
+ "this is not supported ATM!");
+ }
+
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
- int64_t numberOfTiles = innerDimsPos.size();
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
@@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // - All outer dims are 1 - the corresponding transposition order doesn't
- // matter, but requires all dim indices to be present.
+ // - All tiled outer dims are 1 - the corresponding transposition order
+ // doesn't matter, but requires all dim indices to be present.
+ // - Un-tiled outer dims remain un-permuted.
- // 2.1 Get the permutation for linalg.transpose
+ // 2.1 Get the permutation for linalg.transpose:
+ // [ untiled-dims, inner-dims-pos ]
+ // Note, this logic assumes that the untiled dims are not permuted.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
@@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
- // 2.2 Create the init tensor for linalg.transpose with the correct shape
- SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
- oneIdxAttr);
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape:
+ // [ untiled-dims, tiled-dims ]
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ SmallVector<OpFoldResult> shapeForEmptyOp;
+ for (int64_t i = 0; i < srcRank; i++) {
+ if (llvm::is_contained(innerDimsPos, i)) {
+ // The tiled dims are appended after this loop.
+ continue;
+ }
+ if (inputTy.isStaticDim(i))
+ shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
+ else
+ shapeForEmptyOp.emplace_back(
+ tensor::DimOp::create(rewriter, loc, input, i).getResult());
+ }
shapeForEmptyOp.append(packOp.getMixedTiles());
// getMixedTiles() may contain Values pointing to constant ops, not the
@@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
- // 3. Insert the inner tile to the destination:
+ // 3. Insert the inner tile into the destination tensor:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- // Outer dims are all 1s!
- SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
- SmallVector<int64_t> writeShape;
+
+ // Compute the sizes attribute:
+ // [ outer-dims, tile-sizes ]
+ // Note that the output from the transpose Op excludes the tiled outer dims.
+ // However, given the assumption that:
+ // * all tiled outer dims == 1,
+ // we can just use a rank-expanding tensor.insert_slice.
+ SmallVector<OpFoldResult> writeSizes;
+ for (auto size : packOp.getAllOuterDims()) {
+ writeSizes.push_back(rewriter.getIndexAttr(size));
+ }
for (auto tileSize : packOp.getMixedTiles()) {
- auto [tileSizeStatic, tileSizeOfr] =
+ auto [_, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
- writeShape.push_back(tileSizeStatic);
}
- // 4. Replace tensor.packOp with tensor.insert_slice created above
+ // TODO: Add a constructor for tensor.insert_slice that doesn't require
+ // strides nor offsets.
+ SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
+ SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
+
auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);
+
+ // 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
return success();
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index e25a012..1382c7ac 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+ ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR
DEPENDS
MLIRMemRefOpsIncGen
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRDialectUtils
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemOpInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda..507597b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
+void SubViewOp::inferStridedMetadataRanges(
+ ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
+ SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
+ auto isUninitialized =
+ +[](IntegerValueRange range) { return range.isUninitialized(); };
+
+ // Bail early if any of the operands metadata is not ready:
+ SmallVector<IntegerValueRange> offsetOperands =
+ getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
+ if (llvm::any_of(offsetOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> sizeOperands =
+ getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
+ if (llvm::any_of(sizeOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> stridesOperands =
+ getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
+ if (llvm::any_of(stridesOperands, isUninitialized))
+ return;
+
+ StridedMetadataRange sourceRange =
+ ranges[getSourceMutable().getOperandNumber()];
+ if (sourceRange.isUninitialized())
+ return;
+
+ ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
+
+ // Get the dropped dims.
+ llvm::SmallBitVector droppedDims = getDroppedDims();
+
+ // Compute the new offset, strides and sizes.
+ ConstantIntRanges offset = sourceRange.getOffsets()[0];
+ SmallVector<ConstantIntRanges> strides, sizes;
+
+ for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
+ bool dropped = droppedDims.test(i);
+ // Compute the new offset.
+ ConstantIntRanges off =
+ intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
+ offset = intrange::inferAdd({offset, off});
+
+ // Skip dropped dimensions.
+ if (dropped)
+ continue;
+ // Multiply the strides.
+ strides.push_back(
+ intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
+ // Get the sizes.
+ sizes.push_back(sizeOperands[i].getValue());
+ }
+
+ setMetadata(getResult(),
+ StridedMetadataRange::getRanked(
+ SmallVector<ConstantIntRanges>({std::move(offset)}),
+ std::move(sizes), std::move(strides)));
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6564a4e..642ced9 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
@@ -74,14 +75,16 @@ struct MemRefPointerLikeModel
}
mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
- StringRef varName, Type varType,
- Value originalVar) const {
+ StringRef varName, Type varType, Value originalVar,
+ bool &needsFree) const {
auto memrefTy = cast<MemRefType>(pointer);
// Check if this is a static memref (all dimensions are known) - if yes
// then we can generate an alloca operation.
- if (memrefTy.hasStaticShape())
+ if (memrefTy.hasStaticShape()) {
+ needsFree = false; // alloca doesn't need deallocation
return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+ }
// For dynamic memrefs, extract sizes from the original variable if
// provided. Otherwise they cannot be handled.
@@ -99,6 +102,7 @@ struct MemRefPointerLikeModel
// Note: We only add dynamic sizes to the dynamicSizes array
// Static dimensions are handled automatically by AllocOp
}
+ needsFree = true; // alloc needs deallocation
return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
.getResult();
}
@@ -108,10 +112,14 @@ struct MemRefPointerLikeModel
}
bool genFree(Type pointer, OpBuilder &builder, Location loc,
- TypedValue<PointerLikeType> varPtr, Type varType) const {
- if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ TypedValue<PointerLikeType> varToFree, Value allocRes,
+ Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
+ // Use allocRes if provided to determine the allocation type
+ Value valueToInspect = allocRes ? allocRes : memrefValue;
+
// Walk through casts to find the original allocation
- Value currentValue = memrefValue;
+ Value currentValue = valueToInspect;
Operation *originalAlloc = nullptr;
// Follow the chain of operations to find the original allocation
@@ -150,7 +158,7 @@ struct MemRefPointerLikeModel
return true;
}
if (isa<memref::AllocOp>(originalAlloc)) {
- // This is an alloc - generate dealloc
+ // This is an alloc - generate dealloc on varToFree
memref::DeallocOp::create(builder, loc, memrefValue);
return true;
}
@@ -1003,6 +1011,142 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
}
};
+//===----------------------------------------------------------------------===//
+// Recipe Region Helpers
+//===----------------------------------------------------------------------===//
+
+/// Create and populate an init region for privatization recipes.
+/// Returns the init block on success, or nullptr on failure.
+/// Sets needsFree to indicate if the allocated memory requires deallocation.
+static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
+ Type varType, StringRef varName,
+ ValueRange bounds,
+ bool &needsFree) {
+ // Create init block with arguments: original value + bounds
+ SmallVector<Type> argTypes{varType};
+ SmallVector<Location> argLocs{loc};
+ for (Value bound : bounds) {
+ argTypes.push_back(bound.getType());
+ argLocs.push_back(loc);
+ }
+
+ auto initBlock = std::make_unique<Block>();
+ initBlock->addArguments(argTypes, argLocs);
+ builder.setInsertionPointToStart(initBlock.get());
+
+ Value privatizedValue;
+
+ // Get the block argument that represents the original variable
+ Value blockArgVar = initBlock->getArgument(0);
+
+ // Generate init region body based on variable type
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
+ privatizedValue = mappableTy.generatePrivateInit(
+ builder, loc, typedVar, varName, bounds, {}, needsFree);
+ if (!privatizedValue)
+ return nullptr;
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ // Use PointerLikeType's allocation API with the block argument
+ privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
+ blockArgVar, needsFree);
+ if (!privatizedValue)
+ return nullptr;
+ }
+
+ // Add yield operation to init block
+ acc::YieldOp::create(builder, loc, privatizedValue);
+
+ return initBlock;
+}
+
+/// Create and populate a copy region for firstprivate recipes.
+/// Returns the copy block on success, or nullptr on failure.
+/// TODO: Handle MappableType - it does not yet have a copy API.
+static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc,
+ Type varType,
+ ValueRange bounds) {
+ // Create copy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> copyArgTypes{varType, varType};
+ SmallVector<Location> copyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ copyArgTypes.push_back(bound.getType());
+ copyArgLocs.push_back(loc);
+ }
+
+ auto copyBlock = std::make_unique<Block>();
+ copyBlock->addArguments(copyArgTypes, copyArgLocs);
+ builder.setInsertionPointToStart(copyBlock.get());
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a copy API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return nullptr;
+
+ // Generate copy region body based on variable type
+ if (isPointerLike) {
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ Value originalArg = copyBlock->getArgument(0);
+ Value privatizedArg = copyBlock->getArgument(1);
+
+ // Generate copy operation using PointerLikeType interface
+ if (!pointerLikeTy.genCopy(
+ builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
+ cast<TypedValue<PointerLikeType>>(originalArg), varType))
+ return nullptr;
+ }
+
+ // Add terminator to copy block
+ acc::TerminatorOp::create(builder, loc);
+
+ return copyBlock;
+}
+
+/// Create and populate a destroy region for privatization recipes.
+/// Returns the destroy block on success, or nullptr if not needed.
+static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder,
+ Location loc, Type varType,
+ Value allocRes,
+ ValueRange bounds) {
+ // Create destroy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> destroyArgTypes{varType, varType};
+ SmallVector<Location> destroyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ destroyArgTypes.push_back(bound.getType());
+ destroyArgLocs.push_back(loc);
+ }
+
+ auto destroyBlock = std::make_unique<Block>();
+ destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
+ builder.setInsertionPointToStart(destroyBlock.get());
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a deallocation API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return nullptr;
+
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ auto privatizedArg =
+ cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
+ // Pass allocRes to help determine the allocation type
+ if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType))
+ return nullptr;
+
+ acc::TerminatorOp::create(builder, loc);
+
+ return destroyBlock;
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1050,6 +1194,55 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ // Create init and destroy blocks using shared helpers
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Save the original insertion point for creating the recipe operation later
+ auto originalInsertionPoint = builder.saveInsertionPoint();
+
+ bool needsFree = false;
+ auto initBlock =
+ createInitRegion(builder, loc, varType, varName, bounds, needsFree);
+ if (!initBlock)
+ return std::nullopt;
+
+ // Only create destroy region if the allocation needs deallocation
+ std::unique_ptr<Block> destroyBlock;
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
+ if (!destroyBlock)
+ return std::nullopt;
+ }
+
+ // Now create the recipe operation at the original insertion point and attach
+ // the blocks
+ builder.restoreInsertionPoint(originalInsertionPoint);
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Move the blocks into the recipe's regions
+ recipe.getInitRegion().push_back(initBlock.release());
+ if (destroyBlock)
+ recipe.getDestroyRegion().push_back(destroyBlock.release());
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1080,6 +1273,60 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<FirstprivateRecipeOp>
+FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ // Create init, copy, and destroy blocks using shared helpers
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Save the original insertion point for creating the recipe operation later
+ auto originalInsertionPoint = builder.saveInsertionPoint();
+
+ bool needsFree = false;
+ auto initBlock =
+ createInitRegion(builder, loc, varType, varName, bounds, needsFree);
+ if (!initBlock)
+ return std::nullopt;
+
+ auto copyBlock = createCopyRegion(builder, loc, varType, bounds);
+ if (!copyBlock)
+ return std::nullopt;
+
+ // Only create destroy region if the allocation needs deallocation
+ std::unique_ptr<Block> destroyBlock;
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
+ if (!destroyBlock)
+ return std::nullopt;
+ }
+
+ // Now create the recipe operation at the original insertion point and attach
+ // the blocks
+ builder.restoreInsertionPoint(originalInsertionPoint);
+ auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Move the blocks into the recipe's regions
+ recipe.getInitRegion().push_back(initBlock.release());
+ recipe.getCopyRegion().push_back(copyBlock.release());
+ if (destroyBlock)
+ recipe.getDestroyRegion().push_back(destroyBlock.release());
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// ReductionRecipeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49..ac72002 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
sourceTensorType.getEncoding());
}
+// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f..12e6475 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
// yielded value, and:
- // 1. recording the unique first position at which the value is yielded.
+ // 1. recording the unique first position at which the value with uses is
+ // yielded.
// 2. recording for the result, the first position at which the dedup'ed
// value is yielded.
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
+ if (result.use_empty())
+ continue;
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
- if (result.use_empty() || !it.second)
+ if (!it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());
- llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
- origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
- // Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ // Replace the old `WarpOp` with the new one that has additional yield
+ // values and types.
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
- rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
- static_cast<bool>(ifOp.thenBlock()),
+ rewriter, ifOp.getLoc(), newIfOpDistResTypes,
+ newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
- newWarpOp.getResult(warpResRangeStart));
+ newWarpOp.getResult(newIndices[warpResRangeStart]));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `IfOp`.
- for (auto [origIdx, newIdx] : origToNewYieldIdx)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `IfOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(ifOp);
- rewriter.eraseOp(warpOp);
return success();
}
@@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueDistTypes.begin(),
escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
- // types. We also create a mapping between the non-`ForOp` yielded value
- // index and the corresponding new `WarpOp` yield value index (needed to
- // update users later).
- llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+ // types.
for (auto [i, v] :
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
- nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
@@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
- newForOpOperands.push_back(newWarpOp.getResult(i));
+ newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
@@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
- innerWarpInput.push_back(newWarpOp.getResult(i));
+ innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
@@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (!innerWarp.getResults().empty())
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
- // Update the users of original `WarpOp` results that were coming from the
+ // Update the users of the new `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `ForOp`.
- for (auto [origIdx, newIdx] : nonForResultMapping)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(forOp);
- rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 36c498e..f77784a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
return getTileShape(op->getOpResult(0));
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
+ xegpu::StoreMatrixOp>(op))
return getTileShape(op->getOpOperand(0));
- if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
+ if (isa<xegpu::StoreNdOp>(op))
return getTileShape(op->getOpOperand(1));
+ // Handle LoadGatherOp and StoreScatterOp (with and without offset)
+ if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ if (loadGatherOp.getOffsets())
+ return getTileShape(loadGatherOp->getOpResult(0));
+ else
+ return getTileShape(loadGatherOp->getOpOperand(0));
+ }
+
+ if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+ return getTileShape(storeScatterOp.getOffsets()
+ ? storeScatterOp->getOpOperand(0)
+ : storeScatterOp->getOpOperand(1));
+
if (isa<xegpu::DpasOp>(op)) {
std::optional<SmallVector<int64_t>> aTile =
getTileShape(op->getOpOperand(0));