aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp36
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp3
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp154
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h27
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp17
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
7 files changed, 240 insertions, 6 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 8062b474..a84d10d 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -258,6 +258,39 @@ getAllocEffectFor(Value value,
return success();
}
+static Operation *isDistinctObjectsOp(Operation *op) {
+ if (op && op->hasTrait<OpTrait::DistinctObjectsTrait>())
+ return op;
+
+ return nullptr;
+}
+
+static Value getDistinctObjectsOperand(Operation *op, Value value) {
+ unsigned argNumber = cast<OpResult>(value).getResultNumber();
+ return op->getOperand(argNumber);
+}
+
+static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) {
+ // We should already checked that lhs and rhs are different.
+ assert(lhs != rhs && "lhs and rhs must be different");
+
+ // Result and corresponding operand must alias.
+ auto lhsOp = isDistinctObjectsOp(lhs.getDefiningOp());
+ if (lhsOp && getDistinctObjectsOperand(lhsOp, lhs) == rhs)
+ return AliasResult::MustAlias;
+
+ auto rhsOp = isDistinctObjectsOp(rhs.getDefiningOp());
+ if (rhsOp && getDistinctObjectsOperand(rhsOp, rhs) == lhs)
+ return AliasResult::MustAlias;
+
+ // If two different values come from the same `DistinctObjects` operation,
+ // they don't alias.
+ if (lhsOp && lhsOp == rhsOp)
+ return AliasResult::NoAlias;
+
+ return std::nullopt;
+}
+
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
if (lhs == rhs)
@@ -289,6 +322,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
: AliasResult::MayAlias;
}
+ if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs))
+ return *result;
+
// Otherwise, neither of the values are constant so check to see if either has
// an allocation effect.
bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ec581ac..cc66fac 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect
IR/LLVMMemorySlot.cpp
IR/LLVMTypes.cpp
IR/LLVMTypeSyntax.cpp
+ IR/LLVMDialectBytecode.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
DEPENDS
+ MLIRLLVMDialectBytecodeIncGen
MLIRLLVMOpsIncGen
MLIRLLVMTypesIncGen
MLIRLLVMIntrinsicOpsIncGen
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5d08ccc..7ca09d9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -29,6 +29,8 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Error.h"
+#include "LLVMDialectBytecode.h"
+
#include <numeric>
#include <optional>
@@ -4237,6 +4239,7 @@ void LLVMDialect::initialize() {
// Support unknown operations because not all LLVM operations are registered.
allowUnknownOperations();
declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
+ detail::addBytecodeInterface(this);
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
new file mode 100644
index 0000000..41d1f80
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
@@ -0,0 +1,154 @@
+//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMDialectBytecode.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+
+// Provide some forward declarations of the functions that will be generated by
+// the include below.
+static void write(DIExpressionElemAttr attribute,
+ DialectBytecodeWriter &writer);
+static LogicalResult writeAttribute(Attribute attribute,
+ DialectBytecodeWriter &writer);
+
+//===--------------------------------------------------------------------===//
+// Optional ArrayRefs
+//
+// Note that both the writer and reader functions consider attributes to be
+// optional. This is because the attribute may be present or empty.
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalArrayRef(DialectBytecodeWriter &writer,
+ ArrayRef<EntryTy> storage) {
+ if (storage.empty()) {
+ writer.writeOwnedBool(false);
+ return;
+ }
+
+ writer.writeOwnedBool(true);
+ writer.writeList(storage, [&](EntryTy val) {
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ (void)writer.writeOptionalAttribute(val);
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ (void)writer.writeVarInt(val);
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ });
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader,
+ SmallVectorImpl<EntryTy> &storage) {
+ bool isPresent = false;
+ if (failed(reader.readBool(isPresent)))
+ return failure();
+ // Nothing to do here, the array is empty.
+ if (!isPresent)
+ return success();
+
+ auto readEntry = [&]() -> FailureOr<EntryTy> {
+ EntryTy temp;
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ if (succeeded(reader.readOptionalAttribute(temp)))
+ return temp;
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ if (succeeded(reader.readVarInt(temp)))
+ return temp;
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ return failure();
+ };
+
+ return reader.readList(storage, readEntry);
+}
+
+//===--------------------------------------------------------------------===//
+// Optional integral types
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalInt(DialectBytecodeWriter &writer,
+ std::optional<EntryTy> storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ EntryTy val = storage.value_or(0);
+ writer.writeVarIntWithFlag(val, storage.has_value());
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalInt(DialectBytecodeReader &reader,
+ std::optional<EntryTy> &storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ uint64_t result = 0;
+ bool flag = false;
+ if (failed(reader.readVarIntWithFlag(result, flag)))
+ return failure();
+ if (flag)
+ storage = static_cast<EntryTy>(result);
+ else
+ storage = std::nullopt;
+ return success();
+}
+
+//===--------------------------------------------------------------------===//
+// Tablegen generated bytecode functions
+//===--------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc"
+
+//===--------------------------------------------------------------------===//
+// LLVMDialectBytecodeInterface
+//===--------------------------------------------------------------------===//
+
+/// This class implements the bytecode interface for the LLVM dialect.
+struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface {
+ LLVMDialectBytecodeInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ // Attributes
+ Attribute readAttribute(DialectBytecodeReader &reader) const override {
+ return ::readAttribute(getContext(), reader);
+ }
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeAttribute(attr, writer);
+ }
+
+ // Types
+ Type readType(DialectBytecodeReader &reader) const override {
+ return ::readType(getContext(), reader);
+ }
+
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeType(type, writer);
+ }
+};
+} // namespace
+
+void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) {
+ dialect->addInterfaces<LLVMDialectBytecodeInterface>();
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
new file mode 100644
index 0000000..1a17cb4
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
@@ -0,0 +1,27 @@
+//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines hooks into the LLVM dialect bytecode
+// implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+
+namespace mlir::LLVM {
+class LLVMDialect;
+
+namespace detail {
+/// Add the interfaces necessary for encoding the LLVM dialect components in
+/// bytecode.
+void addBytecodeInterface(LLVMDialect *dialect);
+} // namespace detail
+} // namespace mlir::LLVM
+
+#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 36c498e..f77784a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
return getTileShape(op->getOpResult(0));
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
+ xegpu::StoreMatrixOp>(op))
return getTileShape(op->getOpOperand(0));
- if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
+ if (isa<xegpu::StoreNdOp>(op))
return getTileShape(op->getOpOperand(1));
+ // Handle LoadGatherOp and StoreScatterOp (with and without offset)
+ if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ if (loadGatherOp.getOffsets())
+ return getTileShape(loadGatherOp->getOpResult(0));
+ else
+ return getTileShape(loadGatherOp->getOpOperand(0));
+ }
+
+ if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+ return getTileShape(storeScatterOp.getOffsets()
+ ? storeScatterOp->getOpOperand(0)
+ : storeScatterOp->getOpOperand(1));
+
if (isa<xegpu::DpasOp>(op)) {
std::optional<SmallVector<int64_t>> aTile =
getTileShape(op->getOpOperand(0));
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3d19c5a..9b23dd6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2200,10 +2200,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
os << '>';
}
os << '[';
- interleave(
- loc.getLocations(),
- [&](Location loc) { printLocationInternal(loc, pretty); },
- [&]() { os << ", "; });
+ interleaveComma(loc.getLocations(), [&](Location loc) {
+ printLocationInternal(loc, pretty);
+ });
os << ']';
})
.Default([&](LocationAttr loc) {