aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp36
-rw-r--r--mlir/lib/Analysis/CMakeLists.txt2
-rw-r--r--mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp127
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/MathToXeVM/CMakeLists.txt22
-rw-r--r--mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp167
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp288
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp7
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp32
-rw-r--r--mlir/lib/Dialect/AMX/IR/AMXDialect.cpp99
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp5
-rw-r--r--mlir/lib/Dialect/LLVMIR/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp17
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp154
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h27
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp74
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp153
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp89
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp59
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp261
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp1
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp105
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp62
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp52
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp17
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
-rw-r--r--mlir/lib/IR/Diagnostics.cpp6
-rw-r--r--mlir/lib/IR/MLIRContext.cpp15
-rw-r--r--mlir/lib/IR/Remarks.cpp57
-rw-r--r--mlir/lib/Interfaces/CMakeLists.txt16
-rw-r--r--mlir/lib/Interfaces/InferIntRangeInterface.cpp19
-rw-r--r--mlir/lib/Interfaces/InferStridedMetadataInterface.cpp36
-rw-r--r--mlir/lib/Remark/RemarkStreamer.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/DebugImporter.cpp4
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp2
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp2
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp37
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp2
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp11
43 files changed, 1820 insertions, 276 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 8062b474..a84d10d 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -258,6 +258,39 @@ getAllocEffectFor(Value value,
return success();
}
+static Operation *isDistinctObjectsOp(Operation *op) {
+ if (op && op->hasTrait<OpTrait::DistinctObjectsTrait>())
+ return op;
+
+ return nullptr;
+}
+
+static Value getDistinctObjectsOperand(Operation *op, Value value) {
+ unsigned argNumber = cast<OpResult>(value).getResultNumber();
+ return op->getOperand(argNumber);
+}
+
+static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) {
+ // We should already checked that lhs and rhs are different.
+ assert(lhs != rhs && "lhs and rhs must be different");
+
+ // Result and corresponding operand must alias.
+ auto lhsOp = isDistinctObjectsOp(lhs.getDefiningOp());
+ if (lhsOp && getDistinctObjectsOperand(lhsOp, lhs) == rhs)
+ return AliasResult::MustAlias;
+
+ auto rhsOp = isDistinctObjectsOp(rhs.getDefiningOp());
+ if (rhsOp && getDistinctObjectsOperand(rhsOp, rhs) == lhs)
+ return AliasResult::MustAlias;
+
+ // If two different values come from the same `DistinctObjects` operation,
+ // they don't alias.
+ if (lhsOp && lhsOp == rhsOp)
+ return AliasResult::NoAlias;
+
+ return std::nullopt;
+}
+
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
if (lhs == rhs)
@@ -289,6 +322,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
: AliasResult::MayAlias;
}
+ if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs))
+ return *result;
+
// Otherwise, neither of the values are constant so check to see if either has
// an allocation effect.
bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 609cb34..db10ebc 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
+ DataFlow/StridedMetadataRangeAnalysis.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
@@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
MLIRPresburger
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
new file mode 100644
index 0000000..01c9daf
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -0,0 +1,127 @@
+//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/DebugStringHelper.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "strided-metadata-range-analysis"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+/// Get the entry state for a value. For any value that is not a ranked memref,
+/// this function sets the metadata to a top state with no offsets, sizes, or
+/// strides. For `memref` types, this function will use the metadata in the type
+/// to try to deduce as much informaiton as possible.
+static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
+ // TODO: generalize this method with a type interface.
+ auto mTy = dyn_cast<BaseMemRefType>(v.getType());
+
+ // If not a memref or it's un-ranked, don't infer any metadata.
+ if (!mTy || !mTy.hasRank())
+ return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
+
+ // Get the top state.
+ auto metadata =
+ StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
+
+ // Compute the offset and strides.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
+ return metadata;
+
+ // Refine the metadata if we know it from the type.
+ if (!ShapedType::isDynamic(offset)) {
+ metadata.getOffsets()[0] =
+ ConstantIntRanges::constant(APInt(indexBitwidth, offset));
+ }
+ for (auto &&[size, range] :
+ llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
+ if (ShapedType::isDynamic(size))
+ continue;
+ range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
+ }
+ for (auto &&[stride, range] :
+ llvm::zip_equal(strides, metadata.getStrides())) {
+ if (ShapedType::isDynamic(stride))
+ continue;
+ range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
+ }
+
+ return metadata;
+}
+
+StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis(
+ DataFlowSolver &solver, int32_t indexBitwidth)
+ : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
+ assert(indexBitwidth > 0 && "invalid bitwidth");
+}
+
+void StridedMetadataRangeAnalysis::setToEntryState(
+ StridedMetadataRangeLattice *lattice) {
+ propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
+ lattice->getAnchor(), indexBitwidth)));
+}
+
+LogicalResult StridedMetadataRangeAnalysis::visitOperation(
+ Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands,
+ ArrayRef<StridedMetadataRangeLattice *> results) {
+ auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
+
+ // Bail if we cannot reason about the op.
+ if (!inferrable) {
+ setAllToEntryStates(results);
+ return success();
+ }
+
+ LDBG() << "Inferring metadata for: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+ // Helper function to retrieve int range values.
+ auto getIntRange = [&](Value value) -> IntegerValueRange {
+ auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
+ getProgramPointAfter(op), value);
+ return lattice ? lattice->getValue() : IntegerValueRange();
+ };
+
+ // Convert the arguments lattices to a vector.
+ SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
+ operands, [](const StridedMetadataRangeLattice *lattice) {
+ return lattice->getValue();
+ });
+
+ // Callback to set metadata on a result.
+ auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
+ auto result = cast<OpResult>(v);
+ assert(llvm::is_contained(op->getResults(), result));
+ LDBG() << "- Inferred metadata: " << md;
+ StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
+ ChangeResult changed = lattice->join(md);
+ LDBG() << "- Joined metadata: " << lattice->getValue();
+ propagateIfChanged(lattice, changed);
+ };
+
+ // Infer the metadata.
+ inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
+ indexBitwidth);
+ return success();
+}
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f8..bebf1b8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000..050c0ed
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMathToXeVM
+ MathToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithAttrToLLVMConversion
+ MLIRArithDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMathDialect
+ MLIRXeVMDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000..0fe31d0
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,167 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+ return failure();
+
+ arith::FastMathFlags fastFlags = op.getFastmath();
+ if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
+ return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
+
+ SmallVector<Type, 1> operandTypes;
+ for (auto operand : adaptor.getOperands()) {
+ Type opTy = operand.getType();
+ // This pass only supports operations on vectors that are already in SPIRV
+ // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+ // supported vector sizes are done in other blocking optimization passes.
+ if (!isSPIRVCompatibleFloatOrVec(opTy))
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("incompatible operand type: '{0}'", opTy));
+ operandTypes.push_back(opTy);
+ }
+
+ auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
+ auto funcOpRes = LLVM::lookupOrCreateFn(
+ rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
+ operandTypes, op.getType());
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+
+ auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, funcOp, adaptor.getOperands());
+ // Preserve fastmath flags in our MLIR op when converting to llvm function
+ // calls, in order to allow further fastmath optimizations: We thus need to
+ // convert arith fastmath attrs into attrs recognized by llvm.
+ arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+ mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+ callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
+ return success();
+ }
+
+ inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+ if (type.isFloat())
+ return true;
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ if (!vecType.getElementType().isFloat())
+ return false;
+ // SPIRV distinguishes between vectors and matrices: OpenCL native math
+ // intrsinics are not compatible with matrices.
+ ArrayRef<int64_t> shape = vecType.getShape();
+ if (shape.size() != 1)
+ return false;
+ // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+ shape[0] == 16)
+ return true;
+ }
+ return false;
+ }
+
+ inline std::string
+ getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+ std::string mangledFuncName =
+ "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+ auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
+ if (type.isF32())
+ mangledFuncName += "f";
+ else if (type.isF16())
+ mangledFuncName += "Dh";
+ else if (type.isF64())
+ mangledFuncName += "d";
+ };
+
+ for (auto type : operandTypes) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ appendFloatToMangledFunc(vecType.getElementType());
+ } else
+ appendFloatToMangledFunc(type);
+ }
+
+ return mangledFuncName;
+ }
+
+ const StringRef nativeFunc;
+};
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith) {
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+ "__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+ "__spirv_ocl_native_cos");
+ patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_exp2");
+ patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+ "__spirv_ocl_native_log");
+ patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log2");
+ patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log10");
+ patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_powr");
+ patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_rsqrt");
+ patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+ "__spirv_ocl_native_sin");
+ patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_sqrt");
+ patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+ "__spirv_ocl_native_tan");
+ if (convertArith)
+ patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_divide");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+ : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+ using Base::Base;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateMathToXeVMConversionPatterns(patterns, convertArith);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a5336ed..00df14b1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1392,6 +1392,137 @@ public:
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+static llvm::SmallVector<int8_t>
+convertToI8(const llvm::SmallVector<int32_t> &input) {
+ llvm::SmallVector<int8_t> output;
+ output.reserve(input.size());
+
+ for (auto v : llvm::map_range(
+ input, [](int32_t val) { return static_cast<int8_t>(val); })) {
+ output.push_back(v);
+ }
+ return output;
+}
+
+// The shift or multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift or multiplier is non-constant, add it as an input to
+// linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift or multiplier is constant, set 'constant' instead.
+static void setupLinalgGenericOpInputAndIndexingMap(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
+ bool isShift = false) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the
+ // values in a buffer.
+ if (values.size() == 1) {
+ IntegerAttr intAttr = isShift
+ ? rewriter.getI8IntegerAttr(values.front())
+ : rewriter.getI32IntegerAttr(values.front());
+ constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+ } else {
+ auto elementType =
+ isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
+ auto tensorType = RankedTensorType::get(
+ {static_cast<int64_t>(values.size())}, elementType);
+ DenseIntElementsAttr EltAttr;
+ if (isShift)
+ EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
+ else
+ EltAttr = DenseIntElementsAttr::get(tensorType, values);
+ genericInputs.push_back(
+ arith::ConstantOp::create(rewriter, loc, EltAttr));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto operand = isShift ? op.getShift() : op.getMultiplier();
+ auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(operand);
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ }
+ arg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t zpArg,
+ bool isOutputZp = false) {
+ Value result;
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ const uint32_t attrBitwidth =
+ isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
+ auto extendType = builder.getIntegerType(attrBitwidth);
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[zpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
+ // For ExtUIOp, the input must be signless.
+ // UnrealizedConversionCastOp will cast the input to signless type.
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ UnrealizedConversionCastOp::create(
+ builder, loc,
+ builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
+ .getResult(0);
+ }
+ if (zpTy.isUnsignedInteger()) {
+ return builder.create<arith::ExtUIOp>(loc, extendType, result);
+ } else {
+ return builder.create<arith::ExtSIOp>(loc, extendType, result);
+ }
+ }
+ } else {
+ return builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(extendType, *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1554,46 @@ public:
}
}
- // The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int32_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ // explicit cast is required here
+ shiftValues = llvm::to_vector(llvm::map_range(
+ shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
-
- bool doubleRound =
- op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
@@ -1468,45 +1605,43 @@ public:
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant, shiftArg, true);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp's are non-constant, add them as an inputs to
+ // linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ // - If the zp's are constant, they would be generated as arith.constant.
+ int64_t iZpArg = 0;
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ iZpArg = indexingMaps.size() - 1;
+ }
+ int64_t oZpArg = 0;
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
+ oZpArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
@@ -1526,36 +1661,17 @@ public:
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs, iZpArg);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs, oZpArg, true);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5355909..41d8d53 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
return success();
}
- // For 1-d vector, we additionally do a `vectorshuffle`.
auto v =
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
poison, adaptor.getSource(), zero);
+ // For 1-d vector, we additionally do a `shufflevector`.
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
- zeroValues);
+ auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
+ broadcast.getLoc(), v, poison, zeroValues);
+ rewriter.replaceOp(broadcast, shuffle);
return success();
}
};
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 71687b1..ddcbc44 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -390,7 +391,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
- valOrResTy = op.getResult().getType();
+ valOrResTy =
+ this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -878,10 +880,30 @@ struct ConvertXeGPUToXeVMPass
}
return {};
};
- typeConverter.addSourceMaterialization(memrefMaterializationCast);
- typeConverter.addSourceMaterialization(ui64MaterializationCast);
- typeConverter.addSourceMaterialization(ui32MaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
+
+ // If result type of original op is single element vector and lowered type
+ // is scalar. This materialization cast creates a single element vector by
+ // broadcasting the scalar value.
+ auto singleElementVectorMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType().isIntOrIndexOrFloat()) {
+ // If the input is a scalar, and the target type is a vector of single
+ // element, create a single element vector by broadcasting.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
+ return vector::BroadcastOp::create(builder, loc, vecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(
+ singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 68990ef..d9c097c 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
+/// Returns stride expressed in number of bytes for the given `elementStride`
+/// stride encoded in number of elements of the type `mType`.
+static Value computeStrideInBytes(Location loc, MemRefType mType,
+ Value elementStride, RewriterBase &rewriter) {
+ Type llvmInt64Type = rewriter.getIntegerType(64);
+ unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
+ auto attr = rewriter.getI64IntegerAttr(bytes);
+ Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
+ return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
+ .getResult();
+}
+
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
-static Value getStride(Location loc, MemRefType mType, Value base,
- RewriterBase &rewriter) {
+static Value inferStride(Location loc, MemRefType mType, Value base,
+ RewriterBase &rewriter) {
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = rewriter.getIntegerType(64);
@@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base,
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
- auto attr = rewriter.getI64IntegerAttr(bytes);
- Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
- return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
- memrefDescriptor.stride(rewriter, loc, preLast))
- .getResult();
+ return computeStrideInBytes(
+ loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
@@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
return getTileSizes(getLoc(), getTileType(), rewriter);
}
-LogicalResult amx::TileLoadOp::verify() {
- MemRefType memrefTy = getMemRefType();
+template <typename OpTy,
+ typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
+ std::is_same_v<OpTy, amx::TileStoreOp>>>
+static LogicalResult tileTransferVerifier(OpTy op) {
+ MemRefType memrefTy = op.getMemRefType();
unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+ if (op.getIndices().size() != rank)
+ return op.emitOpError("requires ") << rank << " indices";
+
+ if (failed(verifyTileSize(op, op.getTileType())))
+ return failure();
+
+ // Validate basic buffer properties when the stride is implicit.
+ if (!op.getStride()) {
+ if (rank < 2)
+ return op.emitOpError("requires at least 2D memref");
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+ strides.back() != 1)
+ return op.emitOpError("requires memref with unit innermost stride");
+ }
+
+ return success();
+}
+
+void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
+ Value base, ValueRange indices) {
+ build(builder, state, res, base, indices, /*stride=*/nullptr);
}
+LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
return intrinsicOperands;
}
-LogicalResult amx::TileStoreOp::verify() {
- MemRefType memrefTy = getMemRefType();
- unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
+ Value base, ValueRange indices, Value val) {
+ build(builder, state, base, indices, val, /*stride=*/nullptr);
}
+LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
intrinsicOperands.push_back(adaptor.getVal());
return intrinsicOperands;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index 624519f..70faa71 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -64,12 +64,13 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
module.walk([&](func::CallOp callOp) {
if (func::FuncOp calledFunc =
dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) {
- callerMap[calledFunc].insert(callOp);
+ if (!calledFunc.isPublic() && !calledFunc.isExternal())
+ callerMap[calledFunc].insert(callOp);
}
});
for (auto funcOp : module.getOps<func::FuncOp>()) {
- if (funcOp.isExternal())
+ if (funcOp.isExternal() || funcOp.isPublic())
continue;
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
// TODO: Support functions with multiple blocks.
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index ec581ac..cc66fac 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect
IR/LLVMMemorySlot.cpp
IR/LLVMTypes.cpp
IR/LLVMTypeSyntax.cpp
+ IR/LLVMDialectBytecode.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
DEPENDS
+ MLIRLLVMDialectBytecodeIncGen
MLIRLLVMOpsIncGen
MLIRLLVMTypesIncGen
MLIRLLVMIntrinsicOpsIncGen
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5d08ccc..3eae67f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -29,6 +29,8 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Error.h"
+#include "LLVMDialectBytecode.h"
+
#include <numeric>
#include <optional>
@@ -2824,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() {
return success();
}
+// Folding for shufflevector op when v1 is single element 1D vector
+// and the mask is a single zero. OpFoldResult will be v1 in this case.
+OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
+ // Check if operand 0 is a single element vector.
+ auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
+ if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
+ return {};
+ // Check if the mask is a single zero.
+ // Note: The mask is guaranteed to be non-empty.
+ if (getMask().size() != 1 || getMask()[0] != 0)
+ return {};
+ return getV1();
+}
+
//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
@@ -4237,6 +4253,7 @@ void LLVMDialect::initialize() {
// Support unknown operations because not all LLVM operations are registered.
allowUnknownOperations();
declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
+ detail::addBytecodeInterface(this);
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
new file mode 100644
index 0000000..41d1f80
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp
@@ -0,0 +1,154 @@
+//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "LLVMDialectBytecode.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+
+// Provide some forward declarations of the functions that will be generated by
+// the include below.
+static void write(DIExpressionElemAttr attribute,
+ DialectBytecodeWriter &writer);
+static LogicalResult writeAttribute(Attribute attribute,
+ DialectBytecodeWriter &writer);
+
+//===--------------------------------------------------------------------===//
+// Optional ArrayRefs
+//
+// Note that both the writer and reader functions consider attributes to be
+// optional. This is because the attribute may be present or empty.
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalArrayRef(DialectBytecodeWriter &writer,
+ ArrayRef<EntryTy> storage) {
+ if (storage.empty()) {
+ writer.writeOwnedBool(false);
+ return;
+ }
+
+ writer.writeOwnedBool(true);
+ writer.writeList(storage, [&](EntryTy val) {
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ (void)writer.writeOptionalAttribute(val);
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ (void)writer.writeVarInt(val);
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ });
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader,
+ SmallVectorImpl<EntryTy> &storage) {
+ bool isPresent = false;
+ if (failed(reader.readBool(isPresent)))
+ return failure();
+ // Nothing to do here, the array is empty.
+ if (!isPresent)
+ return success();
+
+ auto readEntry = [&]() -> FailureOr<EntryTy> {
+ EntryTy temp;
+ if constexpr (std::is_base_of_v<Attribute, EntryTy>) {
+ if (succeeded(reader.readOptionalAttribute(temp)))
+ return temp;
+ } else if constexpr (std::is_integral_v<EntryTy>) {
+ if (succeeded(reader.readVarInt(temp)))
+ return temp;
+ } else {
+ static_assert(true, "EntryTy not supported");
+ }
+ return failure();
+ };
+
+ return reader.readList(storage, readEntry);
+}
+
+//===--------------------------------------------------------------------===//
+// Optional integral types
+//===--------------------------------------------------------------------===//
+
+template <class EntryTy>
+static void writeOptionalInt(DialectBytecodeWriter &writer,
+ std::optional<EntryTy> storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ EntryTy val = storage.value_or(0);
+ writer.writeVarIntWithFlag(val, storage.has_value());
+}
+
+template <class EntryTy>
+static LogicalResult readOptionalInt(DialectBytecodeReader &reader,
+ std::optional<EntryTy> &storage) {
+ static_assert(std::is_integral_v<EntryTy>,
+ "EntryTy must be an integral type");
+ uint64_t result = 0;
+ bool flag = false;
+ if (failed(reader.readVarIntWithFlag(result, flag)))
+ return failure();
+ if (flag)
+ storage = static_cast<EntryTy>(result);
+ else
+ storage = std::nullopt;
+ return success();
+}
+
+//===--------------------------------------------------------------------===//
+// Tablegen generated bytecode functions
+//===--------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc"
+
+//===--------------------------------------------------------------------===//
+// LLVMDialectBytecodeInterface
+//===--------------------------------------------------------------------===//
+
+/// This class implements the bytecode interface for the LLVM dialect.
+struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface {
+ LLVMDialectBytecodeInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ // Attributes
+ Attribute readAttribute(DialectBytecodeReader &reader) const override {
+ return ::readAttribute(getContext(), reader);
+ }
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeAttribute(attr, writer);
+ }
+
+ // Types
+ Type readType(DialectBytecodeReader &reader) const override {
+ return ::readType(getContext(), reader);
+ }
+
+ LogicalResult writeType(Type type,
+ DialectBytecodeWriter &writer) const override {
+ return ::writeType(type, writer);
+ }
+};
+} // namespace
+
+void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) {
+ dialect->addInterfaces<LLVMDialectBytecodeInterface>();
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
new file mode 100644
index 0000000..1a17cb4
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h
@@ -0,0 +1,27 @@
+//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines hooks into the LLVM dialect bytecode
+// implementation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
+
+namespace mlir::LLVM {
+class LLVMDialect;
+
+namespace detail {
+/// Add the interfaces necessary for encoding the LLVM dialect components in
+/// bytecode.
+void addBytecodeInterface(LLVMDialect *dialect);
+} // namespace detail
+} // namespace mlir::LLVM
+
+#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40b..2a8c330 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() {
" attribute");
}
+ // Validate layout combinations. According to the operation description, most
+ // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
+ // can use other layout combinations.
+ bool isM8N8K4_F16 =
+ (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
+ getMultiplicandAPtxType() == MMATypes::f16);
+
+ if (!isM8N8K4_F16) {
+ // For all other shapes/types, layoutA must be row and layoutB must be col
+ if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
+ return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
+ "layoutB = #nvvm.mma_layout<col> for shape <")
+ << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
+ << "> with element types "
+ << stringifyEnum(*getMultiplicandAPtxType()) << " and "
+ << stringifyEnum(*getMultiplicandBPtxType())
+ << ". Only m8n8k4 with f16 supports other layouts.";
+ }
+ }
+
return success();
}
@@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
@@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+/// Verify the range attribute satisfies LLVM ConstantRange constructor
+/// requirements for NVVM SpecialRangeableRegisterOp.
+static LogicalResult
+verifyConstantRangeAttr(Operation *op,
+ std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
+ if (!rangeAttr)
+ return success();
+
+ const llvm::APInt &lower = rangeAttr->getLower();
+ const llvm::APInt &upper = rangeAttr->getUpper();
+
+ // Check LLVM ConstantRange constructor condition
+ if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+ unsigned bitWidth = lower.getBitWidth();
+ llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
+ llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
+ return op->emitOpError(
+ "invalid range attribute: Lower == Upper, but they aren't min (")
+ << llvm::toString(minVal, 10, false) << ") or max ("
+ << llvm::toString(maxVal, 10, false)
+ << ") value! This is an invalid constant range.";
+ }
+
+ return success();
+}
+
static llvm::Value *getAsPackedI32(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c477c6c..dcc1ef9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody(
Value yielded = getSourceSkipUnary(terminator->getOperand(0));
Operation *reductionOp = yielded.getDefiningOp();
- if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
+ if (!reductionOp || reductionOp->getNumResults() != 1 ||
+ reductionOp->getNumOperands() != 2) {
errs << "expected reduction op to be binary";
return false;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a2..cbc565b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {
SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getDestType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dd9b4c2..d8f983f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
// FuseOp
//===----------------------------------------------------------------------===//
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result, loopTypes,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target, ArrayRef<int64_t> staticTileSizes,
+ ArrayRef<int64_t> staticTileInterchange,
+ bool applyCleanup, bool useForall) {
+ return build(
+ builder, result,
+ /*target=*/target,
+ /*mixedTileSizes=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+ /*mixedTileInterchange=*/
+ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
+ applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ // Loop types are automaticaly splat by the callee, setting up one is
+ // enough.
+ SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
+ build(builder, result, loopTypes, target, mixedTileSizes,
+ mixedTileInterchange, applyCleanup, useForall);
+}
+
+void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
+ TypeRange loopTypes, Value target,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ ArrayRef<OpFoldResult> mixedTileInterchange,
+ bool applyCleanup, bool useForall) {
+ SmallVector<int64_t> staticTileSizes;
+ SmallVector<Value> dynamicTileSizes;
+ dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
+ SmallVector<int64_t> staticTileInterchange;
+ SmallVector<Value> dynamicTileInterchange;
+ dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
+ staticTileInterchange);
+ // Call the default builder which sets up the proper operands segment sizes
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ auto staticTileInterchangeAttr =
+ builder.getDenseI64ArrayAttr(staticTileInterchange);
+ unsigned numExpectedLoops =
+ useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(numExpectedLoops);
+ assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
+ "expected one loop type or as many as loops");
+ if (loopTypes.size() == 1)
+ resultTypes.append(numExpectedLoops, loopTypes[0]);
+ else
+ llvm::append_range(resultTypes, loopTypes);
+ build(builder, result, /*transformed=*/target.getType(),
+ /*loops=*/resultTypes,
+ /*target=*/target,
+ /*tile_sizes=*/dynamicTileSizes,
+ /*tile_interchange=*/dynamicTileInterchange,
+ /*static_tile_sizes=*/staticTileSizesAttr,
+ /*static_tile_interchange=*/staticTileInterchangeAttr,
+ /*apply_cleanup=*/applyCleanup,
+ /*use_forall=*/useForall);
+}
+
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
@@ -630,13 +710,25 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
- SmallVector<int64_t> tileSizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- SmallVector<int64_t> tileInterchange =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ SmallVector<int64_t> tileSizes;
+ DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileSizes(), tileSizes);
+ if (!status.succeeded())
+ return status;
+ SmallVector<int64_t> tileInterchange;
+ status = reifyMixedParamAndHandleResults(
+ state, transformOp, getMixedTileInterchange(), tileInterchange);
+ if (!status.succeeded())
+ return status;
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
+ bool useForall = getUseForall();
+ tilingOptions.setLoopType(useForall
+ ? scf::SCFTilingOptions::LoopType::ForallOp
+ : scf::SCFTilingOptions::LoopType::ForOp);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
@@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}
+ size_t numLoops =
+ useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
LogicalResult result = applyTilingToAll(
- rewriter, getOperation(), state.getPayloadOps(getTarget()),
- tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
+ rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
+ transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}
LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError() << "expects interchange to be a permutation, found "
- << getTileInterchange();
+ auto iterspace_rank = getStaticTileSizes().size();
+ ArrayRef<int64_t> permutation = getStaticTileInterchange();
+ if (permutation.size() > iterspace_rank)
+ return emitOpError()
+ << "interchange length exceeds iteration space dimensions ("
+ << iterspace_rank << "), found " << getTileInterchange();
+ SmallVector<bool> seen(iterspace_rank, false);
+ for (int64_t v : permutation) {
+ if (!ShapedType::isDynamic(v)) {
+ if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
+ return emitOpError() << "expects interchange values to be in range [0, "
+ << iterspace_rank << "), found: " << v;
+ if (seen[v])
+ return emitOpError() << "found duplicate interchange value: " << v;
+ seen[v] = true;
+ }
}
- SmallVector<int64_t> sizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
+ ArrayRef<int64_t> sizes = getStaticTileSizes();
+ size_t numExpectedLoops =
+ getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
return success();
}
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
+ return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
+}
+
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
+ return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
+ getContext());
+}
+
+void transform::FuseOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getTileSizesMutable(), effects);
+ onlyReadsHandle(getTileInterchangeMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0dac688..eb2d825 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
- // TODO: support the case that outer dimensions are not all 1s. A
- // tensor.expand_shape will be generated in this case.
- if (llvm::any_of(packOp.getAllOuterDims(),
+ if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "not all outer dimensions of the result are 1s");
}
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+
+ // Verify that there are no:
+ // * non-unit + un-tiled-outer-dims,
+ // that are permuted. Supporting such cases would require refining the logic
+ // that generates the Transpose Op.
+ if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
+ static int prev = 0;
+ // Skip tiled dims - these can be permuted.
+ if (llvm::is_contained(innerDimsPos, dim))
+ return true;
+
+ // Check whether this dim has been permuted. Permuting unit dims is fine
+ // as that's effectively a no-op.
+ if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
+ packOp.getType().getShape()[dim] != 1))
+ return false;
+
+ prev = dim;
+ return true;
+ })) {
+ return rewriter.notifyMatchFailure(
+ packOp, "At least one non-unit and un-tiled outer dim is permuted, "
+ "this is not supported ATM!");
+ }
+
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
- int64_t numberOfTiles = innerDimsPos.size();
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
@@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // - All outer dims are 1 - the corresponding transposition order doesn't
- // matter, but requires all dim indices to be present.
+ // - All tiled outer dims are 1 - the corresponding transposition order
+ // doesn't matter, but requires all dim indices to be present.
+ // - Un-tiled outer dims remain un-permuted.
- // 2.1 Get the permutation for linalg.transpose
+ // 2.1 Get the permutation for linalg.transpose:
+ // [ untiled-dims, inner-dims-pos ]
+ // Note, this logic assumes that the untiled dims are not permuted.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
@@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
- // 2.2 Create the init tensor for linalg.transpose with the correct shape
- SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
- oneIdxAttr);
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape:
+ // [ untiled-dims, tiled-dims ]
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ SmallVector<OpFoldResult> shapeForEmptyOp;
+ for (int64_t i = 0; i < srcRank; i++) {
+ if (llvm::is_contained(innerDimsPos, i)) {
+ // The tiled dims are appended after this loop.
+ continue;
+ }
+ if (inputTy.isStaticDim(i))
+ shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
+ else
+ shapeForEmptyOp.emplace_back(
+ tensor::DimOp::create(rewriter, loc, input, i).getResult());
+ }
shapeForEmptyOp.append(packOp.getMixedTiles());
// getMixedTiles() may contain Values pointing to constant ops, not the
@@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
- // 3. Insert the inner tile to the destination:
+ // 3. Insert the inner tile into the destination tensor:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- // Outer dims are all 1s!
- SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
- SmallVector<int64_t> writeShape;
+
+ // Compute the sizes attribute:
+ // [ outer-dims, tile-sizes ]
+ // Note that the output from the transpose Op excludes the tiled outer dims.
+ // However, given the assumption that:
+ // * all tiled outer dims == 1,
+ // we can just use a rank-expanding tensor.insert_slice.
+ SmallVector<OpFoldResult> writeSizes;
+ for (auto size : packOp.getAllOuterDims()) {
+ writeSizes.push_back(rewriter.getIndexAttr(size));
+ }
for (auto tileSize : packOp.getMixedTiles()) {
- auto [tileSizeStatic, tileSizeOfr] =
+ auto [_, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
- writeShape.push_back(tileSizeStatic);
}
- // 4. Replace tensor.packOp with tensor.insert_slice created above
+ // TODO: Add a constructor for tensor.insert_slice that doesn't require
+ // strides nor offsets.
+ SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
+ SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
+
auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);
+
+ // 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
return success();
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index e25a012..1382c7ac 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+ ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR
DEPENDS
MLIRMemRefOpsIncGen
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRDialectUtils
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemOpInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda..507597b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
+void SubViewOp::inferStridedMetadataRanges(
+ ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
+ SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
+ auto isUninitialized =
+ +[](IntegerValueRange range) { return range.isUninitialized(); };
+
+ // Bail early if any of the operands metadata is not ready:
+ SmallVector<IntegerValueRange> offsetOperands =
+ getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
+ if (llvm::any_of(offsetOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> sizeOperands =
+ getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
+ if (llvm::any_of(sizeOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> stridesOperands =
+ getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
+ if (llvm::any_of(stridesOperands, isUninitialized))
+ return;
+
+ StridedMetadataRange sourceRange =
+ ranges[getSourceMutable().getOperandNumber()];
+ if (sourceRange.isUninitialized())
+ return;
+
+ ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
+
+ // Get the dropped dims.
+ llvm::SmallBitVector droppedDims = getDroppedDims();
+
+ // Compute the new offset, strides and sizes.
+ ConstantIntRanges offset = sourceRange.getOffsets()[0];
+ SmallVector<ConstantIntRanges> strides, sizes;
+
+ for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
+ bool dropped = droppedDims.test(i);
+ // Compute the new offset.
+ ConstantIntRanges off =
+ intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
+ offset = intrange::inferAdd({offset, off});
+
+ // Skip dropped dimensions.
+ if (dropped)
+ continue;
+ // Multiply the strides.
+ strides.push_back(
+ intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
+ // Get the sizes.
+ sizes.push_back(sizeOperands[i].getValue());
+ }
+
+ setMetadata(getResult(),
+ StridedMetadataRange::getRanked(
+ SmallVector<ConstantIntRanges>({std::move(offset)}),
+ std::move(sizes), std::move(strides)));
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6564a4e..642ced9 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
@@ -74,14 +75,16 @@ struct MemRefPointerLikeModel
}
mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
- StringRef varName, Type varType,
- Value originalVar) const {
+ StringRef varName, Type varType, Value originalVar,
+ bool &needsFree) const {
auto memrefTy = cast<MemRefType>(pointer);
// Check if this is a static memref (all dimensions are known) - if yes
// then we can generate an alloca operation.
- if (memrefTy.hasStaticShape())
+ if (memrefTy.hasStaticShape()) {
+ needsFree = false; // alloca doesn't need deallocation
return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+ }
// For dynamic memrefs, extract sizes from the original variable if
// provided. Otherwise they cannot be handled.
@@ -99,6 +102,7 @@ struct MemRefPointerLikeModel
// Note: We only add dynamic sizes to the dynamicSizes array
// Static dimensions are handled automatically by AllocOp
}
+ needsFree = true; // alloc needs deallocation
return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
.getResult();
}
@@ -108,10 +112,14 @@ struct MemRefPointerLikeModel
}
bool genFree(Type pointer, OpBuilder &builder, Location loc,
- TypedValue<PointerLikeType> varPtr, Type varType) const {
- if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ TypedValue<PointerLikeType> varToFree, Value allocRes,
+ Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
+ // Use allocRes if provided to determine the allocation type
+ Value valueToInspect = allocRes ? allocRes : memrefValue;
+
// Walk through casts to find the original allocation
- Value currentValue = memrefValue;
+ Value currentValue = valueToInspect;
Operation *originalAlloc = nullptr;
// Follow the chain of operations to find the original allocation
@@ -150,7 +158,7 @@ struct MemRefPointerLikeModel
return true;
}
if (isa<memref::AllocOp>(originalAlloc)) {
- // This is an alloc - generate dealloc
+ // This is an alloc - generate dealloc on varToFree
memref::DeallocOp::create(builder, loc, memrefValue);
return true;
}
@@ -1003,6 +1011,142 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
}
};
+//===----------------------------------------------------------------------===//
+// Recipe Region Helpers
+//===----------------------------------------------------------------------===//
+
+/// Create and populate an init region for privatization recipes.
+/// Returns the init block on success, or nullptr on failure.
+/// Sets needsFree to indicate if the allocated memory requires deallocation.
+static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc,
+ Type varType, StringRef varName,
+ ValueRange bounds,
+ bool &needsFree) {
+ // Create init block with arguments: original value + bounds
+ SmallVector<Type> argTypes{varType};
+ SmallVector<Location> argLocs{loc};
+ for (Value bound : bounds) {
+ argTypes.push_back(bound.getType());
+ argLocs.push_back(loc);
+ }
+
+ auto initBlock = std::make_unique<Block>();
+ initBlock->addArguments(argTypes, argLocs);
+ builder.setInsertionPointToStart(initBlock.get());
+
+ Value privatizedValue;
+
+ // Get the block argument that represents the original variable
+ Value blockArgVar = initBlock->getArgument(0);
+
+ // Generate init region body based on variable type
+ if (isa<MappableType>(varType)) {
+ auto mappableTy = cast<MappableType>(varType);
+ auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
+ privatizedValue = mappableTy.generatePrivateInit(
+ builder, loc, typedVar, varName, bounds, {}, needsFree);
+ if (!privatizedValue)
+ return nullptr;
+ } else {
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ // Use PointerLikeType's allocation API with the block argument
+ privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
+ blockArgVar, needsFree);
+ if (!privatizedValue)
+ return nullptr;
+ }
+
+ // Add yield operation to init block
+ acc::YieldOp::create(builder, loc, privatizedValue);
+
+ return initBlock;
+}
+
+/// Create and populate a copy region for firstprivate recipes.
+/// Returns the copy block on success, or nullptr on failure.
+/// TODO: Handle MappableType - it does not yet have a copy API.
+static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc,
+ Type varType,
+ ValueRange bounds) {
+ // Create copy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> copyArgTypes{varType, varType};
+ SmallVector<Location> copyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ copyArgTypes.push_back(bound.getType());
+ copyArgLocs.push_back(loc);
+ }
+
+ auto copyBlock = std::make_unique<Block>();
+ copyBlock->addArguments(copyArgTypes, copyArgLocs);
+ builder.setInsertionPointToStart(copyBlock.get());
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a copy API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return nullptr;
+
+ // Generate copy region body based on variable type
+ if (isPointerLike) {
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ Value originalArg = copyBlock->getArgument(0);
+ Value privatizedArg = copyBlock->getArgument(1);
+
+ // Generate copy operation using PointerLikeType interface
+ if (!pointerLikeTy.genCopy(
+ builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
+ cast<TypedValue<PointerLikeType>>(originalArg), varType))
+ return nullptr;
+ }
+
+ // Add terminator to copy block
+ acc::TerminatorOp::create(builder, loc);
+
+ return copyBlock;
+}
+
+/// Create and populate a destroy region for privatization recipes.
+/// Returns the destroy block on success, or nullptr if not needed.
+static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder,
+ Location loc, Type varType,
+ Value allocRes,
+ ValueRange bounds) {
+ // Create destroy block with arguments: original value + privatized value +
+ // bounds
+ SmallVector<Type> destroyArgTypes{varType, varType};
+ SmallVector<Location> destroyArgLocs{loc, loc};
+ for (Value bound : bounds) {
+ destroyArgTypes.push_back(bound.getType());
+ destroyArgLocs.push_back(loc);
+ }
+
+ auto destroyBlock = std::make_unique<Block>();
+ destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
+ builder.setInsertionPointToStart(destroyBlock.get());
+
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+ // TODO: Handle MappableType - it does not yet have a deallocation API.
+ // Otherwise, for now just fallback to pointer-like behavior.
+ if (isMappable && !isPointerLike)
+ return nullptr;
+
+ assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
+ auto pointerLikeTy = cast<PointerLikeType>(varType);
+ auto privatizedArg =
+ cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
+ // Pass allocRes to help determine the allocation type
+ if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType))
+ return nullptr;
+
+ acc::TerminatorOp::create(builder, loc);
+
+ return destroyBlock;
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1050,6 +1194,55 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<PrivateRecipeOp>
+PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ // Create init and destroy blocks using shared helpers
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Save the original insertion point for creating the recipe operation later
+ auto originalInsertionPoint = builder.saveInsertionPoint();
+
+ bool needsFree = false;
+ auto initBlock =
+ createInitRegion(builder, loc, varType, varName, bounds, needsFree);
+ if (!initBlock)
+ return std::nullopt;
+
+ // Only create destroy region if the allocation needs deallocation
+ std::unique_ptr<Block> destroyBlock;
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
+ if (!destroyBlock)
+ return std::nullopt;
+ }
+
+ // Now create the recipe operation at the original insertion point and attach
+ // the blocks
+ builder.restoreInsertionPoint(originalInsertionPoint);
+ auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Move the blocks into the recipe's regions
+ recipe.getInitRegion().push_back(initBlock.release());
+ if (destroyBlock)
+ recipe.getDestroyRegion().push_back(destroyBlock.release());
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// FirstprivateRecipeOp
//===----------------------------------------------------------------------===//
@@ -1080,6 +1273,60 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
return success();
}
+std::optional<FirstprivateRecipeOp>
+FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
+ StringRef recipeName, Type varType,
+ StringRef varName, ValueRange bounds) {
+ // First, validate that we can handle this variable type
+ bool isMappable = isa<MappableType>(varType);
+ bool isPointerLike = isa<PointerLikeType>(varType);
+
+ // Unsupported type
+ if (!isMappable && !isPointerLike)
+ return std::nullopt;
+
+ // Create init, copy, and destroy blocks using shared helpers
+ OpBuilder::InsertionGuard guard(builder);
+
+ // Save the original insertion point for creating the recipe operation later
+ auto originalInsertionPoint = builder.saveInsertionPoint();
+
+ bool needsFree = false;
+ auto initBlock =
+ createInitRegion(builder, loc, varType, varName, bounds, needsFree);
+ if (!initBlock)
+ return std::nullopt;
+
+ auto copyBlock = createCopyRegion(builder, loc, varType, bounds);
+ if (!copyBlock)
+ return std::nullopt;
+
+ // Only create destroy region if the allocation needs deallocation
+ std::unique_ptr<Block> destroyBlock;
+ if (needsFree) {
+ // Extract the allocated value from the init block's yield operation
+ auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
+ Value allocRes = yieldOp.getOperand(0);
+
+ destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
+ if (!destroyBlock)
+ return std::nullopt;
+ }
+
+ // Now create the recipe operation at the original insertion point and attach
+ // the blocks
+ builder.restoreInsertionPoint(originalInsertionPoint);
+ auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
+
+ // Move the blocks into the recipe's regions
+ recipe.getInitRegion().push_back(initBlock.release());
+ recipe.getCopyRegion().push_back(copyBlock.release());
+ if (destroyBlock)
+ recipe.getDestroyRegion().push_back(destroyBlock.release());
+
+ return recipe;
+}
+
//===----------------------------------------------------------------------===//
// ReductionRecipeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49..ac72002 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
sourceTensorType.getEncoding());
}
+// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58256b0..45c54c7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// ```mlir
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+/// ```
+///
+/// as,
+///
+/// ```mlir
+/// %out = arith.constant dense<false> : vector<3xi1>.
+/// ```
+///
+/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
+/// is false at ALL indices we fold. If the constant was 1, then
+/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
+/// conservatively preferring the 'compact' vector.step representation.
+///
+/// Note: this folder only works for the case where the constant (`%cst` above)
+/// is the second operand of the comparison. The arith.cmpi canonicalizer will
+/// ensure that constants are always second (on the right).
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+ for (OpOperand &use : stepOp.getResult().getUses()) {
+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
+ if (!cmpiOp)
+ continue;
+
+ // arith.cmpi canonicalizer makes constants final operands.
+ const unsigned stepOperandNumber = use.getOperandNumber();
+ if (stepOperandNumber != 0)
+ continue;
+
+ // Check that operand 1 is a constant.
+ unsigned constOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+ std::optional<int64_t> maybeConstValue =
+ getConstantIntValue(otherOperand);
+ if (!maybeConstValue.has_value())
+ continue;
+
+ int64_t constValue = maybeConstValue.value();
+ arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+ auto maybeSplat = [&]() -> std::optional<bool> {
+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
+ if ((pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ult;
+
+ // Handle ule and ugt.
+ if ((pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) &&
+ stepSize - 1 <= constValue) {
+ return pred == arith::CmpIPredicate::ule;
+ }
+
+ // Handle eq and ne.
+ if ((pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ne;
+
+ return std::nullopt;
+ }();
+
+ if (!maybeSplat.has_value())
+ continue;
+
+ rewriter.setInsertionPointAfter(cmpiOp);
+
+ auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+ if (!type)
+ continue;
+
+ auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
+ Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+ type, boolAttr);
+
+ rewriter.replaceOp(cmpiOp, splat);
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<StepCompareFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f..12e6475 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
// yielded value, and:
- // 1. recording the unique first position at which the value is yielded.
+ // 1. recording the unique first position at which the value with uses is
+ // yielded.
// 2. recording for the result, the first position at which the dedup'ed
// value is yielded.
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
+ if (result.use_empty())
+ continue;
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
- if (result.use_empty() || !it.second)
+ if (!it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());
- llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
- origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
- // Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ // Replace the old `WarpOp` with the new one that has additional yield
+ // values and types.
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
- rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
- static_cast<bool>(ifOp.thenBlock()),
+ rewriter, ifOp.getLoc(), newIfOpDistResTypes,
+ newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
- newWarpOp.getResult(warpResRangeStart));
+ newWarpOp.getResult(newIndices[warpResRangeStart]));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `IfOp`.
- for (auto [origIdx, newIdx] : origToNewYieldIdx)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `IfOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(ifOp);
- rewriter.eraseOp(warpOp);
return success();
}
@@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueDistTypes.begin(),
escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
- // types. We also create a mapping between the non-`ForOp` yielded value
- // index and the corresponding new `WarpOp` yield value index (needed to
- // update users later).
- llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+ // types.
for (auto [i, v] :
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
- nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
@@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
- newForOpOperands.push_back(newWarpOp.getResult(i));
+ newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
@@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
- innerWarpInput.push_back(newWarpOp.getResult(i));
+ innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
@@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (!innerWarp.getResults().empty())
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
- // Update the users of original `WarpOp` results that were coming from the
+ // Update the users of the new `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `ForOp`.
- for (auto [origIdx, newIdx] : nonForResultMapping)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(forOp);
- rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5..fbae098 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
+ int64_t targetShapeRank = targetShape->size();
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
- // Bail-out if rank(source) != rank(target). The main limitation here is the
- // fact that `ExtractStridedSlice` requires the rank for the input and
- // output to match. If needed, we can relax this later.
- if (originalSize.size() != targetShape->size())
- return rewriter.notifyMatchFailure(
- op, "expected input vector rank to match target shape rank");
+ int64_t originalShapeRank = originalSize.size();
+
Location loc = op->getLoc();
+
+ // Handle rank mismatch by adding leading unit dimensions to targetShape
+ SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+ int64_t rankDiff = originalShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(),
+ adjustedTargetShape.begin() + rankDiff, 1);
+ std::copy(targetShape->begin(), targetShape->end(),
+ adjustedTargetShape.begin() + rankDiff);
+
+ int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t> strides(targetShape->size(), 1);
- VectorType newVecType =
+ SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
+ VectorType unrolledVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalSize, *targetShape)) {
+ StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
extractOperands.push_back(operand.get());
continue;
}
- extractOperands.push_back(
- rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, operand.get(), offsets, *targetShape, strides));
+ Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+ // Reshape to remove leading unit dims if needed
+ if (adjustedTargetShapeRank > targetShapeRank) {
+ extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, VectorType::get(*targetShape, vecType.getElementType()),
+ extracted);
+ }
+ extractOperands.push_back(extracted);
}
+
Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, op, extractOperands, newVecType);
+ rewriter, loc, op, extractOperands, unrolledVecType);
+
+ Value computeResult = newOp->getResult(0);
+
+ // Use strides sized to targetShape for proper insertion
+ SmallVector<int64_t> insertStrides =
+ (adjustedTargetShapeRank > targetShapeRank)
+ ? SmallVector<int64_t>(targetShapeRank, 1)
+ : strides;
+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, newOp->getResult(0), result, offsets, strides);
+ loc, computeResult, result, offsets, insertStrides);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 36c498e..f77784a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
return getTileShape(op->getOpResult(0));
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
+ xegpu::StoreMatrixOp>(op))
return getTileShape(op->getOpOperand(0));
- if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
+ if (isa<xegpu::StoreNdOp>(op))
return getTileShape(op->getOpOperand(1));
+ // Handle LoadGatherOp and StoreScatterOp (with and without offset)
+ if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ if (loadGatherOp.getOffsets())
+ return getTileShape(loadGatherOp->getOpResult(0));
+ else
+ return getTileShape(loadGatherOp->getOpOperand(0));
+ }
+
+ if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+ return getTileShape(storeScatterOp.getOffsets()
+ ? storeScatterOp->getOpOperand(0)
+ : storeScatterOp->getOpOperand(1));
+
if (isa<xegpu::DpasOp>(op)) {
std::optional<SmallVector<int64_t>> aTile =
getTileShape(op->getOpOperand(0));
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3d19c5a..9b23dd6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2200,10 +2200,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
os << '>';
}
os << '[';
- interleave(
- loc.getLocations(),
- [&](Location loc) { printLocationInternal(loc, pretty); },
- [&]() { os << ", "; });
+ interleaveComma(loc.getLocations(), [&](Location loc) {
+ printLocationInternal(loc, pretty);
+ });
os << ']';
})
.Default([&](LocationAttr loc) {
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 776b5c6..4d81918 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl {
}
// Otherwise, try to load the source file.
- std::string ignored;
- unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
+ auto bufferOrErr = llvm::MemoryBuffer::getFile(filename);
+ if (!bufferOrErr)
+ return 0;
+ unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc());
filenameToBufId[filename] = id;
return id;
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 1fa04ed..89b81cf 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -121,6 +121,11 @@ namespace mlir {
class MLIRContextImpl {
public:
//===--------------------------------------------------------------------===//
+ // Remark
+ //===--------------------------------------------------------------------===//
+ std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
+
+ //===--------------------------------------------------------------------===//
// Debugging
//===--------------------------------------------------------------------===//
@@ -135,11 +140,6 @@ public:
DiagnosticEngine diagEngine;
//===--------------------------------------------------------------------===//
- // Remark
- //===--------------------------------------------------------------------===//
- std::unique_ptr<remark::detail::RemarkEngine> remarkEngine;
-
- //===--------------------------------------------------------------------===//
// Options
//===--------------------------------------------------------------------===//
@@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
}
-MLIRContext::~MLIRContext() = default;
+MLIRContext::~MLIRContext() {
+ // finalize remark engine before destroying anything else.
+ impl->remarkEngine.reset();
+}
/// Copy the specified array of elements into memory managed by the provided
/// bump pointer allocator. This assumes the elements are all PODs.
diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp
index a55f61a..031eae2 100644
--- a/mlir/lib/IR/Remarks.cpp
+++ b/mlir/lib/IR/Remarks.cpp
@@ -16,7 +16,7 @@
#include "llvm/ADT/StringRef.h"
using namespace mlir::remark::detail;
-
+using namespace mlir::remark;
//------------------------------------------------------------------------------
// Remark
//------------------------------------------------------------------------------
@@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) {
void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
// Header: [Type] pass:remarkName
StringRef type = getRemarkTypeString();
- StringRef categoryName = getFullCategoryName();
+ StringRef categoryName = getCombinedCategoryName();
StringRef name = remarkName;
os << '[' << type << "] ";
@@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const {
os << "Function=" << getFunction() << " | ";
if (printLocation) {
- if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation()))
+ if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) {
os << " @" << flc.getFilename() << ":" << flc.getLine() << ":"
<< flc.getColumn();
+ }
}
printArgs(os, getArgs());
@@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const {
r.RemarkType = getRemarkType();
r.RemarkName = getRemarkName();
// MLIR does not use passes; instead, it has categories and sub-categories.
- r.PassName = getFullCategoryName();
+ r.PassName = getCombinedCategoryName();
r.FunctionName = getFunction();
r.Loc = locLambda();
for (const Remark::Arg &arg : getArgs()) {
@@ -225,26 +226,42 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc,
// RemarkEngine
//===----------------------------------------------------------------------===//
-void RemarkEngine::report(const Remark &&remark) {
+void RemarkEngine::reportImpl(const Remark &remark) {
// Stream the remark
- if (remarkStreamer)
+ if (remarkStreamer) {
remarkStreamer->streamOptimizationRemark(remark);
+ }
// Print using MLIR's diagnostic
if (printAsEmitRemarks)
emitRemark(remark.getLocation(), remark.getMsg());
}
+void RemarkEngine::report(const Remark &&remark) {
+ if (remarkEmittingPolicy)
+ remarkEmittingPolicy->reportRemark(remark);
+}
+
RemarkEngine::~RemarkEngine() {
+ if (remarkEmittingPolicy)
+ remarkEmittingPolicy->finalize();
+
if (remarkStreamer)
remarkStreamer->finalize();
}
-llvm::LogicalResult
-RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer,
- std::string *errMsg) {
- // If you need to validate categories/filters, do so here and set errMsg.
+llvm::LogicalResult RemarkEngine::initialize(
+ std::unique_ptr<MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ std::string *errMsg) {
+
remarkStreamer = std::move(streamer);
+
+ auto reportFunc =
+ std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1);
+ remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc)));
+
+ this->remarkEmittingPolicy = std::move(remarkEmittingPolicy);
return success();
}
@@ -301,14 +318,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks,
}
llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
- MLIRContext &ctx,
- std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer,
- const remark::RemarkCategories &cats, bool printAsEmitRemarks) {
+ MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
+ const RemarkCategories &cats, bool printAsEmitRemarks) {
auto engine =
- std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats);
+ std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats);
std::string errMsg;
- if (failed(engine->initialize(std::move(streamer), &errMsg))) {
+ if (failed(engine->initialize(std::move(streamer),
+ std::move(remarkEmittingPolicy), &errMsg))) {
llvm::report_fatal_error(
llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg);
}
@@ -316,3 +334,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks(
return success();
}
+
+//===----------------------------------------------------------------------===//
+// Remark emitting policies
+//===----------------------------------------------------------------------===//
+
+namespace mlir::remark {
+RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default;
+RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default;
+} // namespace mlir::remark
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 388de1c..f96af02 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES
FunctionInterfaces.cpp
IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
+ InferStridedMetadataInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
MemOpInterfaces.cpp
@@ -64,6 +65,21 @@ add_mlir_library(MLIRFunctionInterfaces
add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
+
+add_mlir_library(MLIRInferStridedMetadataInterface
+ InferStridedMetadataInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
+
+ DEPENDS
+ MLIRInferStridedMetadataInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRInferIntRangeInterface
+ MLIRIR
+)
+
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_library(MLIRLoopLikeInterface
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d..84fc9b8 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) {
return os;
}
+SmallVector<IntegerValueRange>
+mlir::getIntValueRanges(ArrayRef<OpFoldResult> values,
+ GetIntRangeFn getIntRange, int32_t indexBitwidth) {
+ SmallVector<IntegerValueRange> ranges;
+ ranges.reserve(values.size());
+ for (OpFoldResult ofr : values) {
+ if (auto value = dyn_cast<Value>(ofr)) {
+ ranges.push_back(getIntRange(value));
+ continue;
+ }
+
+ // Create a constant range.
+ auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
+ ranges.emplace_back(ConstantIntRanges::constant(
+ attr.getValue().sextOrTrunc(indexBitwidth)));
+ }
+ return ranges;
+}
+
void mlir::intrange::detail::defaultInferResultRanges(
InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
SetIntLatticeFn setResultRanges) {
diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
new file mode 100644
index 0000000..483e9f1
--- /dev/null
+++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp
@@ -0,0 +1,36 @@
+//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <optional>
+
+using namespace mlir;
+
+#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc"
+
+void StridedMetadataRange::print(raw_ostream &os) const {
+ if (isUninitialized()) {
+ os << "strided_metadata<None>";
+ return;
+ }
+ os << "strided_metadata<offset = [";
+ llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "], sizes = [";
+ llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "], strides = [";
+ llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) {
+ os << "{" << range << "}";
+ });
+ os << "]>";
+}
diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp
index d213a1a..bf36286 100644
--- a/mlir/lib/Remark/RemarkStreamer.cpp
+++ b/mlir/lib/Remark/RemarkStreamer.cpp
@@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() {
namespace mlir::remark {
LogicalResult enableOptimizationRemarksWithLLVMStreamer(
MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt,
+ std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy,
const RemarkCategories &cat, bool printAsEmitRemarks) {
FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr =
@@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer(
if (failed(sOr))
return failure();
- return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat,
+ return remark::enableOptimizationRemarks(ctx, std::move(*sOr),
+ std::move(remarkEmittingPolicy), cat,
printAsEmitRemarks);
}
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 4bbcd8e..db39c70 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) {
return UnknownLoc::get(context);
// Add a fused location to link the subprogram information.
- StringAttr funcName = StringAttr::get(context, subprogram->getName());
StringAttr fileName = StringAttr::get(context, subprogram->getFilename());
return FusedLocWith<DISubprogramAttr>::get(
- {NameLoc::get(funcName),
- FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
+ {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
translate(subprogram), context);
}
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 132be4e..51c6077 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -956,7 +956,7 @@ inline parsed_inst_t ExpressionParser::buildNumericOp(
<< ", type = " << ty << " ***";
auto tysToPop = SmallVector<Type, numOperands>();
tysToPop.resize(numOperands);
- std::fill(tysToPop.begin(), tysToPop.end(), ty);
+ llvm::fill(tysToPop, ty);
auto operands = popOperands(tysToPop);
if (failed(operands))
return failure();
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index c883baa..3236b4f 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -27,6 +27,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Parser.h"
#include <optional>
@@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
llvm::SourceMgr tdSrcMgr;
tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
+ tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
// This class provides a context argument for the llvm::SourceMgr diagnostic
// handler.
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 30fd384..9ef405d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -37,6 +37,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Remarks/RemarkFormat.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/ManagedStatic.h"
@@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
"bitstream", "Print bitstream file")),
llvm::cl::cat(remarkCategory)};
+ static llvm::cl::opt<RemarkPolicy, /*ExternalStorage=*/true> remarkPolicy{
+ "remark-policy",
+ llvm::cl::desc("Specify the policy for remark output."),
+ cl::location(remarkPolicyFlag),
+ llvm::cl::value_desc("format"),
+ llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL),
+ llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all",
+ "Print all remarks"),
+ clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final",
+ "Print final remarks")),
+ llvm::cl::cat(remarkCategory)};
+
static cl::opt<std::string, /*ExternalStorage=*/true> remarksAll(
"remarks-filter",
cl::desc("Show all remarks: passed, missed, failed, analysis"),
@@ -517,18 +530,28 @@ performActions(raw_ostream &os,
return failure();
context->enableMultithreading(wasThreadingEnabled);
-
+ // Set the remark categories and policy.
remark::RemarkCategories cats{
config.getRemarksAllFilter(), config.getRemarksPassedFilter(),
config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(),
config.getRemarksFailedFilter()};
mlir::MLIRContext &ctx = *context;
+ // Helper to create the appropriate policy based on configuration
+ auto createPolicy = [&config]()
+ -> std::unique_ptr<mlir::remark::detail::RemarkEmittingPolicyBase> {
+ if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL)
+ return std::make_unique<mlir::remark::RemarkEmittingPolicyAll>();
+ if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL)
+ return std::make_unique<mlir::remark::RemarkEmittingPolicyFinal>();
+
+ llvm_unreachable("Invalid remark policy");
+ };
switch (config.getRemarkFormat()) {
case RemarkFormat::REMARK_FORMAT_STDOUT:
if (failed(mlir::remark::enableOptimizationRemarks(
- ctx, nullptr, cats, true /*printAsEmitRemarks*/)))
+ ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/)))
return failure();
break;
@@ -537,7 +560,7 @@ performActions(raw_ostream &os,
? "mlir-remarks.yaml"
: config.getRemarksOutputFile();
if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
- ctx, file, llvm::remarks::Format::YAML, cats)))
+ ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats)))
return failure();
break;
}
@@ -547,7 +570,7 @@ performActions(raw_ostream &os,
? "mlir-remarks.bitstream"
: config.getRemarksOutputFile();
if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer(
- ctx, file, llvm::remarks::Format::Bitstream, cats)))
+ ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats)))
return failure();
break;
}
@@ -593,6 +616,12 @@ performActions(raw_ostream &os,
AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
&fallbackResourceMap);
os << OpWithState(op.get(), asmState) << '\n';
+
+ // This is required if the remark policy is final. Otherwise, the remarks are
+ // not emitted.
+ if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine())
+ engine->getRemarkEmittingPolicy()->finalize();
+
return success();
}
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 60b9567..1dbe7eca 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -31,6 +31,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include <optional>
using namespace mlir;
@@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
llvm::append_range(includeDirs, extraDirs);
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
index 3080b78..2d817be 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
@@ -17,6 +17,7 @@
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Parser.h"
#include "llvm/TableGen/Record.h"
#include <optional>
@@ -448,6 +449,7 @@ void TableGenTextFile::initialize(
return;
}
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
// This class provides a context argument for the SourceMgr diagnostic
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 111f58e..5f3b04a 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -66,7 +66,9 @@ size_t mlir::moveLoopInvariantCode(
size_t numMoved = 0;
for (Region *region : regions) {
- LDBG() << "Original loop:\n" << *region->getParentOp();
+ LDBG() << "Original loop:\n"
+ << OpWithFlags(region->getParentOp(),
+ OpPrintingFlags().skipRegions());
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
@@ -90,7 +92,8 @@ size_t mlir::moveLoopInvariantCode(
!canBeHoisted(op, definedOutside))
continue;
- LDBG() << "Moving loop-invariant op: " << *op;
+ LDBG() << "Moving loop-invariant op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
moveOutOfRegion(op, region);
++numMoved;
@@ -111,9 +114,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
},
- [&](Operation *op, Region *) {
- return isMemoryEffectFree(op) && isSpeculatable(op);
- },
+ [&](Operation *op, Region *) { return isPure(op); },
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
}