aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp36
-rw-r--r--mlir/lib/Analysis/CMakeLists.txt2
-rw-r--r--mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp127
-rw-r--r--mlir/lib/Analysis/FlatLinearValueConstraints.cpp9
-rw-r--r--mlir/lib/Analysis/Presburger/Simplex.cpp2
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp56
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp10
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp47
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp2
-rw-r--r--mlir/lib/Conversion/MathToROCDL/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp76
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp4
-rw-r--r--mlir/lib/Conversion/MathToXeVM/CMakeLists.txt22
-rw-r--r--mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp167
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp288
-rw-r--r--mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp7
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp7
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp205
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp27
-rw-r--r--mlir/lib/Dialect/AMX/IR/AMXDialect.cpp99
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp321
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp5
-rw-r--r--mlir/lib/Dialect/Arith/Utils/Utils.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp84
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp3
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt6
-rw-r--r--mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp139
-rw-r--r--mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp4
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp17
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp154
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h27
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp74
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp179
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp140
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp89
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp95
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp2
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp265
-rw-r--r--mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp5
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp2
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp14
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp4
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp1
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp74
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformTypes.cpp6
-rw-r--r--mlir/lib/Dialect/Utils/IndexingUtils.cpp25
-rw-r--r--mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp108
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp88
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp48
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp52
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp2
-rw-r--r--mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp141
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp146
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp122
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp17
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp3
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp3
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
-rw-r--r--mlir/lib/IR/Diagnostics.cpp6
-rw-r--r--mlir/lib/IR/MLIRContext.cpp17
-rw-r--r--mlir/lib/IR/Operation.cpp6
-rw-r--r--mlir/lib/IR/OperationSupport.cpp9
-rw-r--r--mlir/lib/IR/Remarks.cpp57
-rw-r--r--mlir/lib/IR/TypeUtilities.cpp4
-rw-r--r--mlir/lib/Interfaces/CMakeLists.txt16
-rw-r--r--mlir/lib/Interfaces/InferIntRangeInterface.cpp19
-rw-r--r--mlir/lib/Interfaces/InferStridedMetadataInterface.cpp36
-rw-r--r--mlir/lib/RegisterAllPasses.cpp1
-rw-r--r--mlir/lib/Remark/RemarkStreamer.cpp4
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp3
-rw-r--r--mlir/lib/TableGen/CodeGenHelpers.cpp4
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp34
-rw-r--r--mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp6
-rw-r--r--mlir/lib/Target/LLVMIR/DebugImporter.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp2
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp5
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp5
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp461
-rw-r--r--mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp2
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp2
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp37
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp2
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp11
106 files changed, 3626 insertions, 916 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 8062b474..a84d10d 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -258,6 +258,39 @@ getAllocEffectFor(Value value,
return success();
}
+static Operation *isDistinctObjectsOp(Operation *op) {
+ if (op && op->hasTrait<OpTrait::DistinctObjectsTrait>())
+ return op;
+
+ return nullptr;
+}
+
+static Value getDistinctObjectsOperand(Operation *op, Value value) {
+ unsigned argNumber = cast<OpResult>(value).getResultNumber();
+ return op->getOperand(argNumber);
+}
+
+static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) {
+ // We should already checked that lhs and rhs are different.
+ assert(lhs != rhs && "lhs and rhs must be different");
+
+ // Result and corresponding operand must alias.
+ auto lhsOp = isDistinctObjectsOp(lhs.getDefiningOp());
+ if (lhsOp && getDistinctObjectsOperand(lhsOp, lhs) == rhs)
+ return AliasResult::MustAlias;
+
+ auto rhsOp = isDistinctObjectsOp(rhs.getDefiningOp());
+ if (rhsOp && getDistinctObjectsOperand(rhsOp, rhs) == lhs)
+ return AliasResult::MustAlias;
+
+ // If two different values come from the same `DistinctObjects` operation,
+ // they don't alias.
+ if (lhsOp && lhsOp == rhsOp)
+ return AliasResult::NoAlias;
+
+ return std::nullopt;
+}
+
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
if (lhs == rhs)
@@ -289,6 +322,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
: AliasResult::MayAlias;
}
+ if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs))
+ return *result;
+
// Otherwise, neither of the values are constant so check to see if either has
// an allocation effect.
bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 609cb34..db10ebc 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
+ DataFlow/StridedMetadataRangeAnalysis.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
@@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRPresburger
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000..01c9daf
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,127 @@
+//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/DebugStringHelper.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "strided-metadata-range-analysis"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// Get the entry state for a value. For any value that is not a ranked memref,
+/// this function sets the metadata to a top state with no offsets, sizes, or
+/// strides. For `memref` types, this function will use the metadata in the type
+/// to try to deduce as much informaiton as possible.
+static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
+ // TODO: generalize this method with a type interface.
+ auto mTy = dyn_cast<BaseMemRefType>(v.getType());
+
+ // If not a memref or it's un-ranked, don't infer any metadata.
+ if (!mTy || !mTy.hasRank())
+ return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
+
+ // Get the top state.
+ auto metadata =
+ StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
+
+ // Compute the offset and strides.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
+ return metadata;
+
+ // Refine the metadata if we know it from the type.
+ if (!ShapedType::isDynamic(offset)) {
+ metadata.getOffsets()[0] =
+ ConstantIntRanges::constant(APInt(indexBitwidth, offset));
+ }
+ for (auto &&[size, range] :
+ llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
+ if (ShapedType::isDynamic(size))
+ continue;
+ range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
+ }
+ for (auto &&[stride, range] :
+ llvm::zip_equal(strides, metadata.getStrides())) {
+ if (ShapedType::isDynamic(stride))
+ continue;
+ range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
+ }
+
+ return metadata;
+}
+
+StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis(
+ DataFlowSolver &solver, int32_t indexBitwidth)
+ : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
+ assert(indexBitwidth > 0 && "invalid bitwidth");
+}
+
+void StridedMetadataRangeAnalysis::setToEntryState(
+ StridedMetadataRangeLattice *lattice) {
+ propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
+ lattice->getAnchor(), indexBitwidth)));
+}
+
+LogicalResult StridedMetadataRangeAnalysis::visitOperation(
+ Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands,
+ ArrayRef<StridedMetadataRangeLattice *> results) {
+ auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
+
+ // Bail if we cannot reason about the op.
+ if (!inferrable) {
+ setAllToEntryStates(results);
+ return success();
+ }
+
+ LDBG() << "Inferring metadata for: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+ // Helper function to retrieve int range values.
+ auto getIntRange = [&](Value value) -> IntegerValueRange {
+ auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
+ getProgramPointAfter(op), value);
+ return lattice ? lattice->getValue() : IntegerValueRange();
+ };
+
+ // Convert the arguments lattices to a vector.
+ SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
+ operands, [](const StridedMetadataRangeLattice *lattice) {
+ return lattice->getValue();
+ });
+
+ // Callback to set metadata on a result.
+ auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
+ auto result = cast<OpResult>(v);
+ assert(llvm::is_contained(op->getResults(), result));
+ LDBG() << "- Inferred metadata: " << md;
+ StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
+ ChangeResult changed = lattice->join(md);
+ LDBG() << "- Joined metadata: " << lattice->getValue();
+ propagateIfChanged(lattice, changed);
+ };
+
+ // Infer the metadata.
+ inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
+ indexBitwidth);
+ return success();
+}
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 30ce1fb..6588b53 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1244,8 +1244,9 @@ bool FlatLinearValueConstraints::areVarsAlignedWithOther(
/// Checks if the SSA values associated with `cst`'s variables in range
/// [start, end) are unique.
-static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
- const FlatLinearValueConstraints &cst, unsigned start, unsigned end) {
+[[maybe_unused]] static bool
+areVarsUnique(const FlatLinearValueConstraints &cst, unsigned start,
+ unsigned end) {
assert(start <= cst.getNumDimAndSymbolVars() &&
"Start position out of bounds");
@@ -1267,14 +1268,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
}
/// Checks if the SSA values associated with `cst`'s variables are unique.
-static bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] static bool
areVarsUnique(const FlatLinearValueConstraints &cst) {
return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars());
}
/// Checks if the SSA values associated with `cst`'s variables of kind `kind`
/// are unique.
-static bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] static bool
areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
if (kind == VarKind::SetDim)
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index a1cbe29..547a4c2 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -34,7 +34,7 @@ using Direction = Simplex::Direction;
const int nullIndex = std::numeric_limits<int>::max();
// Return a + scale*b;
-LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]]
static SmallVector<DynamicAPInt, 8>
scaleAndAddForAssert(ArrayRef<DynamicAPInt> a, const DynamicAPInt &scale,
ArrayRef<DynamicAPInt> b) {
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7b17106..06d0256 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2730,6 +2730,17 @@ public:
operation->get(), toMlirStringRef(name)));
}
+ static void
+ forEachAttr(MlirOperation op,
+ llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
+ intptr_t n = mlirOperationGetNumAttributes(op);
+ for (intptr_t i = 0; i < n; ++i) {
+ MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
+ MlirStringRef name = mlirIdentifierStr(na.name);
+ fn(name, na.attribute);
+ }
+ }
+
static void bind(nb::module_ &m) {
nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
.def("__contains__", &PyOpAttributeMap::dunderContains)
@@ -2737,7 +2748,50 @@ public:
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
.def("__setitem__", &PyOpAttributeMap::dunderSetItem)
- .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
+ .def("__delitem__", &PyOpAttributeMap::dunderDelItem)
+ .def("__iter__",
+ [](PyOpAttributeMap &self) {
+ nb::list keys;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ keys.append(nb::str(name.data, name.length));
+ });
+ return nb::iter(keys);
+ })
+ .def("keys",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute) {
+ out.append(nb::str(name.data, name.length));
+ });
+ return out;
+ })
+ .def("values",
+ [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef, MlirAttribute attr) {
+ out.append(PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast());
+ });
+ return out;
+ })
+ .def("items", [](PyOpAttributeMap &self) {
+ nb::list out;
+ PyOpAttributeMap::forEachAttr(
+ self.operation->get(),
+ [&](MlirStringRef name, MlirAttribute attr) {
+ out.append(nb::make_tuple(
+ nb::str(name.data, name.length),
+ PyAttribute(self.operation->getContext(), attr)
+ .maybeDownCast()));
+ });
+ return out;
+ });
}
private:
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index d506b7f..0f0ed22 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -197,10 +197,15 @@ public:
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
- nb::object res = f(op, PyPatternRewriter(rewriter));
+
+ PyMlirContextRef ctx =
+ PyMlirContext::forContext(mlirOperationGetContext(op));
+ nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+ nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
- MlirRewritePattern pattern = mlirOpRewritePattenCreate(
+ MlirRewritePattern pattern = mlirOpRewritePatternCreate(
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
@@ -261,7 +266,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
//----------------------------------------------------------------------------
- nb::class_<MlirRewritePattern>(m, "RewritePattern");
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
.def(
"__init__",
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 70dee59..41ceb15 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -270,17 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-static inline mlir::FrozenRewritePatternSet *
-unwrap(MlirFrozenRewritePatternSet module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
-}
-
-static inline MlirFrozenRewritePatternSet
-wrap(mlir::FrozenRewritePatternSet *module) {
- return {module};
-}
-
MlirFrozenRewritePatternSet
mlirFreezeRewritePattern(MlirRewritePatternSet set) {
auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
@@ -311,15 +300,6 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
/// PatternRewriter API
//===----------------------------------------------------------------------===//
-inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) {
- assert(rewriter.ptr && "unexpected null rewriter");
- return static_cast<mlir::PatternRewriter *>(rewriter.ptr);
-}
-
-inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
- return {rewriter};
-}
-
MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
}
@@ -361,7 +341,7 @@ private:
} // namespace mlir
-MlirRewritePattern mlirOpRewritePattenCreate(
+MlirRewritePattern mlirOpRewritePatternCreate(
MlirStringRef rootName, unsigned benefit, MlirContext context,
MlirRewritePatternCallbacks callbacks, void *userData,
size_t nGeneratedNames, MlirStringRef *generatedNames) {
@@ -400,15 +380,6 @@ void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
//===----------------------------------------------------------------------===//
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
- assert(module.ptr && "unexpected null module");
- return static_cast<mlir::PDLPatternModule *>(module.ptr);
-}
-
-static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
- return {module};
-}
-
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
return wrap(new mlir::PDLPatternModule(
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
@@ -426,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
return wrap(m);
}
-inline const mlir::PDLValue *unwrap(MlirPDLValue value) {
- assert(value.ptr && "unexpected null PDL value");
- return static_cast<const mlir::PDLValue *>(value.ptr);
-}
-
-inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; }
-
-inline mlir::PDLResultList *unwrap(MlirPDLResultList results) {
- assert(results.ptr && "unexpected null PDL results");
- return static_cast<mlir::PDLResultList *>(results.ptr);
-}
-
-inline MlirPDLResultList wrap(mlir::PDLResultList *results) {
- return {results};
-}
-
MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
return wrap(unwrap(value)->dyn_cast<mlir::Value>());
}
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f8..bebf1b8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index b215211..c03f3a5 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
- populateMathToROCDLConversionPatterns(converter, patterns);
+ populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
index 2771955a..8cc3fde 100644
--- a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL
Core
LINK_LIBS PUBLIC
+ MLIRAMDGPUUtils
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index df219f3..a2dfc12 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,8 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -19,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/DebugLog.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
@@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}
+struct ClampFOpConversion final
+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only f16 and f32 types are supported by fmed3
+ Type opTy = op.getType();
+ Type resultType = getTypeConverter()->convertType(opTy);
+
+ if (auto vectorType = dyn_cast<VectorType>(opTy))
+ opTy = vectorType.getElementType();
+
+ if (!isa<Float16Type, Float32Type>(opTy))
+ return rewriter.notifyMatchFailure(
+ op, "fmed3 only supports f16 and f32 types");
+
+ // Handle multi-dimensional vectors (converted to LLVM arrays)
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ typename math::ClampFOp::Adaptor adaptor(operands);
+ return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getValue(), adaptor.getMin(),
+ adaptor.getMax());
+ },
+ rewriter);
+
+ // Handle 1D vectors and scalars directly
+ rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+ op.getMin(), op.getMax());
+ return success();
+ }
+};
+
void mlir::populateMathToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");
+
+ if (chipset.has_value() && chipset->majorVersion >= 9) {
+ patterns.add<ClampFOpConversion>(converter);
+ } else {
+ LDBG() << "Chipset dependent patterns were not added";
+ }
}
-namespace {
-struct ConvertMathToROCDLPass
- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
- ConvertMathToROCDLPass() = default;
+struct ConvertMathToROCDLPass final
+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ using impl::ConvertMathToROCDLBase<
+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
+
void runOnOperation() override;
};
-} // namespace
void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
@@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns);
+
+ FailureOr<amdgpu::Chipset> maybeChipset;
+ if (!chipset.empty()) {
+ maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset))
+ return signalPassFailure();
+ }
+ populateMathToROCDLConversionPatterns(
+ converter, patterns,
+ succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
+
ConversionTarget target(getContext());
- target.addLegalDialect<BuiltinDialect, func::FuncDialect,
- vector::VectorDialect, LLVM::LLVMDialect>();
+ target
+ .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index f0d8b78..610ce1f 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
if (auto vectorType = dyn_cast<VectorType>(operandType))
nanAttr = DenseElementsAttr::get(vectorType, nan);
- Value NanValue =
+ Value nanValue =
spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
Value lhs =
spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
- NanValue, adaptor.getLhs());
+ nanValue, adaptor.getLhs());
Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000..050c0ed
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMathToXeVM
+ MathToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithAttrToLLVMConversion
+ MLIRArithDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMathDialect
+ MLIRXeVMDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000..0fe31d0
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,167 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+ return failure();
+
+ arith::FastMathFlags fastFlags = op.getFastmath();
+ if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
+ return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
+
+ SmallVector<Type, 1> operandTypes;
+ for (auto operand : adaptor.getOperands()) {
+ Type opTy = operand.getType();
+ // This pass only supports operations on vectors that are already in SPIRV
+ // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+ // supported vector sizes are done in other blocking optimization passes.
+ if (!isSPIRVCompatibleFloatOrVec(opTy))
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("incompatible operand type: '{0}'", opTy));
+ operandTypes.push_back(opTy);
+ }
+
+ auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
+ auto funcOpRes = LLVM::lookupOrCreateFn(
+ rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
+ operandTypes, op.getType());
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+
+ auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, funcOp, adaptor.getOperands());
+ // Preserve fastmath flags in our MLIR op when converting to llvm function
+ // calls, in order to allow further fastmath optimizations: We thus need to
+ // convert arith fastmath attrs into attrs recognized by llvm.
+ arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+ mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+ callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
+ return success();
+ }
+
+ inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+ if (type.isFloat())
+ return true;
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ if (!vecType.getElementType().isFloat())
+ return false;
+ // SPIRV distinguishes between vectors and matrices: OpenCL native math
+ // intrsinics are not compatible with matrices.
+ ArrayRef<int64_t> shape = vecType.getShape();
+ if (shape.size() != 1)
+ return false;
+ // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+ shape[0] == 16)
+ return true;
+ }
+ return false;
+ }
+
+ inline std::string
+ getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+ std::string mangledFuncName =
+ "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+ auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
+ if (type.isF32())
+ mangledFuncName += "f";
+ else if (type.isF16())
+ mangledFuncName += "Dh";
+ else if (type.isF64())
+ mangledFuncName += "d";
+ };
+
+ for (auto type : operandTypes) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ appendFloatToMangledFunc(vecType.getElementType());
+ } else
+ appendFloatToMangledFunc(type);
+ }
+
+ return mangledFuncName;
+ }
+
+ const StringRef nativeFunc;
+};
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith) {
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+ "__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+ "__spirv_ocl_native_cos");
+ patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_exp2");
+ patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+ "__spirv_ocl_native_log");
+ patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log2");
+ patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log10");
+ patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_powr");
+ patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_rsqrt");
+ patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+ "__spirv_ocl_native_sin");
+ patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_sqrt");
+ patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+ "__spirv_ocl_native_tan");
+ if (convertArith)
+ patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_divide");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+ : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+ using Base::Base;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateMathToXeVMConversionPatterns(patterns, convertArith);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 2b7bdc9..11f866c 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include <cstdint>
#include <numeric>
@@ -110,9 +111,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
{TypeAttr::get(memrefType.getElementType())}));
IndexType indexType = builder.getIndexType();
- int64_t numElements = std::accumulate(memrefType.getShape().begin(),
- memrefType.getShape().end(), int64_t{1},
- std::multiplies<int64_t>());
+ int64_t numElements = llvm::product_of(memrefType.getShape());
emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
builder, loc, indexType, builder.getIndexAttr(numElements));
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a5336ed..00df14b1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1392,6 +1392,137 @@ public:
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+static llvm::SmallVector<int8_t>
+convertToI8(const llvm::SmallVector<int32_t> &input) {
+ llvm::SmallVector<int8_t> output;
+ output.reserve(input.size());
+
+ for (auto v : llvm::map_range(
+ input, [](int32_t val) { return static_cast<int8_t>(val); })) {
+ output.push_back(v);
+ }
+ return output;
+}
+
+// The shift or multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift or multiplier is non-constant, add it as an input to
+// linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift or multiplier is constant, set 'constant' instead.
+static void setupLinalgGenericOpInputAndIndexingMap(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
+ bool isShift = false) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the
+ // values in a buffer.
+ if (values.size() == 1) {
+ IntegerAttr intAttr = isShift
+ ? rewriter.getI8IntegerAttr(values.front())
+ : rewriter.getI32IntegerAttr(values.front());
+ constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+ } else {
+ auto elementType =
+ isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
+ auto tensorType = RankedTensorType::get(
+ {static_cast<int64_t>(values.size())}, elementType);
+ DenseIntElementsAttr EltAttr;
+ if (isShift)
+ EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
+ else
+ EltAttr = DenseIntElementsAttr::get(tensorType, values);
+ genericInputs.push_back(
+ arith::ConstantOp::create(rewriter, loc, EltAttr));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto operand = isShift ? op.getShift() : op.getMultiplier();
+ auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(operand);
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ }
+ arg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t zpArg,
+ bool isOutputZp = false) {
+ Value result;
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ const uint32_t attrBitwidth =
+ isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
+ auto extendType = builder.getIntegerType(attrBitwidth);
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[zpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
+ // For ExtUIOp, the input must be signless.
+ // UnrealizedConversionCastOp will cast the input to signless type.
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ UnrealizedConversionCastOp::create(
+ builder, loc,
+ builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
+ .getResult(0);
+ }
+ if (zpTy.isUnsignedInteger()) {
+ return builder.create<arith::ExtUIOp>(loc, extendType, result);
+ } else {
+ return builder.create<arith::ExtSIOp>(loc, extendType, result);
+ }
+ }
+ } else {
+ return builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(extendType, *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1554,46 @@ public:
}
}
- // The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int32_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ // explicit cast is required here
+ shiftValues = llvm::to_vector(llvm::map_range(
+ shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
-
- bool doubleRound =
- op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
@@ -1468,45 +1605,43 @@ public:
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant, shiftArg, true);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp's are non-constant, add them as an inputs to
+ // linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ // - If the zp's are constant, they would be generated as arith.constant.
+ int64_t iZpArg = 0;
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ iZpArg = indexingMaps.size() - 1;
+ }
+ int64_t oZpArg = 0;
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
+ oZpArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
@@ -1526,36 +1661,17 @@ public:
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs, iZpArg);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs, oZpArg, true);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 802691c..9bf9ca3 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
@@ -70,8 +71,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Calculate the product of all elements in 'newShape' except for the -1
// placeholder, which we discard by negating the result.
- int64_t totalSizeNoPlaceholder = -std::accumulate(
- newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
+ int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape);
// If there is a 0 component in 'newShape', resolve the placeholder as
// 0.
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index 79c2f23..245a3ef 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -20,6 +20,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include <numeric>
@@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
if (isPacked)
src = collapseLastDim(rewriter, src);
int64_t rows = vecShape[0];
- int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t cols = llvm::product_of(vecShape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
@@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
ArrayRef<int64_t> shape = vecTy.getShape();
int64_t rows = shape[0];
- int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t cols = llvm::product_of(shape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5355909..41d8d53 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
return success();
}
- // For 1-d vector, we additionally do a `vectorshuffle`.
auto v =
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
poison, adaptor.getSource(), zero);
+ // For 1-d vector, we additionally do a `shufflevector`.
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
- zeroValues);
+ auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
+ broadcast.getLoc(), v, poison, zeroValues);
+ rewriter.replaceOp(broadcast, shuffle);
return success();
}
};
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index c45c45e..c9eba69 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOSCF
@@ -760,8 +761,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value.
- auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t flatLength = llvm::product_of(shape);
auto flatVectorType =
VectorType::get({flatLength}, vectorType.getElementType());
value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b2580..dd9edc4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9ead1d8..fcbf66d 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -20,9 +20,12 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -61,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
}
+ llvm_unreachable("Unknown XeGPU memory space");
}
// Get same bitwidth flat vector type of new element type.
@@ -184,6 +188,7 @@ class CreateNdDescToXeVMPattern
int64_t rank = mixedSizes.size();
if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -362,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Add a builder that creates
// offset * elemByteSize + baseAddr
-static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
- Value baseAddr, Value offset, int64_t elemByteSize) {
+static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
+ Location loc, Value baseAddr, Value offset,
+ int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ rewriter, loc, baseAddr.getType(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -389,7 +395,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
- valOrResTy = op.getResult().getType();
+ valOrResTy =
+ this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -440,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
- basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
+ basePtrI64 =
+ addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -503,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto resTy = op.getMemDesc();
+
+ // Create the result MemRefType with the same shape, element type, and
+ // memory space
+ auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+ Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+ op.getSource(), zero, ValueRange());
+ rewriter.replaceOp(op, viewOp);
+ return success();
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ Value basePtrStruct = adaptor.getMemDesc();
+ Value mdescVal = op.getMemDesc();
+ // Load result or Store value Type can be vector or scalar.
+ Value data;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+ data = op.getResult();
+ else
+ data = adaptor.getData();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ if (!valOrResVecTy)
+ valOrResVecTy = VectorType::get(1, data.getType());
+
+ int64_t elemBitWidth =
+ valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+
+ // Default memory space is SLM.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+ auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, basePtrStruct);
+
+ // Convert base pointer (ptr) to i32
+ Value basePtrI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
+
+ Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+ linearOffset = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), linearOffset);
+ basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
+ elemByteSize);
+
+ // convert base pointer (i32) to LLVM pointer type
+ basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
+
+ if (op.getSubgroupBlockIoAttr()) {
+ // if the attribute 'subgroup_block_io' is set to true, it lowers to
+ // xevm.blockload
+
+ Type intElemTy = rewriter.getIntegerType(elemBitWidth);
+ VectorType intVecTy =
+ VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+ if (intVecTy != valOrResVecTy) {
+ loadOp =
+ vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+ }
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ Value dataToStore = adaptor.getData();
+ if (valOrResVecTy != intVecTy) {
+ dataToStore =
+ vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+ }
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+ nullptr);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+
+ if (valOrResVecTy.getNumElements() >= 1) {
+ auto chipOpt = xegpu::getChipStr(op);
+ if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+ // the lowering for chunk load only works for pvc and bmg
+ return rewriter.notifyMatchFailure(
+ op, "The lowering is specific to pvc or bmg.");
+ }
+ }
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+ // operation. LLVM load/store does not support vector of size 1, so we
+ // need to handle this case separately.
+ auto scalarTy = valOrResVecTy.getElementType();
+ LLVM::LoadOp loadOp;
+ if (valOrResVecTy.getNumElements() == 1)
+ loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+ else
+ loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -545,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
op, "Expected element type bit width to be multiple of 8.");
elemByteSize = elemBitWidth / 8;
}
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
+ elemByteSize);
}
}
// Default memory space is global.
@@ -774,9 +923,7 @@ struct ConvertXeGPUToXeVMPass
if (rank < 1 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
- int64_t sum =
- std::accumulate(type.getShape().begin(), type.getShape().end(),
- int64_t{1}, std::multiplies<int64_t>());
+ int64_t sum = llvm::product_of(type.getShape());
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
@@ -785,6 +932,13 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
+ // Convert MemDescType into flattened MemRefType for SLM
+ typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+ Type elemTy = type.getElementType();
+ int numElems = type.getNumElements();
+ return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ });
+
typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
return IntegerType::get(&getContext(), 64);
@@ -879,10 +1033,30 @@ struct ConvertXeGPUToXeVMPass
}
return {};
};
- typeConverter.addSourceMaterialization(memrefMaterializationCast);
- typeConverter.addSourceMaterialization(ui64MaterializationCast);
- typeConverter.addSourceMaterialization(ui32MaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
+
+ // If result type of original op is single element vector and lowered type
+ // is scalar. This materialization cast creates a single element vector by
+ // broadcasting the scalar value.
+ auto singleElementVectorMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType().isIntOrIndexOrFloat()) {
+ // If the input is a scalar, and the target type is a vector of single
+ // element, create a single element vector by broadcasting.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
+ return vector::BroadcastOp::create(builder, loc, vecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(
+ singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
@@ -919,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
+ patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+ LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+ CreateMemDescOpPattern>(typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index f405d0c..61166db 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -339,6 +339,25 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
}
//===----------------------------------------------------------------------===//
+// ScaledExtPacked816Op
+//===----------------------------------------------------------------------===//
+LogicalResult ScaledExtPacked816Op::verify() {
+ int blockSize = getBlockSize();
+ assert((blockSize == 16 || blockSize == 32) && "invalid block size");
+ int firstScaleByte = getFirstScaleByte();
+ if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
+ return emitOpError(
+ "blockSize of 16 can only have firstScaleByte be 0 or 1.");
+ }
+ if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
+ return emitOpError(
+ "blockSize of 32 can only have firstScaleByte be 0 or 2.");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
LogicalResult WMMAOp::verify() {
@@ -757,13 +776,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
offset = numElements - 4l;
}
Type scaleSrcElemType = scaleSrcType.getElementType();
- auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
- scaleSrcElemType);
+ auto newSrcType =
+ VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
Value newScaleSrc =
vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
auto extract = vector::ExtractStridedSliceOp::create(
- rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
- ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1});
+ rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
+ ArrayRef{int64_t(1)});
rewriter.modifyOpInPlace(op, [&] {
op->setOperand(opIdx, extract);
setOpsel(opIdx, opsel);
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 68990ef..d9c097c 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
+/// Returns stride expressed in number of bytes for the given `elementStride`
+/// stride encoded in number of elements of the type `mType`.
+static Value computeStrideInBytes(Location loc, MemRefType mType,
+ Value elementStride, RewriterBase &rewriter) {
+ Type llvmInt64Type = rewriter.getIntegerType(64);
+ unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
+ auto attr = rewriter.getI64IntegerAttr(bytes);
+ Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
+ return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
+ .getResult();
+}
+
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
-static Value getStride(Location loc, MemRefType mType, Value base,
- RewriterBase &rewriter) {
+static Value inferStride(Location loc, MemRefType mType, Value base,
+ RewriterBase &rewriter) {
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = rewriter.getIntegerType(64);
@@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base,
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
- auto attr = rewriter.getI64IntegerAttr(bytes);
- Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
- return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
- memrefDescriptor.stride(rewriter, loc, preLast))
- .getResult();
+ return computeStrideInBytes(
+ loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
@@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
return getTileSizes(getLoc(), getTileType(), rewriter);
}
-LogicalResult amx::TileLoadOp::verify() {
- MemRefType memrefTy = getMemRefType();
+template <typename OpTy,
+ typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
+ std::is_same_v<OpTy, amx::TileStoreOp>>>
+static LogicalResult tileTransferVerifier(OpTy op) {
+ MemRefType memrefTy = op.getMemRefType();
unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+ if (op.getIndices().size() != rank)
+ return op.emitOpError("requires ") << rank << " indices";
+
+ if (failed(verifyTileSize(op, op.getTileType())))
+ return failure();
+
+ // Validate basic buffer properties when the stride is implicit.
+ if (!op.getStride()) {
+ if (rank < 2)
+ return op.emitOpError("requires at least 2D memref");
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+ strides.back() != 1)
+ return op.emitOpError("requires memref with unit innermost stride");
+ }
+
+ return success();
+}
+
+void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
+ Value base, ValueRange indices) {
+ build(builder, state, res, base, indices, /*stride=*/nullptr);
}
+LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
return intrinsicOperands;
}
-LogicalResult amx::TileStoreOp::verify() {
- MemRefType memrefTy = getMemRefType();
- unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
+ Value base, ValueRange indices, Value val) {
+ build(builder, state, base, indices, val, /*stride=*/nullptr);
}
+LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
intrinsicOperands.push_back(adaptor.getVal());
return intrinsicOperands;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7e5ce26..e0a53cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -125,9 +125,9 @@ static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
// Use "unused attribute" marker to silence clang-tidy warning stemming from
// the inability to see through "llvm::TypeSwitch".
template <>
-bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
- Region *src, Region *dest,
- const IRMapping &mapping) {
+[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src,
+ Region *dest,
+ const IRMapping &mapping) {
// If it's a valid dimension, we need to check that it remains so.
if (isValidDim(op.getResult(), src))
return remainsLegalAfterInline(
@@ -1032,8 +1032,8 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
/// Simplify the map while exploiting information on the values in `operands`.
// Use "unused attribute" marker to silence warning stemming from the inability
// to see through the template expansion.
-static void LLVM_ATTRIBUTE_UNUSED
-simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
+[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map,
+ ArrayRef<Value> operands) {
assert(map.getNumInputs() == operands.size() && "invalid operands for map");
SmallVector<AffineExpr> newResults;
newResults.reserve(map.getNumResults());
@@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
return success(*map != initialMap);
}
+/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
+/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
+/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
+/// into `replacementsMap`. If no entries were added to `replacementsMap`,
+/// nothing was found.
+static void shortenAddChainsContainingAll(
+ AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
+ AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
+ auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
+ if (!binOp)
+ return;
+ AffineExpr lhs = binOp.getLHS();
+ AffineExpr rhs = binOp.getRHS();
+ if (binOp.getKind() != AffineExprKind::Add) {
+ shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
+ shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
+ return;
+ }
+ SmallVector<AffineExpr> toPreserve;
+ llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
+ AffineExpr thisTerm = rhs;
+ AffineExpr nextTerm = lhs;
+
+ while (thisTerm) {
+ if (!ourTracker.erase(thisTerm)) {
+ toPreserve.push_back(thisTerm);
+ shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
+ replacementsMap);
+ }
+ auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
+ if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
+ thisTerm = nextTerm;
+ nextTerm = AffineExpr();
+ } else {
+ thisTerm = nextBinOp.getRHS();
+ nextTerm = nextBinOp.getLHS();
+ }
+ }
+ if (!ourTracker.empty())
+ return;
+ // We reverse the terms to be preserved here in order to preserve
+ // associativity between them.
+ AffineExpr newExpr = newVal;
+ for (AffineExpr preserved : llvm::reverse(toPreserve))
+ newExpr = newExpr + preserved;
+ replacementsMap.insert({e, newExpr});
+}
+
+/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
+/// ...` (not necessarily in order) where the set of the `x_i` is the set of
+/// outputs of an `affine.delinearize_index` whos inverse is that expression,
+/// replace that expression with the input of that delinearize_index op.
+///
+/// `unitDimInput` is the input that was detected as the potential start to this
+/// replacement chain - if it isn't the rightmost result of the delinearization,
+/// this method fails. (This is intended to ensure we don't have redundant scans
+/// over the same expression).
+///
+/// While this currently only handles delinearizations with a constant basis,
+/// that isn't a fundamental limitation.
+///
+/// This is a utility function for `replaceDimOrSym` below.
+static LogicalResult replaceAffineDelinearizeIndexInverseExpression(
+ AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
+ SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
+ if (!delinOp.getDynamicBasis().empty())
+ return failure();
+ if (resultToReplace != delinOp.getMultiIndex().back())
+ return failure();
+
+ MLIRContext *ctx = delinOp.getContext();
+ SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
+ for (auto [pos, dim] : llvm::enumerate(dims)) {
+ auto asResult = dyn_cast_if_present<OpResult>(dim);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
+ }
+ for (auto [pos, sym] : llvm::enumerate(syms)) {
+ auto asResult = dyn_cast_if_present<OpResult>(sym);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
+ }
+ if (llvm::is_contained(resToExpr, AffineExpr()))
+ return failure();
+
+ bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
+ int64_t stride = 1;
+ llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
+ // This isn't zip_equal since sometimes the delinearize basis is missing a
+ // size for the first result.
+ for (auto [binding, size] : llvm::zip(
+ llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
+ expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
+ stride *= size;
+ }
+ if (resToExpr.size() != delinOp.getStaticBasis().size())
+ expectedExprs.insert(resToExpr[0] * stride);
+
+ DenseMap<AffineExpr, AffineExpr> replacements;
+ AffineExpr delinInExpr = isDimReplacement
+ ? getAffineDimExpr(dims.size(), ctx)
+ : getAffineSymbolExpr(syms.size(), ctx);
+
+ for (AffineExpr e : map->getResults())
+ shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
+ if (replacements.empty())
+ return failure();
+
+ AffineMap origMap = *map;
+ if (isDimReplacement)
+ dims.push_back(delinOp.getLinearIndex());
+ else
+ syms.push_back(delinOp.getLinearIndex());
+ *map = origMap.replace(replacements, dims.size(), syms.size());
+
+ // Blank out dead dimensions and symbols
+ for (AffineExpr e : resToExpr) {
+ if (auto d = dyn_cast<AffineDimExpr>(e)) {
+ unsigned pos = d.getPosition();
+ if (!map->isFunctionOfDim(pos))
+ dims[pos] = nullptr;
+ }
+ if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
+ unsigned pos = s.getPosition();
+ if (!map->isFunctionOfSymbol(pos))
+ syms[pos] = nullptr;
+ }
+ }
+ return success();
+}
+
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
syms);
}
+ if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
+ return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
+ syms);
+ }
+
auto affineApply = v.getDefiningOp<AffineApplyOp>();
if (!affineApply)
return failure();
@@ -2460,6 +2600,65 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) {
return success(folded);
}
+/// Returns constant trip count in trivial cases.
+static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
+ int64_t step = forOp.getStepAsInt();
+ if (!forOp.hasConstantBounds() || step <= 0)
+ return std::nullopt;
+ int64_t lb = forOp.getConstantLowerBound();
+ int64_t ub = forOp.getConstantUpperBound();
+ return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+}
+
+/// Fold the empty loop.
+static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
+ if (!llvm::hasSingleElement(*forOp.getBody()))
+ return {};
+ if (forOp.getNumResults() == 0)
+ return {};
+ std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
+ if (tripCount == 0) {
+ // The initial values of the iteration arguments would be the op's
+ // results.
+ return forOp.getInits();
+ }
+ SmallVector<Value, 4> replacements;
+ auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
+ auto iterArgs = forOp.getRegionIterArgs();
+ bool hasValDefinedOutsideLoop = false;
+ bool iterArgsNotInOrder = false;
+ for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
+ Value val = yieldOp.getOperand(i);
+ BlockArgument *iterArgIt = llvm::find(iterArgs, val);
+ // TODO: It should be possible to perform a replacement by computing the
+ // last value of the IV based on the bounds and the step.
+ if (val == forOp.getInductionVar())
+ return {};
+ if (iterArgIt == iterArgs.end()) {
+ // `val` is defined outside of the loop.
+ assert(forOp.isDefinedOutsideOfLoop(val) &&
+ "must be defined outside of the loop");
+ hasValDefinedOutsideLoop = true;
+ replacements.push_back(val);
+ } else {
+ unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
+ if (pos != i)
+ iterArgsNotInOrder = true;
+ replacements.push_back(forOp.getInits()[pos]);
+ }
+ }
+ // Bail out when the trip count is unknown and the loop returns any value
+ // defined outside of the loop or any iterArg out of order.
+ if (!tripCount.has_value() &&
+ (hasValDefinedOutsideLoop || iterArgsNotInOrder))
+ return {};
+ // Bail out when the loop iterates more than once and it returns any iterArg
+ // out of order.
+ if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
+ return {};
+ return llvm::to_vector_of<OpFoldResult>(replacements);
+}
+
/// Canonicalize the bounds of the given loop.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
@@ -2491,79 +2690,30 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
return success();
}
-namespace {
-/// Returns constant trip count in trivial cases.
-static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
- int64_t step = forOp.getStepAsInt();
- if (!forOp.hasConstantBounds() || step <= 0)
- return std::nullopt;
- int64_t lb = forOp.getConstantLowerBound();
- int64_t ub = forOp.getConstantUpperBound();
- return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+/// Returns true if the affine.for has zero iterations in trivial cases.
+static bool hasTrivialZeroTripCount(AffineForOp op) {
+ return getTrivialConstantTripCount(op) == 0;
}
-/// This is a pattern to fold trivially empty loop bodies.
-/// TODO: This should be moved into the folding hook.
-struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- // Check that the body only contains a yield.
- if (!llvm::hasSingleElement(*forOp.getBody()))
- return failure();
- if (forOp.getNumResults() == 0)
- return success();
- std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
- if (tripCount == 0) {
- // The initial values of the iteration arguments would be the op's
- // results.
- rewriter.replaceOp(forOp, forOp.getInits());
- return success();
- }
- SmallVector<Value, 4> replacements;
- auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
- auto iterArgs = forOp.getRegionIterArgs();
- bool hasValDefinedOutsideLoop = false;
- bool iterArgsNotInOrder = false;
- for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
- Value val = yieldOp.getOperand(i);
- auto *iterArgIt = llvm::find(iterArgs, val);
- // TODO: It should be possible to perform a replacement by computing the
- // last value of the IV based on the bounds and the step.
- if (val == forOp.getInductionVar())
- return failure();
- if (iterArgIt == iterArgs.end()) {
- // `val` is defined outside of the loop.
- assert(forOp.isDefinedOutsideOfLoop(val) &&
- "must be defined outside of the loop");
- hasValDefinedOutsideLoop = true;
- replacements.push_back(val);
- } else {
- unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
- if (pos != i)
- iterArgsNotInOrder = true;
- replacements.push_back(forOp.getInits()[pos]);
- }
- }
- // Bail out when the trip count is unknown and the loop returns any value
- // defined outside of the loop or any iterArg out of order.
- if (!tripCount.has_value() &&
- (hasValDefinedOutsideLoop || iterArgsNotInOrder))
- return failure();
- // Bail out when the loop iterates more than once and it returns any iterArg
- // out of order.
- if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
- return failure();
- rewriter.replaceOp(forOp, replacements);
- return success();
+LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ bool folded = succeeded(foldLoopBounds(*this));
+ folded |= succeeded(canonicalizeLoopBounds(*this));
+ if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
+ // The initial values of the loop-carried variables (iter_args) are the
+ // results of the op. But this must be avoided for an affine.for op that
+ // does not return any results. Since ops that do not return results cannot
+ // be folded away, we would enter an infinite loop of folds on the same
+ // affine.for op.
+ results.assign(getInits().begin(), getInits().end());
+ folded = true;
}
-};
-} // namespace
-
-void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<AffineForEmptyLoopFolder>(context);
+ SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this);
+ if (!foldResults.empty()) {
+ results.assign(foldResults);
+ folded = true;
+ }
+ return success(folded);
}
OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
@@ -2606,27 +2756,6 @@ void AffineForOp::getSuccessorRegions(
regions.push_back(RegionSuccessor(getResults()));
}
-/// Returns true if the affine.for has zero iterations in trivial cases.
-static bool hasTrivialZeroTripCount(AffineForOp op) {
- return getTrivialConstantTripCount(op) == 0;
-}
-
-LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
- SmallVectorImpl<OpFoldResult> &results) {
- bool folded = succeeded(foldLoopBounds(*this));
- folded |= succeeded(canonicalizeLoopBounds(*this));
- if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
- // The initial values of the loop-carried variables (iter_args) are the
- // results of the op. But this must be avoided for an affine.for op that
- // does not return any results. Since ops that do not return results cannot
- // be folded away, we would enter an infinite loop of folds on the same
- // affine.for op.
- results.assign(getInits().begin(), getInits().end());
- folded = true;
- }
- return success(folded);
-}
-
AffineBound AffineForOp::getLowerBound() {
return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index cd216ef..4743941 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1357,7 +1357,7 @@ bool mlir::affine::isValidLoopInterchangePermutation(
/// Returns true if `loops` is a perfectly nested loop nest, where loops appear
/// in it from outermost to innermost.
-bool LLVM_ATTRIBUTE_UNUSED
+[[maybe_unused]] bool
mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) {
assert(!loops.empty() && "no loops provided");
@@ -1920,8 +1920,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
return copyNestRoot;
}
-static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
-emitRemarkForBlock(Block &block) {
+[[maybe_unused]] static InFlightDiagnostic emitRemarkForBlock(Block &block) {
return block.getParentOp()->emitRemark();
}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index b1fc9aa..f54baff 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -351,9 +351,9 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
Value one = ConstantOp::create(builder, loc, resultType,
builder.getOneAttr(resultType));
ArithBuilder arithBuilder(builder, loc);
- return std::accumulate(
- values.begin(), values.end(), one,
- [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+ return llvm::accumulate(values, one, [&arithBuilder](Value acc, Value v) {
+ return arithBuilder.mul(acc, v);
+ });
}
/// Map strings to float types.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index a50ddbe..bc17990 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -41,28 +41,37 @@ namespace bufferization {
using namespace mlir;
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
+/// Get all the ReturnOp in the funcOp.
+static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
+ SmallVector<func::ReturnOp> returnOps;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
+ returnOps.push_back(candidateOp);
}
}
- return returnOp;
+ return returnOps;
}
-/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+/// Get the operands at the specified position for all returnOps.
+static SmallVector<Value>
+getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
+ return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) {
+ return returnOp.getOperand(pos);
+ });
+}
+
+/// Check if all given values are the same buffer as the block argument (modulo
+/// cast ops).
+static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
+ BlockArgument argument) {
+ for (Value val : operands) {
+ while (auto castOp = val.getDefiningOp<memref::CastOp>())
+ val = castOp.getSource();
+
+ if (val != argument)
+ return false;
+ }
+ return true;
}
LogicalResult
@@ -72,48 +81,55 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
// Collect the mapping of functions to their call sites.
module.walk([&](func::CallOp callOp) {
- if (func::FuncOp calledFunc = getCalledFunction(callOp)) {
- callerMap[calledFunc].insert(callOp);
+ if (func::FuncOp calledFunc =
+ dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
+ if (!calledFunc.isPublic() && !calledFunc.isExternal())
+ callerMap[calledFunc].insert(callOp);
}
});
for (auto funcOp : module.getOps<func::FuncOp>()) {
- if (funcOp.isExternal())
+ if (funcOp.isExternal() || funcOp.isPublic())
continue;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- // TODO: Support functions with multiple blocks.
- if (!returnOp)
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ if (returnOps.empty())
continue;
// Compute erased results.
- SmallVector<Value> newReturnValues;
- BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
+ size_t numReturnOps = returnOps.size();
+ size_t numReturnValues = funcOp.getFunctionType().getNumResults();
+ SmallVector<SmallVector<Value>> newReturnValues(numReturnOps);
+ BitVector erasedResultIndices(numReturnValues);
DenseMap<int64_t, int64_t> resultToArgs;
- for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
+ for (size_t i = 0; i < numReturnValues; ++i) {
bool erased = false;
+ SmallVector<Value> returnOperands =
+ getReturnOpsOperandInPos(returnOps, i);
for (BlockArgument bbArg : funcOp.getArguments()) {
- Value val = it.value();
- while (auto castOp = val.getDefiningOp<memref::CastOp>())
- val = castOp.getSource();
-
- if (val == bbArg) {
- resultToArgs[it.index()] = bbArg.getArgNumber();
+ if (operandsEqualFuncArgument(returnOperands, bbArg)) {
+ resultToArgs[i] = bbArg.getArgNumber();
erased = true;
break;
}
}
if (erased) {
- erasedResultIndices.set(it.index());
+ erasedResultIndices.set(i);
} else {
- newReturnValues.push_back(it.value());
+ for (auto [newReturnValue, operand] :
+ llvm::zip(newReturnValues, returnOperands)) {
+ newReturnValue.push_back(operand);
+ }
}
}
// Update function.
if (failed(funcOp.eraseResults(erasedResultIndices)))
return failure();
- returnOp.getOperandsMutable().assign(newReturnValues);
+
+ for (auto [returnOp, newReturnValue] :
+ llvm::zip(returnOps, newReturnValues))
+ returnOp.getOperandsMutable().assign(newReturnValue);
// Update function calls.
for (func::CallOp callOp : callerMap[funcOp]) {
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 19eba6b..b5f8dda 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2460,8 +2460,7 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
<< dDim << ")";
scales[i] = eDim / dDim;
}
- if (std::accumulate(scales.begin(), scales.end(), 1,
- std::multiplies<int64_t>()) != warpSize)
+ if (llvm::product_of(scales) != warpSize)
return op->emitOpError()
<< "incompatible distribution dimensions from " << expandedVecType
<< " to " << distributedVecType << " with warp size = " << warpSize;
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index 70a9c77..ec68acf 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRGPUPipelines
GPUToNVVMPipeline.cpp
+ GPUToXeVMPipeline.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
@@ -11,12 +12,17 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRTransforms
MLIRLinalgTransforms
MLIRAffineToStandard
+ MLIRGPUToLLVMSPV
MLIRGPUToNVVMTransforms
MLIRIndexToLLVM
MLIRMathToLLVM
+ MLIRMathToXeVM
MLIRNVGPUToNVVM
MLIRNVVMToLLVM
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
MLIRVectorToSCF
+ MLIRXeGPUTransforms
+ MLIRXeGPUToXeVM
+ MLIRXeVMToLLVM
)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
new file mode 100644
index 0000000..1a1485b
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -0,0 +1,139 @@
+//===- GPUToXeVMPipeline.cpp - Lowering pipeline to XeVM/LLVM -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to XeVM as a generally
+// usable sink pass. If XeGPU ops are used, it expects the MLIR code to have
+// XeGPU ops already embedded in gpu code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+//===----------------------------------------------------------------------===//
+// Pre-GPU common pipeline for both Host and GPU.
+//===----------------------------------------------------------------------===//
+void buildPreGPUCommonPassPipeline(
+ OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ // builtin.module scope passes.
+ pm.addPass(createCSEPass());
+ pm.addPass(createConvertVectorToSCFPass());
+ {
+ GpuXeVMAttachTargetOptions xevmTargetOptions;
+ xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher;
+ xevmTargetOptions.triple = options.zebinTriple;
+ xevmTargetOptions.chip = options.zebinChip;
+ xevmTargetOptions.optLevel = options.optLevel;
+ xevmTargetOptions.cmdOptions = options.cmdOptions;
+ pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions));
+ }
+ pm.addPass(createLowerAffinePass());
+ pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
+}
+
+//===----------------------------------------------------------------------===//
+// GPUModule-specific stuff.
+//===----------------------------------------------------------------------===//
+void buildGPUPassPipeline(OpPassManager &pm,
+ const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ if (options.xegpuOpLevel == "workgroup") {
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ }
+ if (options.xegpuOpLevel == "subgroup" ||
+ options.xegpuOpLevel == "workgroup") {
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
+ }
+ pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToXeVM());
+ pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
+ {
+ ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
+ gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
+ }
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createReconcileUnrealizedCastsPass());
+}
+
+//===----------------------------------------------------------------------===//
+// Post-GPU pipeline for both Host and GPU.
+//===----------------------------------------------------------------------===//
+void buildPostGPUCommonPassPipeline(
+ OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ // builtin.module scope passes.
+ pm.addPass(createSCFToControlFlowPass());
+ pm.addPass(memref::createExpandStridedMetadataPass());
+ {
+ GpuToLLVMConversionPassOptions gpuToLLVMOptions;
+ gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv;
+ gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv;
+ pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
+ }
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createConvertToLLVMPass());
+ pm.addPass(createReconcileUnrealizedCastsPass());
+ // gpu-module-to-binary
+ {
+ GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
+ gpuToModuleBinOptions.compilationTarget = options.binaryFormat;
+ gpuToModuleBinOptions.cmdOptions = options.cmdOptions;
+ pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions));
+ }
+}
+} // namespace
+
+void mlir::gpu::buildLowerToXeVMPassPipeline(
+ OpPassManager &pm, const GPUToXeVMPipelineOptions &options) {
+ // Pre-GPU common pipelines.
+ buildPreGPUCommonPassPipeline(pm, options);
+
+ // GPUModule-specific stuff.
+ buildGPUPassPipeline(pm, options);
+
+ // Post-GPU pipeline for both Host and GPU.
+ buildPostGPUCommonPassPipeline(pm, options);
+}
+
+void mlir::gpu::registerGPUToXeVMPipeline() {
+ PassPipelineRegistration<GPUToXeVMPipelineOptions>(
+ "gpu-lower-to-xevm-pipeline",
+ "The default GPU to XeVM lowering pipeline. It starts by lowering GPU "
+ "code to the "
+ "specified compilation target (default is fatbin) then lowers the host "
+ "code.",
+ buildLowerToXeVMPassPipeline);
+}
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 88f531f..572b746 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
@@ -118,8 +119,7 @@ bool WarpDistributionPattern::delinearizeLaneId(
return false;
sizes.push_back(large / small);
}
- if (std::accumulate(sizes.begin(), sizes.end(), 1,
- std::multiplies<int64_t>()) != warpSize)
+ if (llvm::product_of(sizes) != warpSize)
return false;
AffineExpr s0, s1;
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ec581ac..cc66fac 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect
IR/LLVMMemorySlot.cpp
IR/LLVMTypes.cpp
IR/LLVMTypeSyntax.cpp
+ IR/LLVMDialectBytecode.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
DEPENDS
+ MLIRLLVMDialectBytecodeIncGen
MLIRLLVMOpsIncGen
MLIRLLVMTypesIncGen
MLIRLLVMIntrinsicOpsIncGen
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5d08ccc..3eae67f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -29,6 +29,8 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Error.h"
+#include "LLVMDialectBytecode.h"
+
#include <numeric>
#include <optional>
@@ -2824,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() {
return success();
}
+// Folding for shufflevector op when v1 is single element 1D vector
+// and the mask is a single zero. OpFoldResult will be v1 in this case.
+OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
+ // Check if operand 0 is a single element vector.
+ auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
+ if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
+ return {};
+ // Check if the mask is a single zero.
+ // Note: The mask is guaranteed to be non-empty.
+ if (getMask().size() != 1 || getMask()[0] != 0)
+ return {};
+ return getV1();
+}
+
//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
@@ -4237,6 +4253,7 @@ void LLVMDialect::initialize() {
// Support unknown operations because not all LLVM operations are registered.
allowUnknownOperations();
declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
+ detail::addBytecodeInterface(this);
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
new file mode 100644
index 0000000..41d1f80
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
@@ -0,0 +1,154 @@
+//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMDialectBytecode.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+
+// Provide some forward declarations of the functions that will be generated by
+// the include below.
+static void write(DIExpressionElemAttr attribute,
+ DialectBytecodeWriter &writer);
+static LogicalResult writeAttribute(Attribute attribute,
+ DialectBytecodeWriter &writer);
+
+//===--------------------------------------------------------------------===//
+// Optional ArrayRefs
+//
+// Note that both the writer and reader functions consider attributes to be
+// optional. This is because the attribute may be present or empty.
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalArrayRef(DialectBytecodeWriter &writer,
+ ArrayRef<EntryTy> storage) {
+ if (storage.empty()) {
+ writer.writeOwnedBool(false);
+ return;
+ }
+
+ writer.writeOwnedBool(true);
+ writer.writeList(storage, [&](EntryTy val) {
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ (void)writer.writeOptionalAttribute(val);
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ (void)writer.writeVarInt(val);
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ });
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader,
+ SmallVectorImpl<EntryTy> &storage) {
+ bool isPresent = false;
+ if (failed(reader.readBool(isPresent)))
+ return failure();
+ // Nothing to do here, the array is empty.
+ if (!isPresent)
+ return success();
+
+ auto readEntry = [&]() -> FailureOr<EntryTy> {
+ EntryTy temp;
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ if (succeeded(reader.readOptionalAttribute(temp)))
+ return temp;
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ if (succeeded(reader.readVarInt(temp)))
+ return temp;
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ return failure();
+ };
+
+ return reader.readList(storage, readEntry);
+}
+
+//===--------------------------------------------------------------------===//
+// Optional integral types
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalInt(DialectBytecodeWriter &writer,
+ std::optional<EntryTy> storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ EntryTy val = storage.value_or(0);
+ writer.writeVarIntWithFlag(val, storage.has_value());
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalInt(DialectBytecodeReader &reader,
+ std::optional<EntryTy> &storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ uint64_t result = 0;
+ bool flag = false;
+ if (failed(reader.readVarIntWithFlag(result, flag)))
+ return failure();
+ if (flag)
+ storage = static_cast<EntryTy>(result);
+ else
+ storage = std::nullopt;
+ return success();
+}
+
+//===--------------------------------------------------------------------===//
+// Tablegen generated bytecode functions
+//===--------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc"
+
+//===--------------------------------------------------------------------===//
+// LLVMDialectBytecodeInterface
+//===--------------------------------------------------------------------===//
+
+/// This class implements the bytecode interface for the LLVM dialect.
+struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface {
+ LLVMDialectBytecodeInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ // Attributes
+ Attribute readAttribute(DialectBytecodeReader &reader) const override {
+ return ::readAttribute(getContext(), reader);
+ }
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeAttribute(attr, writer);
+ }
+
+ // Types
+ Type readType(DialectBytecodeReader &reader) const override {
+ return ::readType(getContext(), reader);
+ }
+
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeType(type, writer);
+ }
+};
+} // namespace
+
+void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) {
+ dialect->addInterfaces<LLVMDialectBytecodeInterface>();
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
new file mode 100644
index 0000000..1a17cb4
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
@@ -0,0 +1,27 @@
+//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines hooks into the LLVM dialect bytecode
+// implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+
+namespace mlir::LLVM {
+class LLVMDialect;
+
+namespace detail {
+/// Add the interfaces necessary for encoding the LLVM dialect components in
+/// bytecode.
+void addBytecodeInterface(LLVMDialect *dialect);
+} // namespace detail
+} // namespace mlir::LLVM
+
+#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 01a16ce..ac35eea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -134,10 +134,10 @@ static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
/// These are unused for now.
/// TODO: Move over to these once more types have been migrated to TypeDef.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40b..2a8c330 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() {
" attribute");
}
+ // Validate layout combinations. According to the operation description, most
+ // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
+ // can use other layout combinations.
+ bool isM8N8K4_F16 =
+ (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
+ getMultiplicandAPtxType() == MMATypes::f16);
+
+ if (!isM8N8K4_F16) {
+ // For all other shapes/types, layoutA must be row and layoutB must be col
+ if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
+ return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
+ "layoutB = #nvvm.mma_layout<col> for shape <")
+ << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
+ << "> with element types "
+ << stringifyEnum(*getMultiplicandAPtxType()) << " and "
+ << stringifyEnum(*getMultiplicandBPtxType())
+ << ". Only m8n8k4 with f16 supports other layouts.";
+ }
+ }
+
return success();
}
@@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
@@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+/// Verify the range attribute satisfies LLVM ConstantRange constructor
+/// requirements for NVVM SpecialRangeableRegisterOp.
+static LogicalResult
+verifyConstantRangeAttr(Operation *op,
+ std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
+ if (!rangeAttr)
+ return success();
+
+ const llvm::APInt &lower = rangeAttr->getLower();
+ const llvm::APInt &upper = rangeAttr->getUpper();
+
+ // Check LLVM ConstantRange constructor condition
+ if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+ unsigned bitWidth = lower.getBitWidth();
+ llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
+ llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
+ return op->emitOpError(
+ "invalid range attribute: Lower == Upper, but they aren't min (")
+ << llvm::toString(minVal, 10, false) << ") or max ("
+ << llvm::toString(maxVal, 10, false)
+ << ") value! This is an invalid constant range.";
+ }
+
+ return success();
+}
+
static llvm::Value *getAsPackedI32(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 17371ec..6d54bb6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -180,6 +181,15 @@ void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) {
// ROCDLDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
+namespace {
+struct ROCDLInlinerInterface final : DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+ bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+ return true;
+ }
+};
+} // namespace
+
// TODO: This should be the llvm.rocdl dialect once this is supported.
void ROCDLDialect::initialize() {
addOperations<
@@ -194,6 +204,7 @@ void ROCDLDialect::initialize() {
// Support unknown operations because not all ROCDL operations are registered.
allowUnknownOperations();
+ addInterfaces<ROCDLInlinerInterface>();
declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c477c6c..dcc1ef9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody(
Value yielded = getSourceSkipUnary(terminator->getOperand(0));
Operation *reductionOp = yielded.getDefiningOp();
- if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
+ if (!reductionOp || reductionOp->getNumResults() != 1 ||
+ reductionOp->getNumOperands() != 2) {
errs << "expected reduction op to be binary";
return false;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a2..cbc565b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {
SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getDestType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dd9b4c2..9a8a63e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
// FuseOp
//===----------------------------------------------------------------------===//
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result, loopTypes,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target, ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ // Loop types are automaticaly splat by the callee, setting up one is
+ // enough.
+ SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
+ build(builder, result, loopTypes, target, mixedTileSizes,
+ mixedTileInterchange, applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ SmallVector<int64_t> staticTileSizes;
+ SmallVector<Value> dynamicTileSizes;
+ dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
+ SmallVector<int64_t> staticTileInterchange;
+ SmallVector<Value> dynamicTileInterchange;
+ dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
+ staticTileInterchange);
+ // Call the default builder which sets up the proper operands segment sizes
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ auto staticTileInterchangeAttr =
+ builder.getDenseI64ArrayAttr(staticTileInterchange);
+ unsigned numExpectedLoops =
+ useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(numExpectedLoops);
+ assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
+ "expected one loop type or as many as loops");
+ if (loopTypes.size() == 1)
+ resultTypes.append(numExpectedLoops, loopTypes[0]);
+ else
+ llvm::append_range(resultTypes, loopTypes);
+ build(builder, result, /*transformed=*/target.getType(),
+ /*loops=*/resultTypes,
+ /*target=*/target,
+ /*tile_sizes=*/dynamicTileSizes,
+ /*tile_interchange=*/dynamicTileInterchange,
+ /*static_tile_sizes=*/staticTileSizesAttr,
+ /*static_tile_interchange=*/staticTileInterchangeAttr,
+ /*apply_cleanup=*/applyCleanup,
+ /*use_forall=*/useForall);
+}
+
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
@@ -630,13 +710,25 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
- SmallVector<int64_t> tileSizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- SmallVector<int64_t> tileInterchange =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ SmallVector<int64_t> tileSizes;
+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileSizes(), tileSizes);
+ if (!status.succeeded())
+ return status;
+ SmallVector<int64_t> tileInterchange;
+ status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileInterchange(), tileInterchange);
+ if (!status.succeeded())
+ return status;
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
+ bool useForall = getUseForall();
+ tilingOptions.setLoopType(useForall
+ ? scf::SCFTilingOptions::LoopType::ForallOp
+ : scf::SCFTilingOptions::LoopType::ForOp);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
@@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}
+ size_t numLoops =
+ useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
LogicalResult result = applyTilingToAll(
- rewriter, getOperation(), state.getPayloadOps(getTarget()),
- tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+ rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
+ transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}
LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError() << "expects interchange to be a permutation, found "
- << getTileInterchange();
+ auto iterspace_rank = getStaticTileSizes().size();
+ ArrayRef<int64_t> permutation = getStaticTileInterchange();
+ if (permutation.size() > iterspace_rank)
+ return emitOpError()
+ << "interchange length exceeds iteration space dimensions ("
+ << iterspace_rank << "), found " << getTileInterchange();
+ SmallVector<bool> seen(iterspace_rank, false);
+ for (int64_t v : permutation) {
+ if (!ShapedType::isDynamic(v)) {
+ if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
+ return emitOpError() << "expects interchange values to be in range [0, "
+ << iterspace_rank << "), found: " << v;
+ if (seen[v])
+ return emitOpError() << "found duplicate interchange value: " << v;
+ seen[v] = true;
+ }
}
- SmallVector<int64_t> sizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
+ ArrayRef<int64_t> sizes = getStaticTileSizes();
+ size_t numExpectedLoops =
+ getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
return success();
}
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
+ return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
+}
+
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
+ return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
+ getContext());
+}
+
+void transform::FuseOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getTileSizesMutable(), effects);
+ onlyReadsHandle(getTileInterchangeMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
@@ -2336,26 +2457,24 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
}
// Set options.
- TilingInterface paddedOp;
PadTilingInterfaceOptions options;
options.setPaddingValues(paddingValues)
.setPaddingSizes(getMixedPaddingSizes())
.setPadToMultipleOf(getPadToMultipleOf());
- // Apply padding.
- SmallVector<tensor::PadOp> newPadOps;
- FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
- rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
- newPadOps);
- if (failed(maybePaddedOp)) {
+ auto maybePadOps = rewriteAsPaddedOp(
+ rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
+ if (failed(maybePadOps)) {
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
+ const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
// Set transform results.
- paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
- padOps.append(newPadOps.begin(), newPadOps.end());
+ paddedOps.push_back(paddedOp);
+ padOps.append(paddedOperands.begin(), paddedOperands.end());
+ rewriter.replaceOp(targetOp.getOperation(), slicedResults);
}
results.set(cast<OpResult>(getPadded()), paddedOps);
@@ -2903,10 +3022,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
if (dynamicPointParseResult.has_value()) {
- Type ChunkSizesType;
+ Type chunkSizesType;
if (failed(*dynamicPointParseResult) || parser.parseComma() ||
- parser.parseType(ChunkSizesType) ||
- parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
+ parser.parseType(chunkSizesType) ||
+ parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
result.operands)) {
return failure();
}
@@ -3278,9 +3397,9 @@ void transform::ContinuousTileSizesOp::getEffects(
}
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
- Type targetType, Type tile_sizes,
+ Type targetType, Type tileSizes,
Type) {
- printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+ printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
}
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 0956c5d..3e787a2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
-SmallVector<OpFoldResult> linalg::computePaddedShape(
- RewriterBase &rewriter, TypedValue<RankedTensorType> v,
- AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
- const PadTilingInterfaceOptions &options) {
+SmallVector<OpFoldResult>
+linalg::computePaddedShape(OpBuilder &builder, TypedValue<RankedTensorType> v,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> indexingSizes,
+ const PadTilingInterfaceOptions &options) {
Location loc = v.getLoc();
SmallVector<OpFoldResult> paddedShape;
auto tensorType = cast<RankedTensorType>(v.getType());
@@ -109,7 +110,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
// "Full-rank" padding specification.
SmallVector<OpFoldResult> paddingSizes =
- getFullRankPaddingSizes(rewriter, indexingSizes, options);
+ getFullRankPaddingSizes(builder, indexingSizes, options);
// For each dimension in the operand's shape, iterate over indexingSizes and
// add the various term contributions.
@@ -147,28 +148,27 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
- bindDims(rewriter.getContext(), d0);
- bindSymbols(rewriter.getContext(), s0);
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
paddingDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, composedMap,
- {indexingSizes[paddingDim], paddingSize},
+ builder, loc, composedMap, {indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
} else {
// Otherwise just set to paddingSize.
paddingDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, projectedMap, paddingSize);
+ builder, loc, projectedMap, paddingSize);
}
// Adjust for the maximum accessed index, which is (paddingSize - 1) *
// multiplier.
AffineExpr d0;
- bindDims(rewriter.getContext(), d0);
+ bindDims(builder.getContext(), d0);
int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
- rewriter, loc, subtractMap, {paddingDimOfr});
+ builder, loc, subtractMap, {paddingDimOfr});
terms.push_back(maxAccessIdx);
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
@@ -177,19 +177,19 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
// If there are no terms, just return the dim.
if (terms.empty()) {
paddedShape[resultIndex] =
- createFoldedDimOp(rewriter, loc, v, resultIndex);
+ createFoldedDimOp(builder, loc, v, resultIndex);
continue;
}
// Sum individual terms' contributions.
SmallVector<AffineExpr> dims(terms.size());
- bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
+ bindDimsList(builder.getContext(), MutableArrayRef{dims});
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
// Add 1 to the maximum accessed index and get the final padded size.
- OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, sumExpr + 1, terms);
+ OpFoldResult paddedDimOfr =
+ affine::makeComposedFoldedAffineApply(builder, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
@@ -198,7 +198,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &builder, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -206,9 +206,9 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
return failure();
// clang-format off
- assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
- return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
- r.stride == OpFoldResult(rewriter.getIndexAttr(1));
+ assert(llvm::all_of(iterationDomain, [&builder](Range r) {
+ return r.offset == OpFoldResult(builder.getIndexAttr(0)) &&
+ r.stride == OpFoldResult(builder.getIndexAttr(1));
}) && "expected 0-offset 1-stride loop ranges");
// clang-format on
SmallVector<OpFoldResult> loopUpperBounds;
@@ -218,13 +218,13 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
return computePaddedShape(
- rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
+ builder, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
indexingMap, loopUpperBounds, options);
}
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
/// Value.
-static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &builder, TilingInterface opToPad,
TypedValue<RankedTensorType> v,
ArrayRef<OpFoldResult> paddedShape,
Attribute paddingValueAttr) {
@@ -232,15 +232,15 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(),
complexTy, complexAttr);
}
} else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
- paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(),
getElementTypeOrSelf(v.getType()));
} else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
paddingValue =
- arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
+ arith::ConstantOp::create(builder, opToPad.getLoc(), typedAttr);
}
assert(paddingValue && "failed to create value from padding attribute");
@@ -259,49 +259,48 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
RankedTensorType::get(tensorShape, getElementTypeOrSelf(v));
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
<< paddedTensorType);
- return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
+ return makeComposedPadHighOp(builder, opToPad.getLoc(), paddedTensorType, v,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
- RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
+FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
+ OpBuilder &builder, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
const PadSizeComputationFunction &computePaddingSizeFun) {
- LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+ LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
+ SmallVector<tensor::PadOp> padOps;
+ Location loc = toPad.getLoc();
- Location loc = opToPad.getLoc();
- PadTilingInterfaceOptions options(constOptions);
// Allow inference of pad values if they are not explicitly specified.
// TODO: be mindful about the value depending on the actual operation.
if (options.paddingValues.empty()) {
- SmallVector<Type> types(opToPad->getOperandTypes());
- llvm::append_range(types, opToPad->getResultTypes());
+ SmallVector<Type> types(toPad->getOperandTypes());
+ llvm::append_range(types, toPad->getResultTypes());
for (Type t : types) {
options.paddingValues.push_back(
- rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+ builder.getZeroAttr(getElementTypeOrSelf(t)));
}
}
- if (llvm::any_of(opToPad->getOperands(),
+ if (llvm::any_of(toPad->getOperands(),
[](Value v) { return isa<MemRefType>(v.getType()); })) {
- return rewriter.notifyMatchFailure(opToPad,
- "expected operation on tensors");
+ LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
+ return failure();
}
- OpBuilder::InsertionGuard g(rewriter);
- // Set IP after opToPad because we also take the dims of opToPad's output.
- rewriter.setInsertionPointAfter(opToPad);
+ OpBuilder::InsertionGuard g(builder);
+ // Set IP after toPad because we also take the dims of toPad's output.
+ builder.setInsertionPointAfter(toPad);
// 1. Get the loopUpperBounds from the TilingInterface.
- SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+ SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
// 2. For each operand.
SmallVector<Value> newOperands;
- newOperands.reserve(opToPad->getNumOperands());
- for (OpOperand &opOperand : opToPad->getOpOperands()) {
+ newOperands.reserve(toPad->getNumOperands());
+ for (OpOperand &opOperand : toPad->getOpOperands()) {
Value operand = opOperand.get();
- LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+ LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
// 2.a. Skip scalar-like operands.
Type operandType = operand.getType();
@@ -311,30 +310,31 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
newOperands.push_back(operand);
continue;
}
+
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
- computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+ computePaddingSizeFun(builder, opOperand, iterationDomain, options);
if (failed(maybePaddedShape)) {
- return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+ LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
+ return failure();
}
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
- return rewriter.notifyMatchFailure(opToPad,
- "--no padding value specified");
+ LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
+ return failure();
}
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
- Value paddedOperand = padOperand(
- rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
- *maybePaddedShape, paddingValueAttr);
+ Value paddedOperand =
+ padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
+ *maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
- // 2.d. Perform actual padding.
newOperands.push_back(paddedOperand);
if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
padOps.push_back(padOp);
@@ -342,38 +342,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
// 3. Form the resulting tensor::ExtractSliceOp.
ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
- LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
- return rewriter.notifyMatchFailure(opToPad,
- "failed to reify result shapes");
+ if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
+ LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
+ return failure();
}
- assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+ assert(reifiedResultShapes.size() == toPad->getNumResults() &&
"expected same number of results");
- // Clone `opToPad` to operate on the statically padded shapes.
+ // Clone `toPad` to operate on the statically padded shapes.
auto resultTensorTypes =
- ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
- // clone **should** properly notify the rewriter.
+ ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
+ // clone **should** properly notify the builder.
TilingInterface paddedOp =
- clone(rewriter, opToPad, resultTensorTypes, newOperands);
+ clone(builder, toPad, resultTensorTypes, newOperands);
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
- // Recover the slice out of the new static results. This keeps the original
- // opToPad around because it uses the dims of the original results.
+ // Recover the slice out of the new static results.
SmallVector<Value> paddedSubtensorResults;
- paddedSubtensorResults.reserve(opToPad->getNumResults());
+ paddedSubtensorResults.reserve(toPad->getNumResults());
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
- rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+ builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
- rewriter.replaceOp(opToPad, paddedSubtensorResults);
-
- return paddedOp;
+ return PadTilingInterfaceResult{padOps, paddedOp, paddedSubtensorResults};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index f277c5f..0ae2a9c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -266,9 +266,8 @@ struct StructuredOpShardingInterface
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
- unsigned reductionItersCount = std::accumulate(
- iteratorTypes.begin(), iteratorTypes.end(), 0,
- [](unsigned count, utils::IteratorType iter) {
+ unsigned reductionItersCount = llvm::accumulate(
+ iteratorTypes, 0u, [](unsigned count, utils::IteratorType iter) {
return count + (iter == utils::IteratorType::reduction);
});
shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0dac688..eb2d825 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
- // TODO: support the case that outer dimensions are not all 1s. A
- // tensor.expand_shape will be generated in this case.
- if (llvm::any_of(packOp.getAllOuterDims(),
+ if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "not all outer dimensions of the result are 1s");
}
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+
+ // Verify that there are no:
+ // * non-unit + un-tiled-outer-dims,
+ // that are permuted. Supporting such cases would require refining the logic
+ // that generates the Transpose Op.
+ if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
+ static int prev = 0;
+ // Skip tiled dims - these can be permuted.
+ if (llvm::is_contained(innerDimsPos, dim))
+ return true;
+
+ // Check whether this dim has been permuted. Permuting unit dims is fine
+ // as that's effectively a no-op.
+ if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
+ packOp.getType().getShape()[dim] != 1))
+ return false;
+
+ prev = dim;
+ return true;
+ })) {
+ return rewriter.notifyMatchFailure(
+ packOp, "At least one non-unit and un-tiled outer dim is permuted, "
+ "this is not supported ATM!");
+ }
+
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
- int64_t numberOfTiles = innerDimsPos.size();
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
@@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // - All outer dims are 1 - the corresponding transposition order doesn't
- // matter, but requires all dim indices to be present.
+ // - All tiled outer dims are 1 - the corresponding transposition order
+ // doesn't matter, but requires all dim indices to be present.
+ // - Un-tiled outer dims remain un-permuted.
- // 2.1 Get the permutation for linalg.transpose
+ // 2.1 Get the permutation for linalg.transpose:
+ // [ untiled-dims, inner-dims-pos ]
+ // Note, this logic assumes that the untiled dims are not permuted.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
@@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
- // 2.2 Create the init tensor for linalg.transpose with the correct shape
- SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
- oneIdxAttr);
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape:
+ // [ untiled-dims, tiled-dims ]
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ SmallVector<OpFoldResult> shapeForEmptyOp;
+ for (int64_t i = 0; i < srcRank; i++) {
+ if (llvm::is_contained(innerDimsPos, i)) {
+ // The tiled dims are appended after this loop.
+ continue;
+ }
+ if (inputTy.isStaticDim(i))
+ shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
+ else
+ shapeForEmptyOp.emplace_back(
+ tensor::DimOp::create(rewriter, loc, input, i).getResult());
+ }
shapeForEmptyOp.append(packOp.getMixedTiles());
// getMixedTiles() may contain Values pointing to constant ops, not the
@@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
- // 3. Insert the inner tile to the destination:
+ // 3. Insert the inner tile into the destination tensor:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- // Outer dims are all 1s!
- SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
- SmallVector<int64_t> writeShape;
+
+ // Compute the sizes attribute:
+ // [ outer-dims, tile-sizes ]
+ // Note that the output from the transpose Op excludes the tiled outer dims.
+ // However, given the assumption that:
+ // * all tiled outer dims == 1,
+ // we can just use a rank-expanding tensor.insert_slice.
+ SmallVector<OpFoldResult> writeSizes;
+ for (auto size : packOp.getAllOuterDims()) {
+ writeSizes.push_back(rewriter.getIndexAttr(size));
+ }
for (auto tileSize : packOp.getMixedTiles()) {
- auto [tileSizeStatic, tileSizeOfr] =
+ auto [_, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
- writeShape.push_back(tileSizeStatic);
}
- // 4. Replace tensor.packOp with tensor.insert_slice created above
+ // TODO: Add a constructor for tensor.insert_slice that doesn't require
+ // strides nor offsets.
+ SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
+ SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
+
auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);
+
+ // 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
return success();
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index e25a012..1382c7ac 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+ ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR
DEPENDS
MLIRMemRefOpsIncGen
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRDialectUtils
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemOpInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda..94947b7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,45 @@ public:
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ unsigned srcStaticCount = llvm::count_if(
+ llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+
+ SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
+
+ // TODO: Using counting comparison instead of direct comparison because
+ // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+ // IntegerAttrs, while constifyIndexValues (and therefore
+ // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+ if (srcStaticCount ==
+ llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
@@ -3437,6 +3471,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
+void SubViewOp::inferStridedMetadataRanges(
+ ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
+ SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
+ auto isUninitialized =
+ +[](IntegerValueRange range) { return range.isUninitialized(); };
+
+ // Bail early if any of the operands metadata is not ready:
+ SmallVector<IntegerValueRange> offsetOperands =
+ getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
+ if (llvm::any_of(offsetOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> sizeOperands =
+ getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
+ if (llvm::any_of(sizeOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> stridesOperands =
+ getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
+ if (llvm::any_of(stridesOperands, isUninitialized))
+ return;
+
+ StridedMetadataRange sourceRange =
+ ranges[getSourceMutable().getOperandNumber()];
+ if (sourceRange.isUninitialized())
+ return;
+
+ ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
+
+ // Get the dropped dims.
+ llvm::SmallBitVector droppedDims = getDroppedDims();
+
+ // Compute the new offset, strides and sizes.
+ ConstantIntRanges offset = sourceRange.getOffsets()[0];
+ SmallVector<ConstantIntRanges> strides, sizes;
+
+ for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
+ bool dropped = droppedDims.test(i);
+ // Compute the new offset.
+ ConstantIntRanges off =
+ intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
+ offset = intrange::inferAdd({offset, off});
+
+ // Skip dropped dimensions.
+ if (dropped)
+ continue;
+ // Multiply the strides.
+ strides.push_back(
+ intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
+ // Get the sizes.
+ sizes.push_back(sizeOperands[i].getValue());
+ }
+
+ setMetadata(getResult(),
+ StridedMetadataRange::getRanked(
+ SmallVector<ConstantIntRanges>({std::move(offset)}),
+ std::move(sizes), std::move(strides)));
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 49b7162..6f815ae 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -121,7 +121,7 @@ struct EmulateWideIntPass final
[&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
RewritePatternSet patterns(ctx);
- // Add common pattenrs to support contants, functions, etc.
+ // Add common patterns to support contants, functions, etc.
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6564a4e..dcfe2c7 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
@@ -39,6 +40,16 @@ static bool isScalarLikeType(Type type) {
return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
}
+/// Helper function to attach the `VarName` attribute to an operation
+/// if a variable name is provided.
+static void attachVarNameAttr(Operation *op, OpBuilder &builder,
+ StringRef varName) {
+ if (!varName.empty()) {
+ auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
+ op->setAttr(acc::getVarNameAttrName(), varNameAttr);
+ }
+}
+
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
MemRefType> {
@@ -74,14 +85,18 @@ struct MemRefPointerLikeModel
}
mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
- StringRef varName, Type varType,
- Value originalVar) const {
+ StringRef varName, Type varType, Value originalVar,
+ bool &needsFree) const {
auto memrefTy = cast<MemRefType>(pointer);
// Check if this is a static memref (all dimensions are known) - if yes
// then we can generate an alloca operation.
- if (memrefTy.hasStaticShape())
- return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+ if (memrefTy.hasStaticShape()) {
+ needsFree = false; // alloca doesn't need deallocation
+ auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
+ attachVarNameAttr(allocaOp, builder, varName);
+ return allocaOp.getResult();
+ }
// For dynamic memrefs, extract sizes from the original variable if
// provided. Otherwise they cannot be handled.
@@ -99,8 +114,11 @@ struct MemRefPointerLikeModel
// Note: We only add dynamic sizes to the dynamicSizes array
// Static dimensions are handled automatically by AllocOp
}
- return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
- .getResult();
+ needsFree = true; // alloc needs deallocation
+ auto allocOp =
+ memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
+ attachVarNameAttr(allocOp, builder, varName);
+ return allocOp.getResult();
}
// TODO: Unranked not yet supported.
@@ -108,10 +126,14 @@ struct MemRefPointerLikeModel
}
bool genFree(Type pointer, OpBuilder &builder, Location loc,
- TypedValue<PointerLikeType> varPtr, Type varType) const {
- if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ TypedValue<PointerLikeType> varToFree, Value allocRes,
+ Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
+ // Use allocRes if provided to determine the allocation type
+ Value valueToInspect = allocRes ? allocRes : memrefValue;
+
// Walk through casts to find the original allocation
- Value currentValue = memrefValue;
+ Value currentValue = valueToInspect;
Operation *originalAlloc = nullptr;
// Follow the chain of operations to find the original allocation
@@ -150,7 +172,7 @@ struct MemRefPointerLikeModel
return true;
}
if (isa<memref::AllocOp>(originalAlloc)) {
- // This is an alloc - generate dealloc
+ // This is an alloc - generate dealloc on varToFree
memref::DeallocOp::create(builder, loc, memrefValue);
return true;
}
@@ -1003,6 +1025,138 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
}
};
+//===----------------------------------------------------------------------===//
+// Recipe Region Helpers
+//===----------------------------------------------------------------------===//
+
+/// Create and populate an init region for privatization recipes.
+/// Returns success if the region is populated, failure otherwise.
+/// Sets needsFree to indicate if the allocated memory requires deallocation.
+static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
+ Region &initRegion, Type varType,
+ StringRef varName, ValueRange bounds,
+ bool &needsFree) {
+ // Create init block with arguments: original value + bounds
+ SmallVector<Type> argTypes{varType};
+ SmallVector<Location> argLocs{loc};
+ for (Value bound : bounds) {
+ argTypes.push_back(bound.getType());
+ argLocs.push_back(loc);
+ }
+
+ Block *initBlock = builder.createBlock(&initRegion);
+ initBlock->addArguments(argTypes, argLocs);
+ builder.setInsertionPointToStart(initBlock);
+
+ Value privatizedValue;
+
+ // Get the block argument that represents the original variable
+ Value blockArgVar = initBlock->getArgument(0);
+
+ // Generate init region body based on variable type
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
+ privatizedValue = mappableTy.generatePrivateInit(
+ builder, loc, typedVar, varName, bounds, {}, needsFree);
+ if (!privatizedValue)
+ return failure();
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ // Use PointerLikeType's allocation API with the block argument
+ privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
+ blockArgVar, needsFree);
+ if (!privatizedValue)
+ return failure();
+ }
+
+ // Add yield operation to init block
+ acc::YieldOp::create(builder, loc, privatizedValue);
+
+ return success();
+}
+
+/// Create and populate a copy region for firstprivate recipes.
+/// Returns success if the region is populated, failure otherwise.
+/// TODO: Handle MappableType - it does not yet have a copy API.
+static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
+ Region &copyRegion, Type varType,
+ ValueRange bounds) {
+ // Create copy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> copyArgTypes{varType, varType};
+ SmallVector<Location> copyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ copyArgTypes.push_back(bound.getType());
+ copyArgLocs.push_back(loc);
+ }
+
+ Block *copyBlock = builder.createBlock(&copyRegion);
+ copyBlock->addArguments(copyArgTypes, copyArgLocs);
+ builder.setInsertionPointToStart(copyBlock);
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a copy API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return failure();
+
+ // Generate copy region body based on variable type
+ if (isPointerLike) {
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ Value originalArg = copyBlock->getArgument(0);
+ Value privatizedArg = copyBlock->getArgument(1);
+
+ // Generate copy operation using PointerLikeType interface
+ if (!pointerLikeTy.genCopy(
+ builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
+ cast<TypedValue<PointerLikeType>>(originalArg), varType))
+ return failure();
+ }
+
+ // Add terminator to copy block
+ acc::TerminatorOp::create(builder, loc);
+
+ return success();
+}
+
+/// Create and populate a destroy region for privatization recipes.
+/// Returns success if the region is populated, failure otherwise.
+static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
+ Region &destroyRegion, Type varType,
+ Value allocRes, ValueRange bounds) {
+ // Create destroy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> destroyArgTypes{varType, varType};
+ SmallVector<Location> destroyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ destroyArgTypes.push_back(bound.getType());
+ destroyArgLocs.push_back(loc);
+ }
+
+ Block *destroyBlock = builder.createBlock(&destroyRegion);
+ destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
+ builder.setInsertionPointToStart(destroyBlock);
+
+ auto varToFree =
+ cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
+ return failure();
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
+ return failure();
+ }
+
+ acc::TerminatorOp::create(builder, loc);
+ return success();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1050,6 +1204,48 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Create the recipe operation first so regions have proper parent context
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Populate the init region
+ bool needsFree = false;
+ if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
+ varName, bounds, needsFree))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+
+ // Only create destroy region if the allocation needs deallocation
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp =
+ cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
+ varType, allocRes, bounds))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+ }
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1080,6 +1276,55 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<FirstprivateRecipeOp>
+FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Create the recipe operation first so regions have proper parent context
+ auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Populate the init region
+ bool needsFree = false;
+ if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
+ varName, bounds, needsFree))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+
+ // Populate the copy region
+ if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
+ bounds))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+
+ // Only create destroy region if the allocation needs deallocation
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp =
+ cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
+ varType, allocRes, bounds))) {
+ recipe.erase();
+ return std::nullopt;
+ }
+ }
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// ReductionRecipeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index b663908..8c4f80f 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
using namespace mlir;
@@ -76,9 +77,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
// using the right quantization parameters.
int64_t flattenIndex = 0;
auto shape = type.getShape();
- int64_t chunkSize =
- std::accumulate(std::next(shape.begin(), quantizationDim + 1),
- shape.end(), 1, std::multiplies<int64_t>());
+ int64_t chunkSize = llvm::product_of(shape.drop_front(quantizationDim + 1));
Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
return attr.mapValues(newElementType, [&](const APFloat &old) {
int chunkIndex = (flattenIndex++) / chunkSize;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5511998..fe50865 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -400,7 +400,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
return emitOpError("operand element type mismatch: expected to be ")
<< resultType.getElementType() << ", but provided " << elementType;
}
- unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
+ unsigned totalCount = llvm::sum_of(sizes);
if (totalCount != cType.getNumElements())
return emitOpError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided " << totalCount;
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 08fccfa..645cbff 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op,
}
template <typename It>
-bool isUnique(It begin, It end) {
+static bool isUnique(It begin, It end) {
if (begin == end) {
return true;
}
@@ -1010,18 +1010,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
return success();
}
-template <typename It>
-static auto product(It begin, It end) {
- using ElementType = std::decay_t<decltype(*begin)>;
- return std::accumulate(begin, end, static_cast<ElementType>(1),
- std::multiplies<ElementType>());
-}
-
-template <typename R>
-static auto product(R &&range) {
- return product(adl_begin(range), adl_end(range));
-}
-
static LogicalResult verifyDimensionCompatibility(Location loc,
int64_t expectedDimSize,
int64_t resultDimSize,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index a1711a6..069191c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -143,8 +143,8 @@ void VarInfo::setNum(Var::Num n) {
/// Helper function for `assertUsageConsistency` to better handle SMLoc
/// mismatches.
-LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
-minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
+[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1,
+ llvm::SMLoc sm2) {
const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index f539502..684c088 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -43,8 +43,8 @@ using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
#ifndef NDEBUG
-LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
- Location loc, Value memref) {
+[[maybe_unused]] static void dumpIndexMemRef(OpBuilder &builder, Location loc,
+ Value memref) {
memref = memref::CastOp::create(
builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49..ac72002 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
sourceTensorType.getEncoding());
}
+// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 5aad671..1cba1bb 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
}
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
- return TargetEnvAttr::get(context, Level::eightK,
+ return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}
@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c51b5e9..00f84bc 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2368,9 +2368,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
}
- int64_t newShapeElementsNum = std::accumulate(
- shapeValues.begin(), shapeValues.end(), 1LL,
- [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
+ int64_t newShapeElementsNum =
+ llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
+ return (dim > 0) ? acc * dim : acc;
+ });
bool isStaticNewShape =
llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
index bcb880a..a0661e4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -61,8 +61,8 @@ public:
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
- const auto targetEnvAttr =
- TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ const auto targetEnvAttr = TargetEnvAttr::get(
+ ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index d33ebe3..5786f53 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
@@ -375,8 +376,7 @@ llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
++reductionAxisVal) {
- int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
- oldShape.end(), 1, std::multiplies<int>());
+ int64_t stride = llvm::product_of(oldShape.drop_front(reductionAxis + 1));
int64_t index = indexAtOldTensor + stride * reductionAxisVal;
reducedValue =
OperationType::calcOneElement(reducedValue, oldTensor[index]);
@@ -424,8 +424,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
auto oldShape = shapedOldElementsValues.getShape();
auto newShape = resultType.getShape();
- auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
- std::multiplies<int>());
+ int64_t newNumOfElements = llvm::product_of(newShape);
llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 20f9333..f072e3e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
//===----------------------------------------------------------------------===//
template <typename T>
-FailureOr<SmallVector<T>>
-TosaProfileCompliance::getOperatorDefinition(Operation *op,
- CheckCondition &condition) {
+FailureOr<OpComplianceInfo<T>>
+TosaProfileCompliance::getOperatorDefinition(Operation *op) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
- return findMatchedProfile<T>(op, it->second, condition);
+ return findMatchedEntry<T>(op, it->second);
}
template <typename T>
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
- if (failed(maybeOpRequiredMode)) {
+ const auto maybeOpDefinition = getOperatorDefinition<T>(op);
+ if (failed(maybeOpDefinition)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
- int mode_count = 0;
+ int modeCount = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
- mode_count += cands.size();
+ modeCount += cands.size();
}
op->emitOpError() << "illegal: requires"
- << (mode_count > 1 ? " any of " : " ") << "["
+ << (modeCount > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
// Find the required profiles or extensions according to the operand type
// combination.
- const auto opRequiredMode = maybeOpRequiredMode.value();
+ const auto opDefinition = maybeOpDefinition.value();
+ const SmallVector<T> opRequiredMode = opDefinition.mode;
+ const CheckCondition condition = opDefinition.condition;
+
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
}
}
+ // Ensure the matched op compliance version does not exceed the target
+ // specification version.
+ const VersionedTypeInfo versionedTypeInfo =
+ opDefinition.operandTypeInfoSet[0];
+ const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
+ const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
+ if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
+ op->emitOpError() << "illegal: the target specification version ("
+ << stringifyVersion(targetVersion)
+ << ") is not backwards compatible with the op compliance "
+ "specification version ("
+ << stringifyVersion(complianceVersion) << ")\n";
+ return failure();
+ }
+
return success();
}
@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
- CheckCondition condition = CheckCondition::invalid;
- const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
- const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ const auto maybeProfDef = getOperatorDefinition<Profile>(op);
+ const auto maybeExtDef = getOperatorDefinition<Extension>(op);
if (failed(maybeProfDef) && failed(maybeExtDef))
return success();
- const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
- (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ const bool hasEntry =
+ (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
- for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
+ for (const auto &versionedTypeInfos :
+ complianceInfos.operandTypeInfoSet) {
+ const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
-SmallVector<T> TosaProfileCompliance::findMatchedProfile(
- Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
- CheckCondition &condition) {
+OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
+ Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
- SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
- for (SmallVector<TypeInfo> expected : sets) {
+ SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
+ for (const auto &set : sets) {
+ SmallVector<TypeInfo> expected = set.first;
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
- bool is_found = true;
+ bool isFound = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
- is_found = false;
+ isFound = false;
break;
}
}
- if (is_found == true) {
- condition = compInfo[i].condition;
- return compInfo[i].mode;
+ if (isFound == true) {
+ SmallVector<VersionedTypeInfo> typeInfoSet{set};
+ OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
+ compInfo[i].condition};
+ return info;
}
}
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index 9a24c2b..a2cff6a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -21,10 +21,10 @@ using namespace mlir;
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
-LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
+[[maybe_unused]] static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
-LLVM_ATTRIBUTE_UNUSED static LogicalResult
-generatedTypePrinter(Type def, AsmPrinter &printer);
+[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def,
+ AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index e1648ab9..305b06eb 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -81,21 +81,10 @@ SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
return computeElementwiseMulImpl(v1, v2);
}
-int64_t mlir::computeSum(ArrayRef<int64_t> basis) {
- assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
- "basis must be nonnegative");
- if (basis.empty())
- return 0;
- return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>());
-}
-
int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
"basis must be nonnegative");
- if (basis.empty())
- return 1;
- return std::accumulate(basis.begin(), basis.end(), 1,
- std::multiplies<int64_t>());
+ return llvm::product_of(basis);
}
int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
@@ -158,19 +147,11 @@ SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
}
AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
- if (basis.empty())
- return getAffineConstantExpr(0, ctx);
- return std::accumulate(basis.begin(), basis.end(),
- getAffineConstantExpr(0, ctx),
- std::plus<AffineExpr>());
+ return llvm::sum_of(basis, getAffineConstantExpr(0, ctx));
}
AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
- if (basis.empty())
- return getAffineConstantExpr(1, ctx);
- return std::accumulate(basis.begin(), basis.end(),
- getAffineConstantExpr(1, ctx),
- std::multiplies<AffineExpr>());
+ return llvm::product_of(basis, getAffineConstantExpr(1, ctx));
}
AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 7b2734d..6e9118e 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -374,11 +374,11 @@ mlir::composeReassociationIndices(
if (consumerReassociations.empty())
return composedIndices;
- size_t consumerDims = std::accumulate(
- consumerReassociations.begin(), consumerReassociations.end(), 0,
- [](size_t all, ReassociationIndicesRef indices) {
- return all + indices.size();
- });
+ size_t consumerDims =
+ llvm::accumulate(consumerReassociations, size_t(0),
+ [](size_t all, ReassociationIndicesRef indices) {
+ return all + indices.size();
+ });
if (producerReassociations.size() != consumerDims)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a7e3ba8..45c54c7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2496,8 +2496,7 @@ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
auto srcElems = vector::ToElementsOp::create(
rewriter, toElementsOp.getLoc(), bcastOp.getSource());
- int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t dstCount = llvm::product_of(dstShape);
SmallVector<Value> replacements;
replacements.reserve(dstCount);
@@ -7602,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// ```mlir
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+/// ```
+///
+/// as,
+///
+/// ```mlir
+/// %out = arith.constant dense<false> : vector<3xi1>.
+/// ```
+///
+/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
+/// is false at ALL indices we fold. If the constant was 1, then
+/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
+/// conservatively preferring the 'compact' vector.step representation.
+///
+/// Note: this folder only works for the case where the constant (`%cst` above)
+/// is the second operand of the comparison. The arith.cmpi canonicalizer will
+/// ensure that constants are always second (on the right).
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+ for (OpOperand &use : stepOp.getResult().getUses()) {
+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
+ if (!cmpiOp)
+ continue;
+
+ // arith.cmpi canonicalizer makes constants final operands.
+ const unsigned stepOperandNumber = use.getOperandNumber();
+ if (stepOperandNumber != 0)
+ continue;
+
+ // Check that operand 1 is a constant.
+ unsigned constOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+ std::optional<int64_t> maybeConstValue =
+ getConstantIntValue(otherOperand);
+ if (!maybeConstValue.has_value())
+ continue;
+
+ int64_t constValue = maybeConstValue.value();
+ arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+ auto maybeSplat = [&]() -> std::optional<bool> {
+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
+ if ((pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ult;
+
+ // Handle ule and ugt.
+ if ((pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) &&
+ stepSize - 1 <= constValue) {
+ return pred == arith::CmpIPredicate::ule;
+ }
+
+ // Handle eq and ne.
+ if ((pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ne;
+
+ return std::nullopt;
+ }();
+
+ if (!maybeSplat.has_value())
+ continue;
+
+ rewriter.setInsertionPointAfter(cmpiOp);
+
+ auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+ if (!type)
+ continue;
+
+ auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
+ Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+ type, boolAttr);
+
+ rewriter.replaceOp(cmpiOp, splat);
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<StepCompareFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index c5f22b2..0eba0b1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
#define DEBUG_TYPE "vector-shape-cast-lowering"
@@ -166,10 +167,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
const VectorType resultType = shapeCast.getResultVectorType();
const ArrayRef<int64_t> resultShape = resultType.getShape();
- const int64_t nSlices =
- std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1,
- std::multiplies<int64_t>());
-
+ const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim));
SmallVector<int64_t> extractIndex(sourceDim, 0);
SmallVector<int64_t> insertIndex(resultDim, 0);
Value result = ub::PoisonOp::create(rewriter, loc, resultType);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f..7c019e7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
// yielded value, and:
- // 1. recording the unique first position at which the value is yielded.
+ // 1. recording the unique first position at which the value with uses is
+ // yielded.
// 2. recording for the result, the first position at which the dedup'ed
// value is yielded.
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
+ if (result.use_empty())
+ continue;
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
- if (result.use_empty() || !it.second)
+ if (!it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());
- llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
- origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
- // Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ // Replace the old `WarpOp` with the new one that has additional yield
+ // values and types.
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
- rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
- static_cast<bool>(ifOp.thenBlock()),
+ rewriter, ifOp.getLoc(), newIfOpDistResTypes,
+ newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
- newWarpOp.getResult(warpResRangeStart));
+ newWarpOp.getResult(newIndices[warpResRangeStart]));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `IfOp`.
- for (auto [origIdx, newIdx] : origToNewYieldIdx)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `IfOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(ifOp);
- rewriter.eraseOp(warpOp);
return success();
}
@@ -2038,11 +2032,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
// Newly created `WarpOp` will yield values in following order:
- // 1. All init args of the `ForOp`.
- // 2. All escaping values.
- // 3. All non-`ForOp` yielded values.
+ // 1. Loop bounds.
+ // 2. All init args of the `ForOp`.
+ // 3. All escaping values.
+ // 4. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
+ newWarpOpYieldValues.insert(
+ newWarpOpYieldValues.end(),
+ {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ {forOp.getLowerBound().getType(),
+ forOp.getUpperBound().getType(),
+ forOp.getStep().getType()});
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
newWarpOpYieldValues.push_back(initArg);
// Compute the distributed type for this init arg.
@@ -2065,36 +2067,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueDistTypes.begin(),
escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
- // types. We also create a mapping between the non-`ForOp` yielded value
- // index and the corresponding new `WarpOp` yield value index (needed to
- // update users later).
- llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+ // types.
for (auto [i, v] :
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
- nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
+ const unsigned initArgsStartIdx = 3; // After loop bounds.
const unsigned escapingValuesStartIdx =
+ initArgsStartIdx +
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
- for (size_t i = 0; i < escapingValuesStartIdx; ++i)
- newForOpOperands.push_back(newWarpOp.getResult(i));
+ for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
+ newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
- rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
- forOp.getUnsignedCmp());
+ rewriter, forOp.getLoc(),
+ /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
+ /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
+ /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
@@ -2110,7 +2113,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
- innerWarpInput.push_back(newWarpOp.getResult(i));
+ innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
@@ -2146,20 +2149,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (!innerWarp.getResults().empty())
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
- // Update the users of original `WarpOp` results that were coming from the
+ // Update the users of the new `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `ForOp`.
- for (auto [origIdx, newIdx] : nonForResultMapping)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(forOp);
- rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 963b2c8..aa2dd89 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
#define DEBUG_TYPE "vector-drop-unit-dim"
@@ -557,8 +558,7 @@ struct CastAwayConstantMaskLeadingOneDim
// If any of the dropped unit dims has a size of `0`, the entire mask is a
// zero mask, else the unit dim has no effect on the mask.
int64_t flatLeadingSize =
- std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
- static_cast<int64_t>(1), std::multiplies<int64_t>());
+ llvm::product_of(dimSizes.take_front(dropDim + 1));
SmallVector<int64_t> newDimSizes = {flatLeadingSize};
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 1b656d8..ea93085 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -817,6 +817,50 @@ struct LinearizeVectorToElements final
}
};
+/// Convert broadcasts from scalars or 1-element vectors, such as
+///
+/// ```mlir
+/// vector.broadcast %value : f32 to vector<4x4xf32>
+/// ```
+///
+/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
+/// The above becomes,
+///
+/// ```mlir
+/// %out_1d = vector.broadcast %value : f32 to vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// ```
+struct LinearizeVectorBroadcast final
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using Base::Base;
+
+ LinearizeVectorBroadcast(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ int numElements = 1;
+ Type sourceType = broadcastOp.getSourceType();
+ if (auto vecType = dyn_cast<VectorType>(sourceType)) {
+ numElements = vecType.getNumElements();
+ }
+
+ if (numElements != 1) {
+ return rewriter.notifyMatchFailure(
+ broadcastOp, "only broadcasts of single elements can be linearized.");
+ }
+
+ auto dstTy = getTypeConverter()->convertType(broadcastOp.getType());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy,
+ adaptor.getSource());
+
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
- LinearizeVectorFromElements, LinearizeVectorToElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorBroadcast, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5..fbae098 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
+ int64_t targetShapeRank = targetShape->size();
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
- // Bail-out if rank(source) != rank(target). The main limitation here is the
- // fact that `ExtractStridedSlice` requires the rank for the input and
- // output to match. If needed, we can relax this later.
- if (originalSize.size() != targetShape->size())
- return rewriter.notifyMatchFailure(
- op, "expected input vector rank to match target shape rank");
+ int64_t originalShapeRank = originalSize.size();
+
Location loc = op->getLoc();
+
+ // Handle rank mismatch by adding leading unit dimensions to targetShape
+ SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+ int64_t rankDiff = originalShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(),
+ adjustedTargetShape.begin() + rankDiff, 1);
+ std::copy(targetShape->begin(), targetShape->end(),
+ adjustedTargetShape.begin() + rankDiff);
+
+ int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t> strides(targetShape->size(), 1);
- VectorType newVecType =
+ SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
+ VectorType unrolledVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalSize, *targetShape)) {
+ StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
extractOperands.push_back(operand.get());
continue;
}
- extractOperands.push_back(
- rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, operand.get(), offsets, *targetShape, strides));
+ Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+ // Reshape to remove leading unit dims if needed
+ if (adjustedTargetShapeRank > targetShapeRank) {
+ extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, VectorType::get(*targetShape, vecType.getElementType()),
+ extracted);
+ }
+ extractOperands.push_back(extracted);
}
+
Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, op, extractOperands, newVecType);
+ rewriter, loc, op, extractOperands, unrolledVecType);
+
+ Value computeResult = newOp->getResult(0);
+
+ // Use strides sized to targetShape for proper insertion
+ SmallVector<int64_t> insertStrides =
+ (adjustedTargetShapeRank > targetShapeRank)
+ ? SmallVector<int64_t>(targetShapeRank, 1)
+ : strides;
+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, newOp->getResult(0), result, offsets, strides);
+ loc, computeResult, result, offsets, insertStrides);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a..c809c502 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -91,7 +91,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
- // transpose pattens. Otherwise, these patterns are not applicable.
+ // transpose patterns. Otherwise, these patterns are not applicable.
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
op.getPermutation()))
return failure();
diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
index 89b62a2..a514ea9 100644
--- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
+++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
@@ -39,28 +40,6 @@ void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
opPrinter.printKeywordOrString("else ");
opPrinter.printRegion(elseRegion);
}
-
-ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) {
- std::string keyword;
- auto initLocation = opParser.getCurrentLocation();
- std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
- if (keyword == "nested" or keyword == "") {
- visibility = StringAttr::get(opParser.getContext(), "nested");
- return ParseResult::success();
- }
-
- if (keyword == "public" || keyword == "private") {
- visibility = StringAttr::get(opParser.getContext(), keyword);
- return ParseResult::success();
- }
- opParser.emitError(initLocation, "expecting symbol visibility");
- return ParseResult::failure();
-}
-
-void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op,
- Attribute visibility) {
- opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref());
-}
} // namespace
#define GET_OP_CLASSES
@@ -167,10 +146,23 @@ Block *FuncOp::addEntryBlock() {
void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, FunctionType funcType) {
- FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested");
+ FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {});
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto *ctx = parser.getContext();
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ bool exported{false};
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ exported = true;
+ }
+
auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
@@ -191,11 +183,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
return builder.getFunctionType(argTypesWithoutLocal, results);
};
-
- return function_interface_impl::parseFunctionOp(
+ auto funcParseRes = function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ if (exported)
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ return funcParseRes;
}
LogicalResult FuncOp::verifyBody() {
@@ -224,9 +218,18 @@ LogicalResult FuncOp::verifyBody() {
}
void FuncOp::print(OpAsmPrinter &p) {
+ /// If exported, print it before and mask it before printing
+ /// using generic interface.
+ auto exported = getExported();
+ if (exported) {
+ p << " exported";
+ removeExportedAttr();
+ }
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
+ if (exported)
+ setExported(true);
}
//===----------------------------------------------------------------------===//
@@ -237,38 +240,37 @@ void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef symbol, StringRef moduleName,
StringRef importName, FunctionType type) {
FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, {}, {}, odsBuilder.getStringAttr("nested"));
+ type, {}, {});
}
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
-
-void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, Type type, bool isMutable) {
- GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable,
- odsBuilder.getStringAttr("nested"));
-}
-
// Custom formats
ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
StringAttr symbolName;
Type globalType;
auto *ctx = parser.getContext();
- ParseResult res = parser.parseSymbolName(
- symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
+ std::string visibilityString;
+ auto loc = parser.getNameLoc();
+ ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
+ if (res.succeeded()) {
+ if (visibilityString != "exported")
+ return parser.emitError(
+ loc, "expecting either `exported` or symbol name. got ")
+ << visibilityString;
+ result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
+ }
+ res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
+ result.attributes);
res = parser.parseType(globalType);
result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
std::string mutableString;
res = parser.parseOptionalKeywordOrString(&mutableString);
if (res.succeeded() && mutableString == "mutable")
result.addAttribute("isMutable", UnitAttr::get(ctx));
- std::string visibilityString;
- res = parser.parseOptionalKeywordOrString(&visibilityString);
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, visibilityString));
+
res = parser.parseColon();
Region *globalInitRegion = result.addRegion();
res = parser.parseRegion(*globalInitRegion);
@@ -276,11 +278,11 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
}
void GlobalOp::print(OpAsmPrinter &printer) {
+ if (getExported())
+ printer << " exported";
printer << " @" << getSymName().str() << " " << getType();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " :";
Region &body = getRegion();
if (!body.empty()) {
@@ -319,13 +321,6 @@ GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// GlobalImportOp
//===----------------------------------------------------------------------===//
-void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, Type type, bool isMutable) {
- GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, isMutable, odsBuilder.getStringAttr("nested"));
-}
-
ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
auto *ctx = parser.getContext();
ParseResult res = parseImportOp(parser, result);
@@ -335,12 +330,8 @@ ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
if (res.succeeded() && mutableOrSymVisString == "mutable") {
result.addAttribute("isMutable", UnitAttr::get(ctx));
- res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
}
- if (res.succeeded())
- result.addAttribute("sym_visibility",
- StringAttr::get(ctx, mutableOrSymVisString));
res = parser.parseColon();
Type importedType;
@@ -356,8 +347,6 @@ void GlobalImportOp::print(OpAsmPrinter &printer) {
<< "\" as @" << getSymName();
if (getIsMutable())
printer << " mutable";
- if (auto vis = getSymVisibility())
- printer << " " << *vis;
printer << " : " << getType();
}
@@ -431,27 +420,6 @@ LogicalResult LocalTeeOp::verify() {
Block *LoopOp::getLabelTarget() { return &getBody().front(); }
//===----------------------------------------------------------------------===//
-// MemOp
-//===----------------------------------------------------------------------===//
-
-void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, LimitType limit) {
- MemOp::build(odsBuilder, odsState, symbol, limit,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// MemImportOp
-//===----------------------------------------------------------------------===//
-
-void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, LimitType limits) {
- MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- limits, odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
// ReinterpretOp
//===----------------------------------------------------------------------===//
@@ -471,24 +439,3 @@ LogicalResult ReinterpretOp::verify() {
//===----------------------------------------------------------------------===//
void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
-
-//===----------------------------------------------------------------------===//
-// TableOp
-//===----------------------------------------------------------------------===//
-
-void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, TableType type) {
- TableOp::build(odsBuilder, odsState, symbol, type,
- odsBuilder.getStringAttr("nested"));
-}
-
-//===----------------------------------------------------------------------===//
-// TableImportOp
-//===----------------------------------------------------------------------===//
-
-void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef symbol, StringRef moduleName,
- StringRef importName, TableType type) {
- TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
- type, odsBuilder.getStringAttr("nested"));
-}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d..1599ae9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
}
printer << ">";
}
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+ OpBuilder &builder) {
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b) \
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b) \
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b) \
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<int64_t> blockShape) {
+
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ return blockedOffsets;
+}
+
+// Get strides as vector of integer for MemDesc.
+SmallVector<int64_t> MemDescType::getStrideShape() {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+
+ ArrayAttr strideAttr = getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+
+ SmallVector<int64_t> innerBlkShape = getBlockShape();
+
+ // get perm from FCD to LCD
+ // perm[i] = the dim with i-th smallest stride
+ SmallVector<int, 4> perm =
+ llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+
+ SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
+ innerBlkStride[perm[0]] = 1;
+ for (size_t i = 1; i < perm.size(); ++i)
+ innerBlkStride[perm[i]] =
+ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
+
+ // compute the original matrix shape using the stride info
+ // and compute the number of blocks in each dimension
+ // The shape of highest dim can't be derived from stride info,
+ // but doesn't impact the stride computation for blocked layout.
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+ }
+
+ int64_t innerBlkSize = 1;
+ for (auto s : innerBlkShape)
+ innerBlkSize *= s;
+
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
+ outerBlkStride[perm[0]] = innerBlkSize;
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ outerBlkStride[perm[i + 1]] =
+ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
+ }
+
+ // combine the inner and outer strides
+ SmallVector<int64_t> blockedStrides;
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+ return blockedStrides;
+}
+
+// Calculate the linear offset using the blocked offsets and stride
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets) {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
+ SmallVector<int64_t> blockShape = getBlockShape();
+ SmallVector<int64_t> strides = getStrideShape();
+ SmallVector<OpFoldResult> blockedOffsets;
+
+ // blockshape equal to matrixshape means no blocking
+ if (llvm::equal(blockShape, matrixShape)) {
+ // remove the outer dims from strides
+ strides.erase(strides.begin(), strides.begin() + matrixShape.size());
+ } else {
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ // say the original offset is [y, x], and the block shape is [By, Bx],
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+ offsets = blockedOffsets;
+ }
+
+ // Start with initial value as matrix descriptor's base offset.
+ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+ }
+
+ return linearOffset;
+}
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788..abd12e2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,8 +20,8 @@
#define DEBUG_TYPE "xegpu"
-namespace mlir {
-namespace xegpu {
+using namespace mlir;
+using namespace mlir::xegpu;
static bool isSharedMemory(const MemRefType &memrefTy) {
Attribute attr = memrefTy.getMemorySpace();
@@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
return success();
}
+LogicalResult
+IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
+ UnitAttr subgroup_block_io,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!dataTy) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ else
+ return success();
+ }
+
+ if (mdescTy.getRank() != 2)
+ return emitError() << "mem_desc must be 2D.";
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+
+ if (dataShape.size() == 2) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitError() << "data shape must not exceed mem_desc shape.";
+ } else {
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ // if the subgroup_block_io attribute is set, mdescTy must have block
+ // attribute
+ if (subgroup_block_io && !blockShape.size())
+ return emitError() << "mem_desc must have block attribute when "
+ "subgroup_block_io is set.";
+ // if the subgroup_block_io attribute is set, the memdesc should be row
+ // major
+ if (subgroup_block_io && mdescTy.isColMajor())
+ return emitError() << "mem_desc should be row major when "
+ "subgroup_block_io is set.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ // Call the generated builder with all parameters (including optional ones as
+ // nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult LoadMatrixOp::verify() {
- VectorType resTy = getRes().getType();
- MemDescType mdescTy = getMemDesc().getType();
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
+ auto resTy = dyn_cast<VectorType>(getRes().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
- ArrayRef<int64_t> valueShape = resTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed mem_desc shape.");
- return success();
+ return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult StoreMatrixOp::verify() {
- VectorType dataTy = getData().getType();
- MemDescType mdescTy = getMemDesc().getType();
-
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
-
- ArrayRef<int64_t> dataShape = dataTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("data shape must not exceed mem_desc shape.");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// XeGPU_MemDescSubviewOp
-//===----------------------------------------------------------------------===//
-
-void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
- Type resTy, Value src,
- llvm::ArrayRef<OpFoldResult> offsets) {
- llvm::SmallVector<Value> dynamicOffsets;
- llvm::SmallVector<int64_t> staticOffsets;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
- build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
-}
-
-LogicalResult MemDescSubviewOp::verify() {
- MemDescType srcTy = getSrc().getType();
- MemDescType resTy = getRes().getType();
- ArrayRef<int64_t> srcShape = srcTy.getShape();
- ArrayRef<int64_t> resShape = resTy.getShape();
-
- if (srcTy.getRank() < resTy.getRank())
- return emitOpError("result rank must not exceed source rank.");
- if (llvm::any_of(
- llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed source shape.");
-
- if (srcTy.getStrides() != resTy.getStrides())
- return emitOpError("result must inherit the source strides.");
-
- return success();
+ auto dataTy = dyn_cast<VectorType>(getData().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MemDescType mdescTy = getMemDesc().getType();
+ return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
-} // namespace xegpu
-} // namespace mlir
-
namespace mlir {
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 36c498e..f77784a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
return getTileShape(op->getOpResult(0));
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
+ xegpu::StoreMatrixOp>(op))
return getTileShape(op->getOpOperand(0));
- if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
+ if (isa<xegpu::StoreNdOp>(op))
return getTileShape(op->getOpOperand(1));
+ // Handle LoadGatherOp and StoreScatterOp (with and without offset)
+ if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ if (loadGatherOp.getOffsets())
+ return getTileShape(loadGatherOp->getOpResult(0));
+ else
+ return getTileShape(loadGatherOp->getOpOperand(0));
+ }
+
+ if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+ return getTileShape(storeScatterOp.getOffsets()
+ ? storeScatterOp->getOpOperand(0)
+ : storeScatterOp->getOpOperand(1));
+
if (isa<xegpu::DpasOp>(op)) {
std::optional<SmallVector<int64_t>> aTile =
getTileShape(op->getOpOperand(0));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0f..aafa1b7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- VectorType valueTy = op.getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+ assert(valueTy && "the value type must be vector type!");
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
return failure();
@@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
return failure();
Location loc = op.getLoc();
- VectorType valueTy = op.getData().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
+ assert(valueTy && "the value type must be vector type!");
ArrayRef<int64_t> shape = valueTy.getShape();
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c28d2fc..31a967d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
return failure();
ArrayRef<int64_t> wgShape = op.getDataShape();
- VectorType valueTy = op.getRes().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
+ assert(valueTy && "the value type must be vector type!");
Type elemTy = valueTy.getElementType();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b72d564..2c56a43 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -52,8 +52,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
// compute sgSize by multiply elements of laneLayout
// e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
// e.g. for 1D layout, sgSize = laneLayout[0]
- auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,
- std::multiplies<int64_t>());
+ int64_t sgSize = llvm::product_of(laneLayout);
// Case 1: regular loads/stores
auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3d19c5a..9b23dd6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2200,10 +2200,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
os << '>';
}
os << '[';
- interleave(
- loc.getLocations(),
- [&](Location loc) { printLocationInternal(loc, pretty); },
- [&]() { os << ", "; });
+ interleaveComma(loc.getLocations(), [&](Location loc) {
+ printLocationInternal(loc, pretty);
+ });
os << ']';
})
.Default([&](LocationAttr loc) {
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 776b5c6..4d81918 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl {
}
// Otherwise, try to load the source file.
- std::string ignored;
- unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
+ auto bufferOrErr = llvm::MemoryBuffer::getFile(filename);
+ if (!bufferOrErr)
+ return 0;
+ unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc());
filenameToBufId[filename] = id;
return id;
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 1fa04ed..5f63fe6 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -121,6 +121,11 @@ namespace mlir {
class MLIRContextImpl {
public:
//===--------------------------------------------------------------------===//
+ // Remark
+ //===--------------------------------------------------------------------===//
+ std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
+
+ //===--------------------------------------------------------------------===//
// Debugging
//===--------------------------------------------------------------------===//
@@ -135,11 +140,6 @@ public:
DiagnosticEngine diagEngine;
//===--------------------------------------------------------------------===//
- // Remark
- //===--------------------------------------------------------------------===//
- std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
-
- //===--------------------------------------------------------------------===//
// Options
//===--------------------------------------------------------------------===//
@@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
}
-MLIRContext::~MLIRContext() = default;
+MLIRContext::~MLIRContext() {
+ // finalize remark engine before destroying anything else.
+ impl->remarkEngine.reset();
+}
/// Copy the specified array of elements into memory managed by the provided
/// bump pointer allocator. This assumes the elements are all PODs.
@@ -1201,7 +1204,7 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
/// present in result expressions is less than `dimCount` and the highest index
/// of symbolic identifier present in result expressions is less than
/// `symbolCount`.
-LLVM_ATTRIBUTE_UNUSED static bool
+[[maybe_unused]] static bool
willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results) {
int64_t maxDimPosition = -1;
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 8bcfa46..ce421f4 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FoldInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include <numeric>
@@ -1274,10 +1275,7 @@ LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
return op->emitOpError("'")
<< attrName << "' attribute cannot have negative elements";
- size_t totalCount =
- std::accumulate(sizes.begin(), sizes.end(), 0,
- [](unsigned all, int32_t one) { return all + one; });
-
+ size_t totalCount = llvm::sum_of(sizes, size_t(0));
if (totalCount != expectedCount)
return op->emitOpError()
<< valueGroupName << " count (" << expectedCount
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 394ac77..2a37f38 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -406,15 +406,13 @@ OperandRangeRange::OperandRangeRange(OperandRange operands,
OperandRange OperandRangeRange::join() const {
const OwnerT &owner = getBase();
ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
- return OperandRange(owner.first,
- std::accumulate(sizeData.begin(), sizeData.end(), 0));
+ return OperandRange(owner.first, llvm::sum_of(sizeData));
}
OperandRange OperandRangeRange::dereference(const OwnerT &object,
ptrdiff_t index) {
ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
- uint32_t startIndex =
- std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
+ uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
}
@@ -565,8 +563,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
ptrdiff_t index) {
ArrayRef<int32_t> sizeData =
llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
- uint32_t startIndex =
- std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
+ uint32_t startIndex = llvm::sum_of(sizeData.take_front(index));
return object.first.slice(
startIndex, *(sizeData.begin() + index),
MutableOperandRange::OperandSegment(index, object.second));
diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp
index a55f61a..031eae2 100644
--- a/mlir/lib/IR/Remarks.cpp
+++ b/mlir/lib/IR/Remarks.cpp
@@ -16,7 +16,7 @@
#include "llvm/ADT/StringRef.h"
using namespace mlir::remark::detail;
-
+using namespace mlir::remark;
//------------------------------------------------------------------------------
// Remark
//------------------------------------------------------------------------------
@@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) {
void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
// Header: [Type] pass:remarkName
StringRef type = getRemarkTypeString();
- StringRef categoryName = getFullCategoryName();
+ StringRef categoryName = getCombinedCategoryName();
StringRef name = remarkName;
os << '[' << type << "] ";
@@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
os << "Function=" << getFunction() << " | ";
if (printLocation) {
- if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation()))
+ if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) {
os << " @" << flc.getFilename() << ":" << flc.getLine() << ":"
<< flc.getColumn();
+ }
}
printArgs(os, getArgs());
@@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const {
r.RemarkType = getRemarkType();
r.RemarkName = getRemarkName();
// MLIR does not use passes; instead, it has categories and sub-categories.
- r.PassName = getFullCategoryName();
+ r.PassName = getCombinedCategoryName();
r.FunctionName = getFunction();
r.Loc = locLambda();
for (const Remark::Arg &arg : getArgs()) {
@@ -225,26 +226,42 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc,
// RemarkEngine
//===----------------------------------------------------------------------===//
-void RemarkEngine::report(const Remark &&remark) {
+void RemarkEngine::reportImpl(const Remark &remark) {
// Stream the remark
- if (remarkStreamer)
+ if (remarkStreamer) {
remarkStreamer->streamOptimizationRemark(remark);
+ }
// Print using MLIR's diagnostic
if (printAsEmitRemarks)
emitRemark(remark.getLocation(), remark.getMsg());
}
+void RemarkEngine::report(const Remark &&remark) {
+ if (remarkEmittingPolicy)
+ remarkEmittingPolicy->reportRemark(remark);
+}
+
RemarkEngine::~RemarkEngine() {
+ if (remarkEmittingPolicy)
+ remarkEmittingPolicy->finalize();
+
if (remarkStreamer)
remarkStreamer->finalize();
}
-llvm::LogicalResult
-RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
- std::string *errMsg) {
- // If you need to validate categories/filters, do so here and set errMsg.
+llvm::LogicalResult RemarkEngine::initialize(
+ std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ std::string *errMsg) {
+
remarkStreamer = std::move(streamer);
+
+ auto reportFunc =
+ std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1);
+ remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc)));
+
+ this->remarkEmittingPolicy = std::move(remarkEmittingPolicy);
return success();
}
@@ -301,14 +318,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks,
}
llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
- MLIRContext &ctx,
- std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
- const remark::RemarkCategories &cats, bool printAsEmitRemarks) {
+ MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ const RemarkCategories &cats, bool printAsEmitRemarks) {
auto engine =
- std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats);
+ std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats);
std::string errMsg;
- if (failed(engine->initialize(std::move(streamer), &errMsg))) {
+ if (failed(engine->initialize(std::move(streamer),
+ std::move(remarkEmittingPolicy), &errMsg))) {
llvm::report_fatal_error(
llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg);
}
@@ -316,3 +334,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
return success();
}
+
+//===----------------------------------------------------------------------===//
+// Remark emitting policies
+//===----------------------------------------------------------------------===//
+
+namespace mlir::remark {
+RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default;
+RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default;
+} // namespace mlir::remark
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index d2d115e..e438631 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -104,8 +104,8 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
if (dims.empty())
return success();
- auto staticDim = std::accumulate(
- dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
+ auto staticDim =
+ llvm::accumulate(dims, dims.front(), [](auto fold, auto dim) {
return ShapedType::isDynamic(dim) ? fold : dim;
});
return success(llvm::all_of(dims, [&](auto dim) {
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 388de1c..f96af02 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES
FunctionInterfaces.cpp
IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
+ InferStridedMetadataInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
MemOpInterfaces.cpp
@@ -64,6 +65,21 @@ add_mlir_library(MLIRFunctionInterfaces
add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
+
+add_mlir_library(MLIRInferStridedMetadataInterface
+ InferStridedMetadataInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
+
+ DEPENDS
+ MLIRInferStridedMetadataInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRInferIntRangeInterface
+ MLIRIR
+)
+
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_library(MLIRLoopLikeInterface
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d..84fc9b8 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
return os;
}
+SmallVector<IntegerValueRange>
+mlir::getIntValueRanges(ArrayRef<OpFoldResult> values,
+ GetIntRangeFn getIntRange, int32_t indexBitwidth) {
+ SmallVector<IntegerValueRange> ranges;
+ ranges.reserve(values.size());
+ for (OpFoldResult ofr : values) {
+ if (auto value = dyn_cast<Value>(ofr)) {
+ ranges.push_back(getIntRange(value));
+ continue;
+ }
+
+ // Create a constant range.
+ auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
+ ranges.emplace_back(ConstantIntRanges::constant(
+ attr.getValue().sextOrTrunc(indexBitwidth)));
+ }
+ return ranges;
+}
+
void mlir::intrange::detail::defaultInferResultRanges(
InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
SetIntLatticeFn setResultRanges) {
diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
new file mode 100644
index 0000000..483e9f1
--- /dev/null
+++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
@@ -0,0 +1,36 @@
+//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <optional>
+
+using namespace mlir;
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc"
+
+void StridedMetadataRange::print(raw_ostream &os) const {
+ if (isUninitialized()) {
+ os << "strided_metadata<None>";
+ return;
+ }
+ os << "strided_metadata<offset = [";
+ llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "], sizes = [";
+ llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "], strides = [";
+ llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "]>";
+}
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
index c67b242..dd413d2de 100644
--- a/mlir/lib/RegisterAllPasses.cpp
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -98,4 +98,5 @@ void mlir::registerAllPasses() {
sparse_tensor::registerSparseTensorPipelines();
tosa::registerTosaToLinalgPipelines();
gpu::registerGPUToNVVMPipeline();
+ gpu::registerGPUToXeVMPipeline();
}
diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp
index d213a1a..bf36286 100644
--- a/mlir/lib/Remark/RemarkStreamer.cpp
+++ b/mlir/lib/Remark/RemarkStreamer.cpp
@@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() {
namespace mlir::remark {
LogicalResult enableOptimizationRemarksWithLLVMStreamer(
MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
const RemarkCategories &cat, bool printAsEmitRemarks) {
FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr =
@@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer(
if (failed(sOr))
return failure();
- return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat,
+ return remark::enableOptimizationRemarks(ctx, std::move(*sOr),
+ std::move(remarkEmittingPolicy), cat,
printAsEmitRemarks);
}
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 33fbd2a..42843ea 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1835,8 +1835,7 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
return nullptr;
ArrayRef<int32_t> segments = segmentAttr;
- unsigned startIndex =
- std::accumulate(segments.begin(), segments.begin() + index, 0);
+ unsigned startIndex = llvm::sum_of(segments.take_front(index));
values = values.slice(startIndex, *std::next(segments.begin(), index));
LDBG() << " * Extracting range[" << startIndex << ", "
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index cb90ef8..d52d5e7 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
-void StaticVerifierFunctionEmitter::emitOpConstraints(
- ArrayRef<const Record *> opDefs) {
- NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
+void StaticVerifierFunctionEmitter::emitOpConstraints() {
emitTypeConstraints();
emitAttrConstraints();
emitPropConstraints();
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 5fe5f41..1243511 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -357,11 +357,6 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
if (expressionOp.getDoNotInline())
return false;
- // Do not inline expressions with side effects to prevent side-effect
- // reordering.
- if (expressionOp.hasSideEffects())
- return false;
-
// Do not inline expressions with multiple uses.
Value result = expressionOp.getResult();
if (!result.hasOneUse())
@@ -377,7 +372,34 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
// Do not inline expressions used by other expressions or by ops with the
// CExpressionInterface. If this was intended, the user could have been merged
// into the expression op.
- return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user);
+ if (isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user))
+ return false;
+
+ // Expressions with no side-effects can safely be inlined.
+ if (!expressionOp.hasSideEffects())
+ return true;
+
+ // Expressions with side-effects can be only inlined if side-effect ordering
+ // in the program is provably retained.
+
+ // Require the user to immediately follow the expression.
+ if (++Block::iterator(expressionOp) != Block::iterator(user))
+ return false;
+
+ // These single-operand ops are safe.
+ if (isa<emitc::IfOp, emitc::SwitchOp, emitc::ReturnOp>(user))
+ return true;
+
+ // For assignment look for specific cases to inline as evaluation order of
+ // its lvalue and rvalue is undefined in C.
+ if (auto assignOp = dyn_cast<emitc::AssignOp>(user)) {
+ // Inline if this assignment is of the form `<var> = <expression>`.
+ if (expressionOp.getResult() == assignOp.getValue() &&
+ isa_and_present<VariableOp>(assignOp.getVar().getDefiningOp()))
+ return true;
+ }
+
+ return false;
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index e3f075f..8ecb084 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -464,12 +464,6 @@ static std::string generateOpDefinition(irdl::detail::dictionary &dict,
auto opStrings = getStrings(op);
fillDict(dict, opStrings);
- const auto operandCount = opStrings.opOperandNames.size();
- const auto operandNames =
- operandCount ? joinNameList(opStrings.opOperandNames) : "{\"\"}";
-
- const auto resultNames = joinNameList(opStrings.opResultNames);
-
auto resultTypes = llvm::join(
llvm::map_range(opStrings.opResultNames,
[](StringRef attr) -> std::string {
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 4bbcd8e..db39c70 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) {
return UnknownLoc::get(context);
// Add a fused location to link the subprogram information.
- StringAttr funcName = StringAttr::get(context, subprogram->getName());
StringAttr fileName = StringAttr::get(context, subprogram->getFilename());
return FusedLocWith<DISubprogramAttr>::get(
- {NameLoc::get(funcName),
- FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
+ {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
translate(subprogram), context);
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1e2099d..8de49dd 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -246,7 +246,7 @@ public:
// Rewrite all uses of the original variable in `BBName`
// with the linear variable in-place
- void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+ void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName,
size_t varIndex) {
llvm::SmallVector<llvm::User *> users;
for (llvm::User *user : linearOrigVal[varIndex]->users())
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 9603813..857e31b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2604,6 +2604,7 @@ static constexpr std::array kExplicitLLVMFuncOpAttributes{
StringLiteral("denormal-fp-math-f32"),
StringLiteral("fp-contract"),
StringLiteral("frame-pointer"),
+ StringLiteral("inlinehint"),
StringLiteral("instrument-function-entry"),
StringLiteral("instrument-function-exit"),
StringLiteral("memory"),
@@ -2643,6 +2644,8 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
funcOp.setNoInline(true);
if (func->hasFnAttribute(llvm::Attribute::AlwaysInline))
funcOp.setAlwaysInline(true);
+ if (func->hasFnAttribute(llvm::Attribute::InlineHint))
+ funcOp.setInlineHint(true);
if (func->hasFnAttribute(llvm::Attribute::OptimizeNone))
funcOp.setOptimizeNone(true);
if (func->hasFnAttribute(llvm::Attribute::Convergent))
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 5a3eb20..147613f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -922,8 +922,7 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
"operand bundles and tags do not match");
- numOpBundleOperands =
- std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0));
+ numOpBundleOperands = llvm::sum_of(opBundleSizes);
assert(numOpBundleOperands <= intrOp->getNumOperands() &&
"operand bundle operands is more than the number of operands");
@@ -1653,6 +1652,8 @@ static void convertFunctionAttributes(LLVMFuncOp func,
llvmFunc->addFnAttr(llvm::Attribute::NoInline);
if (func.getAlwaysInlineAttr())
llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline);
+ if (func.getInlineHintAttr())
+ llvmFunc->addFnAttr(llvm::Attribute::InlineHint);
if (func.getOptimizeNoneAttr())
llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone);
if (func.getConvergentAttr())
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 0c3e87a..d9ad8fb 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2619,6 +2619,11 @@ LogicalResult ControlFlowStructurizer::structurize() {
// region. We cannot handle such cases given that once a value is sinked into
// the SelectionOp/LoopOp's region, there is no escape for it.
for (auto *block : constructBlocks) {
+ if (!block->use_empty())
+ return emitError(block->getParent()->getLoc(),
+ "failed control flow structurization: "
+ "block has uses outside of the "
+ "enclosing selection/loop construct");
for (Operation &op : *block)
if (!op.use_empty())
return op.emitOpError("failed control flow structurization: value has "
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 132be4e..366ba8f 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
@@ -138,6 +139,10 @@ using ImportDesc =
using parsed_inst_t = FailureOr<SmallVector<Value>>;
+struct EmptyBlockMarker {};
+using BlockTypeParseResult =
+ std::variant<EmptyBlockMarker, TypeIdxRecord, Type>;
+
struct WasmModuleSymbolTables {
SmallVector<FunctionSymbolRefContainer> funcSymbols;
SmallVector<GlobalSymbolRefContainer> globalSymbols;
@@ -175,6 +180,9 @@ class ParserHead;
/// Wrapper around SmallVector to only allow access as push and pop on the
/// stack. Makes sure that there are no "free accesses" on the stack to preserve
/// its state.
+/// This class also keep tracks of the Wasm labels defined by different ops,
+/// which can be targeted by control flow ops. This can be modeled as part of
+/// the Value Stack as Wasm control flow ops can only target enclosing labels.
class ValueStack {
private:
struct LabelLevel {
@@ -206,6 +214,16 @@ public:
/// if an error occurs.
LogicalResult pushResults(ValueRange results, Location *opLoc);
+ void addLabelLevel(LabelLevelOpInterface levelOp) {
+ labelLevel.push_back({values.size(), levelOp});
+ LDBG() << "Adding a new frame context to ValueStack";
+ }
+
+ void dropLabelLevel() {
+ assert(!labelLevel.empty() && "Trying to drop a frame from empty context");
+ auto newSize = labelLevel.pop_back_val().stackIdx;
+ values.truncate(newSize);
+ }
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// A simple dump function for debugging.
/// Writes output to llvm::dbgs().
@@ -214,6 +232,7 @@ public:
private:
SmallVector<Value> values;
+ SmallVector<LabelLevel> labelLevel;
};
using local_val_t = TypedValue<wasmssa::LocalRefType>;
@@ -248,6 +267,19 @@ private:
buildNumericOp(OpBuilder &builder,
std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr);
+ /// Construct a conversion operation of type \p opType that takes a value from
+ /// type \p inputType on the stack and will produce a value of type
+ /// \p outputType.
+ ///
+ /// \p opType - The WASM dialect operation to build.
+ /// \p inputType - The operand type for the built instruction.
+ /// \p outputType - The result type for the built instruction.
+ ///
+ /// \returns The parsed instruction result, or failure.
+ template <typename opType, typename inputType, typename outputType,
+ typename... extraArgsT>
+ inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...);
+
/// This function generates a dispatch tree to associate an opcode with a
/// parser. Parsers are registered by specialising the
/// `parseSpecificInstruction` function for the op code to handle.
@@ -280,11 +312,105 @@ private:
}
}
+ ///
+ /// RAII guard class for creating a nesting level
+ ///
+ struct NestingContextGuard {
+ NestingContextGuard(ExpressionParser &parser, LabelLevelOpInterface levelOp)
+ : parser{parser} {
+ parser.addNestingContextLevel(levelOp);
+ }
+ NestingContextGuard(NestingContextGuard &&other) : parser{other.parser} {
+ other.shouldDropOnDestruct = false;
+ }
+ NestingContextGuard(NestingContextGuard const &) = delete;
+ ~NestingContextGuard() {
+ if (shouldDropOnDestruct)
+ parser.dropNestingContextLevel();
+ }
+ ExpressionParser &parser;
+ bool shouldDropOnDestruct = true;
+ };
+
+ void addNestingContextLevel(LabelLevelOpInterface levelOp) {
+ valueStack.addLabelLevel(levelOp);
+ }
+
+ void dropNestingContextLevel() {
+ // Should always succeed as we are droping the frame that was previously
+ // created.
+ valueStack.dropLabelLevel();
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ EmptyBlockMarker) {
+ return builder.getFunctionType({}, {});
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ TypeIdxRecord type) {
+ if (type.id >= symbols.moduleFuncTypes.size())
+ return emitError(*currentOpLoc,
+ "type index references nonexistent type (")
+ << type.id << "). Only " << symbols.moduleFuncTypes.size()
+ << " types are registered";
+ return symbols.moduleFuncTypes[type.id];
+ }
+
+ llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder,
+ Type valType) {
+ return builder.getFunctionType({}, {valType});
+ }
+
+ llvm::FailureOr<FunctionType>
+ getFuncTypeFor(OpBuilder &builder, BlockTypeParseResult parseResult) {
+ return std::visit(
+ [this, &builder](auto value) { return getFuncTypeFor(builder, value); },
+ parseResult);
+ }
+
+ llvm::FailureOr<FunctionType>
+ getFuncTypeFor(OpBuilder &builder,
+ llvm::FailureOr<BlockTypeParseResult> parseResult) {
+ if (llvm::failed(parseResult))
+ return failure();
+ return getFuncTypeFor(builder, *parseResult);
+ }
+
+ llvm::FailureOr<FunctionType> parseBlockFuncType(OpBuilder &builder);
+
struct ParseResultWithInfo {
SmallVector<Value> opResults;
std::byte endingByte;
};
+ template <typename FilterT = ByteSequence<WasmBinaryEncoding::endByte>>
+ /// @param blockToFill: the block which content will be populated
+ /// @param resType: the type that this block is supposed to return
+ llvm::FailureOr<std::byte>
+ parseBlockContent(OpBuilder &builder, Block *blockToFill, TypeRange resTypes,
+ Location opLoc, LabelLevelOpInterface levelOp,
+ FilterT parseEndBytes = {}) {
+ OpBuilder::InsertionGuard guard{builder};
+ builder.setInsertionPointToStart(blockToFill);
+ LDBG() << "parsing a block of type "
+ << builder.getFunctionType(blockToFill->getArgumentTypes(),
+ resTypes);
+ auto nC = addNesting(levelOp);
+
+ if (failed(pushResults(blockToFill->getArguments())))
+ return failure();
+ auto bodyParsingRes = parse(builder, parseEndBytes);
+ if (failed(bodyParsingRes))
+ return failure();
+ auto returnOperands = popOperands(resTypes);
+ if (failed(returnOperands))
+ return failure();
+ builder.create<BlockReturnOp>(opLoc, *returnOperands);
+ LDBG() << "end of parsing of a block";
+ return bodyParsingRes->endingByte;
+ }
+
public:
template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
@@ -294,7 +420,11 @@ public:
parse(OpBuilder &builder,
ByteSequence<ExpressionParseEnd...> parsingEndFilters);
- FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
+ NestingContextGuard addNesting(LabelLevelOpInterface levelOp) {
+ return NestingContextGuard{*this, levelOp};
+ }
+
+ FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes) {
return valueStack.popOperands(operandTypes, &currentOpLoc.value());
}
@@ -308,6 +438,12 @@ public:
template <typename OpToCreate>
parsed_inst_t parseSetOrTee(OpBuilder &);
+ /// Blocks and Loops have a similar format and differ only in how their exit
+ /// is handled which doesn´t matter at parsing time. Factorizes in one
+ /// function.
+ template <typename OpToCreate>
+ parsed_inst_t parseBlockLikeOp(OpBuilder &);
+
private:
std::optional<Location> currentOpLoc;
ParserHead &parser;
@@ -586,6 +722,29 @@ public:
return success();
}
+ llvm::FailureOr<BlockTypeParseResult> parseBlockType(MLIRContext *ctx) {
+ auto loc = getLocation();
+ auto blockIndicator = peek();
+ if (failed(blockIndicator))
+ return failure();
+ if (*blockIndicator == WasmBinaryEncoding::Type::emptyBlockType) {
+ offset += 1;
+ return {EmptyBlockMarker{}};
+ }
+ if (isValueOneOf(*blockIndicator, valueTypesEncodings))
+ return parseValueType(ctx);
+ /// Block type idx is a 32 bit positive integer encoded as a 33 bit signed
+ /// value
+ auto typeIdx = parseI64();
+ if (failed(typeIdx))
+ return failure();
+ if (*typeIdx < 0 || *typeIdx > std::numeric_limits<uint32_t>::max())
+ return emitError(loc, "type ID should be representable with an unsigned "
+ "32 bits integer. Got ")
+ << *typeIdx;
+ return {TypeIdxRecord{static_cast<uint32_t>(*typeIdx)}};
+ }
+
bool end() const { return curHead().empty(); }
ParserHead copy() const { return *this; }
@@ -701,17 +860,41 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
void ValueStack::dump() const {
llvm::dbgs() << "================= Wasm ValueStack =======================\n";
llvm::dbgs() << "size: " << size() << "\n";
+ llvm::dbgs() << "nbFrames: " << labelLevel.size() << '\n';
llvm::dbgs() << "<Top>"
<< "\n";
// Stack is pushed to via push_back. Therefore the top of the stack is the
// end of the vector. Iterate in reverse so that the first thing we print
// is the top of the stack.
+ auto indexGetter = [this]() {
+ size_t idx = labelLevel.size();
+ return [this, idx]() mutable -> std::optional<std::pair<size_t, size_t>> {
+ llvm::dbgs() << "IDX: " << idx << '\n';
+ if (idx == 0)
+ return std::nullopt;
+ auto frameId = idx - 1;
+ auto frameLimit = labelLevel[frameId].stackIdx;
+ idx -= 1;
+ return {{frameId, frameLimit}};
+ };
+ };
+ auto getNextFrameIndex = indexGetter();
+ auto nextFrameIdx = getNextFrameIndex();
size_t stackSize = size();
- for (size_t idx = 0; idx < stackSize; idx++) {
+ for (size_t idx = 0; idx < stackSize; ++idx) {
size_t actualIdx = stackSize - 1 - idx;
+ while (nextFrameIdx && (nextFrameIdx->second > actualIdx)) {
+ llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first
+ << ")\n";
+ nextFrameIdx = getNextFrameIndex();
+ }
llvm::dbgs() << " ";
values[actualIdx].dump();
}
+ while (nextFrameIdx) {
+ llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first << ")\n";
+ nextFrameIdx = getNextFrameIndex();
+ }
llvm::dbgs() << "<Bottom>"
<< "\n";
llvm::dbgs() << "=========================================================\n";
@@ -726,7 +909,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
return emitError(*opLoc,
"stack doesn't contain enough values. trying to get ")
<< operandTypes.size() << " operands on a stack containing only "
- << values.size() << " values.";
+ << values.size() << " values";
size_t stackIdxOffset = values.size() - operandTypes.size();
SmallVector<Value> res{};
res.reserve(operandTypes.size());
@@ -735,8 +918,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
Type stackType = operand.getType();
if (stackType != operandTypes[i])
return emitError(*opLoc, "invalid operand type on stack. expecting ")
- << operandTypes[i] << ", value on stack is of type " << stackType
- << ".";
+ << operandTypes[i] << ", value on stack is of type " << stackType;
LDBG() << " POP: " << operand;
res.push_back(operand);
}
@@ -792,6 +974,151 @@ ExpressionParser::parse(OpBuilder &builder,
}
}
+llvm::FailureOr<FunctionType>
+ExpressionParser::parseBlockFuncType(OpBuilder &builder) {
+ return getFuncTypeFor(builder, parser.parseBlockType(builder.getContext()));
+}
+
+template <typename OpToCreate>
+parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
+ auto opLoc = currentOpLoc;
+ auto funcType = parseBlockFuncType(builder);
+ if (failed(funcType))
+ return failure();
+
+ auto inputTypes = funcType->getInputs();
+ auto inputOps = popOperands(inputTypes);
+ if (failed(inputOps))
+ return failure();
+
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto resTypes = funcType->getResults();
+ llvm::SmallVector<Location> locations{};
+ locations.resize(resTypes.size(), *currentOpLoc);
+ auto *successor =
+ builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
+ builder.setInsertionPointToEnd(curBlock);
+ auto blockOp =
+ builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor);
+ auto *blockBody = blockOp.createBlock();
+ if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp)))
+ return failure();
+ builder.setInsertionPointToStart(successor);
+ return {ValueRange{successor->getArguments()}};
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>(
+ OpBuilder &builder) {
+ return parseBlockLikeOp<BlockOp>(builder);
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>(
+ OpBuilder &builder) {
+ return parseBlockLikeOp<LoopOp>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder) {
+ auto opLoc = currentOpLoc;
+ auto funcType = parseBlockFuncType(builder);
+ if (failed(funcType))
+ return failure();
+
+ LDBG() << "Parsing an if instruction of type " << *funcType;
+ auto inputTypes = funcType->getInputs();
+ auto conditionValue = popOperands(builder.getI32Type());
+ if (failed(conditionValue))
+ return failure();
+ auto inputOps = popOperands(inputTypes);
+ if (failed(inputOps))
+ return failure();
+
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto resTypes = funcType->getResults();
+ llvm::SmallVector<Location> locations{};
+ locations.resize(resTypes.size(), *currentOpLoc);
+ auto *successor =
+ builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
+ builder.setInsertionPointToEnd(curBlock);
+ auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(),
+ *inputOps, successor);
+ auto *ifEntryBlock = ifOp.createIfBlock();
+ constexpr auto ifElseFilter =
+ ByteSequence<WasmBinaryEncoding::endByte,
+ WasmBinaryEncoding::OpCode::elseOpCode>{};
+ auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc,
+ ifOp, ifElseFilter);
+ if (failed(parseIfRes))
+ return failure();
+ if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) {
+ LDBG() << " else block is present.";
+ Block *elseEntryBlock = ifOp.createElseBlock();
+ auto parseElseRes =
+ parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp);
+ if (failed(parseElseRes))
+ return failure();
+ }
+ builder.setInsertionPointToStart(successor);
+ return {ValueRange{successor->getArguments()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder) {
+ auto level = parser.parseLiteral<uint32_t>();
+ if (failed(level))
+ return failure();
+ Block *curBlock = builder.getBlock();
+ Region *curRegion = curBlock->getParent();
+ auto sip = builder.saveInsertionPoint();
+ Block *elseBlock = builder.createBlock(curRegion, curRegion->end());
+ auto condition = popOperands(builder.getI32Type());
+ if (failed(condition))
+ return failure();
+ builder.restoreInsertionPoint(sip);
+ auto targetOp =
+ LabelBranchingOpInterface::getTargetOpFromBlock(curBlock, *level);
+ if (failed(targetOp))
+ return failure();
+ auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes();
+ auto branchArgs = popOperands(inputTypes);
+ if (failed(branchArgs))
+ return failure();
+ builder.create<BranchIfOp>(*currentOpLoc, condition->front(),
+ builder.getUI32IntegerAttr(*level), *branchArgs,
+ elseBlock);
+ builder.setInsertionPointToStart(elseBlock);
+ return {*branchArgs};
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
+ OpBuilder &builder) {
+ auto loc = *currentOpLoc;
+ auto funcIdx = parser.parseLiteral<uint32_t>();
+ if (failed(funcIdx))
+ return failure();
+ if (*funcIdx >= symbols.funcSymbols.size())
+ return emitError(loc, "Invalid function index: ") << *funcIdx;
+ auto callee = symbols.funcSymbols[*funcIdx];
+ llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs();
+ llvm::ArrayRef<Type> resTypes = callee.functionType.getResults();
+ parsed_inst_t inOperands = popOperands(inTypes);
+ if (failed(inOperands))
+ return failure();
+ auto callOp =
+ builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands);
+ return {callOp.getResults()};
+}
+
template <>
inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) {
@@ -834,7 +1161,7 @@ parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) {
if (valueStack.empty())
return emitError(
*currentOpLoc,
- "invalid stack access, trying to access a value on an empty stack.");
+ "invalid stack access, trying to access a value on an empty stack");
parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType());
if (failed(poppedOp))
@@ -956,7 +1283,7 @@ inline parsed_inst_t ExpressionParser::buildNumericOp(
<< ", type = " << ty << " ***";
auto tysToPop = SmallVector<Type, numOperands>();
tysToPop.resize(numOperands);
- std::fill(tysToPop.begin(), tysToPop.end(), ty);
+ llvm::fill(tysToPop, ty);
auto operands = popOperands(tysToPop);
if (failed(operands))
return failure();
@@ -1000,11 +1327,23 @@ inline parsed_inst_t ExpressionParser::buildNumericOp(
BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign)
BUILD_NUMERIC_BINOP_FP(DivOp, div)
+BUILD_NUMERIC_BINOP_FP(GeOp, ge)
+BUILD_NUMERIC_BINOP_FP(GtOp, gt)
+BUILD_NUMERIC_BINOP_FP(LeOp, le)
+BUILD_NUMERIC_BINOP_FP(LtOp, lt)
BUILD_NUMERIC_BINOP_FP(MaxOp, max)
BUILD_NUMERIC_BINOP_FP(MinOp, min)
BUILD_NUMERIC_BINOP_INT(AndOp, and)
BUILD_NUMERIC_BINOP_INT(DivSIOp, divS)
BUILD_NUMERIC_BINOP_INT(DivUIOp, divU)
+BUILD_NUMERIC_BINOP_INT(GeSIOp, geS)
+BUILD_NUMERIC_BINOP_INT(GeUIOp, geU)
+BUILD_NUMERIC_BINOP_INT(GtSIOp, gtS)
+BUILD_NUMERIC_BINOP_INT(GtUIOp, gtU)
+BUILD_NUMERIC_BINOP_INT(LeSIOp, leS)
+BUILD_NUMERIC_BINOP_INT(LeUIOp, leU)
+BUILD_NUMERIC_BINOP_INT(LtSIOp, ltS)
+BUILD_NUMERIC_BINOP_INT(LtUIOp, ltU)
BUILD_NUMERIC_BINOP_INT(OrOp, or)
BUILD_NUMERIC_BINOP_INT(RemSIOp, remS)
BUILD_NUMERIC_BINOP_INT(RemUIOp, remU)
@@ -1015,7 +1354,9 @@ BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS)
BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU)
BUILD_NUMERIC_BINOP_INT(XOrOp, xor)
BUILD_NUMERIC_BINOP_INTFP(AddOp, add)
+BUILD_NUMERIC_BINOP_INTFP(EqOp, eq)
BUILD_NUMERIC_BINOP_INTFP(MulOp, mul)
+BUILD_NUMERIC_BINOP_INTFP(NeOp, ne)
BUILD_NUMERIC_BINOP_INTFP(SubOp, sub)
BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs)
BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil)
@@ -1025,6 +1366,7 @@ BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt)
BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc)
BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz)
BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz)
+BUILD_NUMERIC_UNARY_OP_INT(EqzOp, eqz)
BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
// Don't need these anymore so let's undef them.
@@ -1036,6 +1378,105 @@ BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt)
#undef BUILD_NUMERIC_OP
#undef BUILD_NUMERIC_CAST_OP
+template <typename opType, typename inputType, typename outputType,
+ typename... extraArgsT>
+inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
+ extraArgsT... extraArgs) {
+ static_assert(std::is_arithmetic_v<inputType>,
+ "InputType should be an arithmetic type");
+ static_assert(std::is_arithmetic_v<outputType>,
+ "OutputType should be an arithmetic type");
+ auto intype = buildLiteralType<inputType>(builder);
+ auto outType = buildLiteralType<outputType>(builder);
+ auto operand = popOperands(intype);
+ if (failed(operand))
+ return failure();
+ auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(),
+ extraArgs...);
+ LDBG() << "Built operation: " << op;
+ return {{op.getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::demoteF64ToF32>(OpBuilder &builder) {
+ return buildConvertOp<DemoteOp, double, float>(builder);
+}
+
+template <>
+inline parsed_inst_t
+ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::wrap>(
+ OpBuilder &builder) {
+ return buildConvertOp<WrapOp, int64_t, int32_t>(builder);
+}
+
+#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::SOURCE_OP>(OpBuilder & builder) { \
+ return buildConvertOp<TARGET_OP, IN_T, OUT_T>(builder); \
+ }
+
+#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH) \
+ BUILD_CONVERSION_OP(uint32_t, DEST_T, convertUI32F##WIDTH, ConvertUOp) \
+ BUILD_CONVERSION_OP(int32_t, DEST_T, convertSI32F##WIDTH, ConvertSOp) \
+ BUILD_CONVERSION_OP(uint64_t, DEST_T, convertUI64F##WIDTH, ConvertUOp) \
+ BUILD_CONVERSION_OP(int64_t, DEST_T, convertSI64F##WIDTH, ConvertSOp)
+
+BUILD_CONVERT_OP_FOR(float, 32)
+BUILD_CONVERT_OP_FOR(double, 64)
+
+#undef BUILD_CONVERT_OP_FOR
+
+BUILD_CONVERSION_OP(int32_t, int64_t, extendS, ExtendSI32Op)
+BUILD_CONVERSION_OP(int32_t, int64_t, extendU, ExtendUI32Op)
+
+#undef BUILD_CONVERSION_OP
+
+#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH) \
+ template <> \
+ parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::extendI##IT_WIDTH##EXTRACT_WIDTH##S>( \
+ OpBuilder & builder) { \
+ using inout_t = int##IT_WIDTH##_t; \
+ auto attr = builder.getUI32IntegerAttr(EXTRACT_WIDTH); \
+ return buildConvertOp<ExtendLowBitsSOp, inout_t, inout_t>(builder, attr); \
+ }
+
+BUILD_SLICE_EXTEND_PARSER(32, 8)
+BUILD_SLICE_EXTEND_PARSER(32, 16)
+BUILD_SLICE_EXTEND_PARSER(64, 8)
+BUILD_SLICE_EXTEND_PARSER(64, 16)
+BUILD_SLICE_EXTEND_PARSER(64, 32)
+
+#undef BUILD_SLICE_EXTEND_PARSER
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::promoteF32ToF64>(OpBuilder &builder) {
+ return buildConvertOp<PromoteOp, float, double>(builder);
+}
+
+#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE) \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::reinterpretF##WIDTH##AsI##WIDTH>(OpBuilder & \
+ builder) { \
+ return buildConvertOp<ReinterpretOp, FP_TYPE, int##WIDTH##_t>(builder); \
+ } \
+ \
+ template <> \
+ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \
+ WasmBinaryEncoding::OpCode::reinterpretI##WIDTH##AsF##WIDTH>(OpBuilder & \
+ builder) { \
+ return buildConvertOp<ReinterpretOp, int##WIDTH##_t, FP_TYPE>(builder); \
+ }
+
+BUILD_REINTERPRET_PARSER(32, float)
+BUILD_REINTERPRET_PARSER(64, double)
+
+#undef BUILD_REINTERPRET_PARSER
+
class WasmBinaryParser {
private:
struct SectionRegistry {
@@ -1153,7 +1594,7 @@ private:
if (tid.id >= symbols.moduleFuncTypes.size())
return emitError(loc, "invalid type id: ")
<< tid.id << ". Only " << symbols.moduleFuncTypes.size()
- << " type registration.";
+ << " type registrations";
FunctionType type = symbols.moduleFuncTypes[tid.id];
std::string symbol = symbols.getNewFuncSymbolName();
auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
@@ -1221,7 +1662,7 @@ public:
FileLineColLoc magicLoc = parser.getLocation();
FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
if (failed(magic) || magic->compare(wasmHeader)) {
- emitError(magicLoc, "source file does not contain valid Wasm header.");
+ emitError(magicLoc, "source file does not contain valid Wasm header");
return;
}
auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
@@ -1391,7 +1832,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
return failure();
Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
- SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
+ op->setAttr("exported", UnitAttr::get(op->getContext()));
StringAttr symName = SymbolTable::getSymbolName(op);
return SymbolTable{mOp}.rename(symName, *exportName);
}
diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index 9670285..3fda5a7 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -93,7 +93,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
// Emit function to add the generated matchers to the pattern list.
os << "template <typename... ConfigsT>\n"
- "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
+ "[[maybe_unused]] static void populateGeneratedPDLLPatterns("
"::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
for (const auto &name : patternNames)
os << " patterns.add<" << name
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index c883baa..3236b4f 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -27,6 +27,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Parser.h"
#include <optional>
@@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
llvm::SourceMgr tdSrcMgr;
tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
+ tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
// This class provides a context argument for the llvm::SourceMgr diagnostic
// handler.
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 30fd384..9ef405d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -37,6 +37,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Remarks/RemarkFormat.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/ManagedStatic.h"
@@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
"bitstream", "Print bitstream file")),
llvm::cl::cat(remarkCategory)};
+ static llvm::cl::opt<RemarkPolicy, /*ExternalStorage=*/true> remarkPolicy{
+ "remark-policy",
+ llvm::cl::desc("Specify the policy for remark output."),
+ cl::location(remarkPolicyFlag),
+ llvm::cl::value_desc("format"),
+ llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL),
+ llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all",
+ "Print all remarks"),
+ clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final",
+ "Print final remarks")),
+ llvm::cl::cat(remarkCategory)};
+
static cl::opt<std::string, /*ExternalStorage=*/true> remarksAll(
"remarks-filter",
cl::desc("Show all remarks: passed, missed, failed, analysis"),
@@ -517,18 +530,28 @@ performActions(raw_ostream &os,
return failure();
context->enableMultithreading(wasThreadingEnabled);
-
+ // Set the remark categories and policy.
remark::RemarkCategories cats{
config.getRemarksAllFilter(), config.getRemarksPassedFilter(),
config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(),
config.getRemarksFailedFilter()};
mlir::MLIRContext &ctx = *context;
+ // Helper to create the appropriate policy based on configuration
+ auto createPolicy = [&config]()
+ -> std::unique_ptr<mlir::remark::detail::RemarkEmittingPolicyBase> {
+ if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL)
+ return std::make_unique<mlir::remark::RemarkEmittingPolicyAll>();
+ if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL)
+ return std::make_unique<mlir::remark::RemarkEmittingPolicyFinal>();
+
+ llvm_unreachable("Invalid remark policy");
+ };
switch (config.getRemarkFormat()) {
case RemarkFormat::REMARK_FORMAT_STDOUT:
if (failed(mlir::remark::enableOptimizationRemarks(
- ctx, nullptr, cats, true /*printAsEmitRemarks*/)))
+ ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/)))
return failure();
break;
@@ -537,7 +560,7 @@ performActions(raw_ostream &os,
? "mlir-remarks.yaml"
: config.getRemarksOutputFile();
if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
- ctx, file, llvm::remarks::Format::YAML, cats)))
+ ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats)))
return failure();
break;
}
@@ -547,7 +570,7 @@ performActions(raw_ostream &os,
? "mlir-remarks.bitstream"
: config.getRemarksOutputFile();
if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
- ctx, file, llvm::remarks::Format::Bitstream, cats)))
+ ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats)))
return failure();
break;
}
@@ -593,6 +616,12 @@ performActions(raw_ostream &os,
AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
&fallbackResourceMap);
os << OpWithState(op.get(), asmState) << '\n';
+
+ // This is required if the remark policy is final. Otherwise, the remarks are
+ // not emitted.
+ if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine())
+ engine->getRemarkEmittingPolicy()->finalize();
+
return success();
}
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 60b9567..1dbe7eca 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -31,6 +31,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include <optional>
using namespace mlir;
@@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
llvm::append_range(includeDirs, extraDirs);
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
index 3080b78..2d817be 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
@@ -17,6 +17,7 @@
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Parser.h"
#include "llvm/TableGen/Record.h"
#include <optional>
@@ -448,6 +449,7 @@ void TableGenTextFile::initialize(
return;
}
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
// This class provides a context argument for the SourceMgr diagnostic
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 111f58e..5f3b04a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -66,7 +66,9 @@ size_t mlir::moveLoopInvariantCode(
size_t numMoved = 0;
for (Region *region : regions) {
- LDBG() << "Original loop:\n" << *region->getParentOp();
+ LDBG() << "Original loop:\n"
+ << OpWithFlags(region->getParentOp(),
+ OpPrintingFlags().skipRegions());
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
@@ -90,7 +92,8 @@ size_t mlir::moveLoopInvariantCode(
!canBeHoisted(op, definedOutside))
continue;
- LDBG() << "Moving loop-invariant op: " << *op;
+ LDBG() << "Moving loop-invariant op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
moveOutOfRegion(op, region);
++numMoved;
@@ -111,9 +114,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
},
- [&](Operation *op, Region *) {
- return isMemoryEffectFree(op) && isSpeculatable(op);
- },
+ [&](Operation *op, Region *) { return isPure(op); },
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
}