diff options
Diffstat (limited to 'mlir/lib')
43 files changed, 1820 insertions, 276 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 8062b474..a84d10d 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -258,6 +258,39 @@ getAllocEffectFor(Value value, return success(); } +static Operation *isDistinctObjectsOp(Operation *op) { + if (op && op->hasTrait<OpTrait::DistinctObjectsTrait>()) + return op; + + return nullptr; +} + +static Value getDistinctObjectsOperand(Operation *op, Value value) { + unsigned argNumber = cast<OpResult>(value).getResultNumber(); + return op->getOperand(argNumber); +} + +static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) { + // We should already checked that lhs and rhs are different. + assert(lhs != rhs && "lhs and rhs must be different"); + + // Result and corresponding operand must alias. + auto lhsOp = isDistinctObjectsOp(lhs.getDefiningOp()); + if (lhsOp && getDistinctObjectsOperand(lhsOp, lhs) == rhs) + return AliasResult::MustAlias; + + auto rhsOp = isDistinctObjectsOp(rhs.getDefiningOp()); + if (rhsOp && getDistinctObjectsOperand(rhsOp, rhs) == lhs) + return AliasResult::MustAlias; + + // If two different values come from the same `DistinctObjects` operation, + // they don't alias. + if (lhsOp && lhsOp == rhsOp) + return AliasResult::NoAlias; + + return std::nullopt; +} + /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { if (lhs == rhs) @@ -289,6 +322,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { : AliasResult::MayAlias; } + if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs)) + return *result; + // Otherwise, neither of the values are constant so check to see if either has // an allocation effect. bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 609cb34..db10ebc 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp + DataFlow/StridedMetadataRangeAnalysis.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis @@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis MLIRDataLayoutInterfaces MLIRFunctionInterfaces MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRPresburger diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp new file mode 100644 index 0000000..01c9daf --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp @@ -0,0 +1,127 @@ +//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "strided-metadata-range-analysis" + +using namespace mlir; +using namespace mlir::dataflow; + +/// Get the entry state for a value. For any value that is not a ranked memref, +/// this function sets the metadata to a top state with no offsets, sizes, or +/// strides. For `memref` types, this function will use the metadata in the type +/// to try to deduce as much informaiton as possible. +static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) { + // TODO: generalize this method with a type interface. + auto mTy = dyn_cast<BaseMemRefType>(v.getType()); + + // If not a memref or it's un-ranked, don't infer any metadata. + if (!mTy || !mTy.hasRank()) + return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0); + + // Get the top state. + auto metadata = + StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank()); + + // Compute the offset and strides. + int64_t offset; + SmallVector<int64_t> strides; + if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset))) + return metadata; + + // Refine the metadata if we know it from the type. + if (!ShapedType::isDynamic(offset)) { + metadata.getOffsets()[0] = + ConstantIntRanges::constant(APInt(indexBitwidth, offset)); + } + for (auto &&[size, range] : + llvm::zip_equal(mTy.getShape(), metadata.getSizes())) { + if (ShapedType::isDynamic(size)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, size)); + } + for (auto &&[stride, range] : + llvm::zip_equal(strides, metadata.getStrides())) { + if (ShapedType::isDynamic(stride)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, stride)); + } + + return metadata; +} + +StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis( + DataFlowSolver &solver, int32_t indexBitwidth) + : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) { + assert(indexBitwidth > 0 && "invalid bitwidth"); +} + +void StridedMetadataRangeAnalysis::setToEntryState( + StridedMetadataRangeLattice *lattice) { + propagateIfChanged(lattice, lattice->join(getEntryStateImpl( + lattice->getAnchor(), indexBitwidth))); +} + +LogicalResult StridedMetadataRangeAnalysis::visitOperation( + Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands, + ArrayRef<StridedMetadataRangeLattice *> results) { + auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op); + + // Bail if we cannot reason about the op. + if (!inferrable) { + setAllToEntryStates(results); + return success(); + } + + LDBG() << "Inferring metadata for: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + + // Helper function to retrieve int range values. + auto getIntRange = [&](Value value) -> IntegerValueRange { + auto lattice = getOrCreateFor<IntegerValueRangeLattice>( + getProgramPointAfter(op), value); + return lattice ? lattice->getValue() : IntegerValueRange(); + }; + + // Convert the arguments lattices to a vector. + SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector( + operands, [](const StridedMetadataRangeLattice *lattice) { + return lattice->getValue(); + }); + + // Callback to set metadata on a result. + auto joinCallback = [&](Value v, const StridedMetadataRange &md) { + auto result = cast<OpResult>(v); + assert(llvm::is_contained(op->getResults(), result)); + LDBG() << "- Inferred metadata: " << md; + StridedMetadataRangeLattice *lattice = results[result.getResultNumber()]; + ChangeResult changed = lattice->join(md); + LDBG() << "- Joined metadata: " << lattice->getValue(); + propagateIfChanged(lattice, changed); + }; + + // Infer the metadata. + inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback, + indexBitwidth); + return success(); +} diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 71986f8..bebf1b8 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) +add_subdirectory(MathToXeVM) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt new file mode 100644 index 0000000..050c0ed --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(MLIRMathToXeVM + MathToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithAttrToLLVMConversion + MLIRArithDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRMathDialect + MLIRXeVMDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp new file mode 100644 index 0000000..0fe31d0 --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -0,0 +1,167 @@ +//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/FormatVariadic.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-xevm" + +/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. +template <typename Op> +struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> { + + ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, + PatternBenefit benefit = 1) + : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {} + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isSPIRVCompatibleFloatOrVec(op.getType())) + return failure(); + + arith::FastMathFlags fastFlags = op.getFastmath(); + if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn)) + return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); + + SmallVector<Type, 1> operandTypes; + for (auto operand : adaptor.getOperands()) { + Type opTy = operand.getType(); + // This pass only supports operations on vectors that are already in SPIRV + // supported vector sizes: Distributing unsupported vector sizes to SPIRV + // supported vector sizes are done in other blocking optimization passes. + if (!isSPIRVCompatibleFloatOrVec(opTy)) + return rewriter.notifyMatchFailure( + op, llvm::formatv("incompatible operand type: '{0}'", opTy)); + operandTypes.push_back(opTy); + } + + auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>(); + auto funcOpRes = LLVM::lookupOrCreateFn( + rewriter, moduleOp, getMangledNativeFuncName(operandTypes), + operandTypes, op.getType()); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + + auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( + op, funcOp, adaptor.getOperands()); + // Preserve fastmath flags in our MLIR op when converting to llvm function + // calls, in order to allow further fastmath optimizations: We thus need to + // convert arith fastmath attrs into attrs recognized by llvm. + arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op); + mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; + callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); + return success(); + } + + inline bool isSPIRVCompatibleFloatOrVec(Type type) const { + if (type.isFloat()) + return true; + if (auto vecType = dyn_cast<VectorType>(type)) { + if (!vecType.getElementType().isFloat()) + return false; + // SPIRV distinguishes between vectors and matrices: OpenCL native math + // intrsinics are not compatible with matrices. + ArrayRef<int64_t> shape = vecType.getShape(); + if (shape.size() != 1) + return false; + // SPIRV only allows vectors of size 2, 3, 4, 8, 16. + if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || + shape[0] == 16) + return true; + } + return false; + } + + inline std::string + getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const { + std::string mangledFuncName = + "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); + + auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { + if (type.isF32()) + mangledFuncName += "f"; + else if (type.isF16()) + mangledFuncName += "Dh"; + else if (type.isF64()) + mangledFuncName += "d"; + }; + + for (auto type : operandTypes) { + if (auto vecType = dyn_cast<VectorType>(type)) { + mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; + appendFloatToMangledFunc(vecType.getElementType()); + } else + appendFloatToMangledFunc(type); + } + + return mangledFuncName; + } + + const StringRef nativeFunc; +}; + +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith) { + patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(), + "__spirv_ocl_native_exp"); + patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(), + "__spirv_ocl_native_cos"); + patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>( + patterns.getContext(), "__spirv_ocl_native_exp2"); + patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(), + "__spirv_ocl_native_log"); + patterns.add<ConvertNativeFuncPattern<math::Log2Op>>( + patterns.getContext(), "__spirv_ocl_native_log2"); + patterns.add<ConvertNativeFuncPattern<math::Log10Op>>( + patterns.getContext(), "__spirv_ocl_native_log10"); + patterns.add<ConvertNativeFuncPattern<math::PowFOp>>( + patterns.getContext(), "__spirv_ocl_native_powr"); + patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_rsqrt"); + patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(), + "__spirv_ocl_native_sin"); + patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_sqrt"); + patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(), + "__spirv_ocl_native_tan"); + if (convertArith) + patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>( + patterns.getContext(), "__spirv_ocl_native_divide"); +} + +namespace { +struct ConvertMathToXeVMPass + : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToXeVMPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateMathToXeVMConversionPatterns(patterns, convertArith); + ConversionTarget target(getContext()); + target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>(); + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a5336ed..00df14b1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1392,6 +1392,137 @@ public: } }; +// Collapse tensor<1xiN> into tensor<iN> +// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16> +static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input, + Location loc) { + SmallVector<ReassociationExprs, 1> reassociation; + // Create the collapsed type + auto inputType = cast<RankedTensorType>(input.getType()); + auto elemType = inputType.getElementType(); + auto collapsedType = RankedTensorType::get({}, elemType); + // Emit the collapse op + return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input, + reassociation); +} + +static llvm::SmallVector<int8_t> +convertToI8(const llvm::SmallVector<int32_t> &input) { + llvm::SmallVector<int8_t> output; + output.reserve(input.size()); + + for (auto v : llvm::map_range( + input, [](int32_t val) { return static_cast<int8_t>(val); })) { + output.push_back(v); + } + return output; +} + +// The shift or multiplier may be either constant or non-constant, depending on +// whether dynamic extension is enabled. +// - If the shift or multiplier is non-constant, add it as an input to +// linalg::GenericOp by: +// 1. Pushing it into 'genericInputs'. +// 2. Appending a corresponding affine map to 'indexingMaps'. +// - If the shift or multiplier is constant, set 'constant' instead. +static void setupLinalgGenericOpInputAndIndexingMap( + PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values, + SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps, + bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg, + bool isShift = false) { + + auto loc = op.getLoc(); + auto inputTy = cast<ShapedType>(op.getInput().getType()); + unsigned rank = inputTy.getRank(); + SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)}; + + if (isConstant) { + // If we are rescaling per-channel then we need to store the + // values in a buffer. + if (values.size() == 1) { + IntegerAttr intAttr = isShift + ? rewriter.getI8IntegerAttr(values.front()) + : rewriter.getI32IntegerAttr(values.front()); + constant = rewriter.create<arith::ConstantOp>(loc, intAttr); + } else { + auto elementType = + isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type(); + auto tensorType = RankedTensorType::get( + {static_cast<int64_t>(values.size())}, elementType); + DenseIntElementsAttr EltAttr; + if (isShift) + EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values)); + else + EltAttr = DenseIntElementsAttr::get(tensorType, values); + genericInputs.push_back( + arith::ConstantOp::create(rewriter, loc, EltAttr)); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } else { + // If we are not rescaling per-channel then we need to collapse 1xN to N + // and push broadcastMap. + auto operand = isShift ? op.getShift() : op.getMultiplier(); + auto tensorType = dyn_cast<RankedTensorType>(operand.getType()); + if (tensorType && tensorType.hasStaticShape() && + tensorType.getShape()[0] == 1) { + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = + AffineMap::get(rank, 0, {}, rewriter.getContext()); + genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc)); + indexingMaps.push_back(broadcastMap); + } else { + genericInputs.push_back(operand); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } + arg = indexingMaps.size() - 1; +} + +// Return the extended Zp to be used in subsequent arithmetic operations. +static Value getExtendZp(OpBuilder &builder, Type valueTy, + FailureOr<int64_t> maybeZp, Location loc, + ValueRange blockArgs, int64_t zpArg, + bool isOutputZp = false) { + Value result; + const int32_t bitwidth = valueTy.getIntOrFloatBitWidth(); + const uint32_t attrBitwidth = + isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32); + auto extendType = builder.getIntegerType(attrBitwidth); + // The Zp value can be either constant or non-constant, depending on + // whether dynamic extension is enabled. + // If 'maybeZp' fails, it indicates that Zp is non-constant and will + // be passed as an input to linalg::GenericOp. + if (failed(maybeZp)) { + result = blockArgs[zpArg]; + auto zpTy = result.getType(); + if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) { + // For ExtUIOp, the input must be signless. + // UnrealizedConversionCastOp will cast the input to signless type. + if (zpTy.isUnsignedInteger()) { + result = + UnrealizedConversionCastOp::create( + builder, loc, + builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result) + .getResult(0); + } + if (zpTy.isUnsignedInteger()) { + return builder.create<arith::ExtUIOp>(loc, extendType, result); + } else { + return builder.create<arith::ExtSIOp>(loc, extendType, result); + } + } + } else { + return builder.create<arith::ConstantOp>( + loc, IntegerAttr::get(extendType, *maybeZp)); + } + return result; +} + class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { public: using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; @@ -1423,40 +1554,46 @@ public: } } - // The shift and multiplier values. DenseElementsAttr shiftElems; - if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant shift input values"); + bool isShiftConstant = false; + if (matchPattern(op.getShift(), m_Constant(&shiftElems))) + isShiftConstant = true; DenseElementsAttr multiplierElems; - if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant multiplier input values"); - - llvm::SmallVector<int8_t> shiftValues = - llvm::to_vector(shiftElems.getValues<int8_t>()); - // explicit cast is required here - llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector( - llvm::map_range(multiplierElems.getValues<IntegerAttr>(), - [](IntegerAttr attr) -> int32_t { - return static_cast<int32_t>(attr.getInt()); - })); - - // If we shift by more than the bitwidth, this just sets to 0. - for (int i = 0, s = multiplierValues.size(); i < s; i++) { - if (shiftValues[i] > 63) { - shiftValues[i] = 0; - multiplierValues[i] = 0; + bool isMultiplierConstant = false; + if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) + isMultiplierConstant = true; + + llvm::SmallVector<int32_t> shiftValues; + llvm::SmallVector<int32_t> multiplierValues; + bool doubleRound; + + if (isMultiplierConstant && isShiftConstant) { + // explicit cast is required here + shiftValues = llvm::to_vector(llvm::map_range( + shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + multiplierValues = llvm::to_vector( + llvm::map_range(multiplierElems.getValues<IntegerAttr>(), + [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + + // If we shift by more than the bitwidth, this just sets to 0. + for (int i = 0, s = multiplierValues.size(); i < s; i++) { + if (shiftValues[i] > 63) { + shiftValues[i] = 0; + multiplierValues[i] = 0; + } } - } + // Double round only occurs if shift is greater than 31, check that this + // is ever true. + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && + llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); + } else + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND; - // Double round only occurs if shift is greater than 31, check that this - // is ever true. - - bool doubleRound = - op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && - llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); RoundingMode roundingMode = doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND; @@ -1468,45 +1605,43 @@ public: // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; - if (multiplierValues.size() == 1) { - multiplierConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); - } else { - SmallVector<AffineExpr, 2> multiplierExprs{ - rewriter.getAffineDimExpr(rank - 1)}; - auto multiplierType = - RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, - rewriter.getI32Type()); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, - DenseIntElementsAttr::get(multiplierType, multiplierValues))); - - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, multiplierExprs, - rewriter.getContext())); - - multiplierArg = indexingMaps.size() - 1; - } + setupLinalgGenericOpInputAndIndexingMap( + rewriter, multiplierValues, genericInputs, indexingMaps, + isMultiplierConstant, op, multiplierConstant, multiplierArg); // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; - if (shiftValues.size() == 1) { - shiftConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); - } else { - SmallVector<AffineExpr, 2> shiftExprs = { - rewriter.getAffineDimExpr(rank - 1)}; - auto shiftType = - RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, - rewriter.getIntegerType(8)); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, shiftExprs, - rewriter.getContext())); - shiftArg = indexingMaps.size() - 1; + setupLinalgGenericOpInputAndIndexingMap( + rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op, + shiftConstant, shiftArg, true); + + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext()); + FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); + FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); + // The inputZp and outputZp may be either constant or non-constant, + // depending on whether dynamic extension is enabled. + // - If the zp's are non-constant, add them as an inputs to + // linalg::GenericOp by: + // 1. Pushing it into 'genericInputs'. + // 2. Appending a corresponding affine map to 'indexingMaps'. + // - If the zp's are constant, they would be generated as arith.constant. + int64_t iZpArg = 0; + if (failed(maybeIZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(3), loc)); + indexingMaps.push_back(broadcastMap); + iZpArg = indexingMaps.size() - 1; + } + int64_t oZpArg = 0; + if (failed(maybeOZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(4), loc)); + indexingMaps.push_back(broadcastMap); + oZpArg = indexingMaps.size() - 1; } // Indexing maps for output values. @@ -1526,36 +1661,17 @@ public: Type valueTy = value.getType(); FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); - if (failed(maybeIZp)) { - (void)rewriter.notifyMatchFailure( - op, "input zero point cannot be statically determined"); - return; - } - - const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); - // Extend zeropoint for sub-32bits widths. - const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), - *maybeIZp)); + auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp, + nestedLoc, blockArgs, iZpArg); FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); - if (failed(maybeOZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return; - }; + auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp, + nestedLoc, blockArgs, oZpArg, true); IntegerType outIntType = cast<IntegerType>(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); - const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), - *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5355909..41d8d53 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle`. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + // For 1-d vector, we additionally do a `shufflevector`. int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0); SmallVector<int32_t> zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison, - zeroValues); + auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>( + broadcast.getLoc(), v, poison, zeroValues); + rewriter.replaceOp(broadcast, shuffle); return success(); } }; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 71687b1..ddcbc44 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" @@ -390,7 +391,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // Load result or Store valye Type can be vector or scalar. Type valOrResTy; if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) - valOrResTy = op.getResult().getType(); + valOrResTy = + this->getTypeConverter()->convertType(op.getResult().getType()); else valOrResTy = adaptor.getValue().getType(); VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); @@ -878,10 +880,30 @@ struct ConvertXeGPUToXeVMPass } return {}; }; - typeConverter.addSourceMaterialization(memrefMaterializationCast); - typeConverter.addSourceMaterialization(ui64MaterializationCast); - typeConverter.addSourceMaterialization(ui32MaterializationCast); - typeConverter.addSourceMaterialization(vectorMaterializationCast); + + // If result type of original op is single element vector and lowered type + // is scalar. This materialization cast creates a single element vector by + // broadcasting the scalar value. + auto singleElementVectorMaterializationCast = + [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType().isIntOrIndexOrFloat()) { + // If the input is a scalar, and the target type is a vector of single + // element, create a single element vector by broadcasting. + if (auto vecTy = dyn_cast<VectorType>(type)) { + if (vecTy.getNumElements() == 1) { + return vector::BroadcastOp::create(builder, loc, vecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization( + singleElementVectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 68990ef..d9c097c 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } +/// Returns stride expressed in number of bytes for the given `elementStride` +/// stride encoded in number of elements of the type `mType`. +static Value computeStrideInBytes(Location loc, MemRefType mType, + Value elementStride, RewriterBase &rewriter) { + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) + .getResult(); +} + /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -static Value getStride(Location loc, MemRefType mType, Value base, - RewriterBase &rewriter) { +static Value inferStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = rewriter.getIntegerType(64); @@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base, if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); + return computeStrideInBytes( + loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); @@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands, return getTileSizes(getLoc(), getTileType(), rewriter); } -LogicalResult amx::TileLoadOp::verify() { - MemRefType memrefTy = getMemRefType(); +template <typename OpTy, + typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> || + std::is_same_v<OpTy, amx::TileStoreOp>>> +static LogicalResult tileTransferVerifier(OpTy op) { + MemRefType memrefTy = op.getMemRefType(); unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); + if (op.getIndices().size() != rank) + return op.emitOpError("requires ") << rank << " indices"; + + if (failed(verifyTileSize(op, op.getTileType()))) + return failure(); + + // Validate basic buffer properties when the stride is implicit. + if (!op.getStride()) { + if (rank < 2) + return op.emitOpError("requires at least 2D memref"); + SmallVector<int64_t> strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return op.emitOpError("requires memref with unit innermost stride"); + } + + return success(); +} + +void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res, + Value base, ValueRange indices) { + build(builder, state, res, base, indices, /*stride=*/nullptr); } +LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); return intrinsicOperands; } -LogicalResult amx::TileStoreOp::verify() { - MemRefType memrefTy = getMemRefType(); - unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); +void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange indices, Value val) { + build(builder, state, base, indices, val, /*stride=*/nullptr); } +LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); intrinsicOperands.push_back(adaptor.getVal()); return intrinsicOperands; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index 624519f..70faa71 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -64,12 +64,13 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { module.walk([&](func::CallOp callOp) { if (func::FuncOp calledFunc = dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { - callerMap[calledFunc].insert(callOp); + if (!calledFunc.isPublic() && !calledFunc.isExternal()) + callerMap[calledFunc].insert(callOp); } }); for (auto funcOp : module.getOps<func::FuncOp>()) { - if (funcOp.isExternal()) + if (funcOp.isExternal() || funcOp.isPublic()) continue; func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); // TODO: Support functions with multiple blocks. diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ec581ac..cc66fac 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect IR/LLVMMemorySlot.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp + IR/LLVMDialectBytecode.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR DEPENDS + MLIRLLVMDialectBytecodeIncGen MLIRLLVMOpsIncGen MLIRLLVMTypesIncGen MLIRLLVMIntrinsicOpsIncGen diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5d08ccc..3eae67f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -29,6 +29,8 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/Error.h" +#include "LLVMDialectBytecode.h" + #include <numeric> #include <optional> @@ -2824,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() { return success(); } +// Folding for shufflevector op when v1 is single element 1D vector +// and the mask is a single zero. OpFoldResult will be v1 in this case. +OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) { + // Check if operand 0 is a single element vector. + auto vecType = llvm::dyn_cast<VectorType>(getV1().getType()); + if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1) + return {}; + // Check if the mask is a single zero. + // Note: The mask is guaranteed to be non-empty. + if (getMask().size() != 1 || getMask()[0] != 0) + return {}; + return getV1(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// @@ -4237,6 +4253,7 @@ void LLVMDialect::initialize() { // Support unknown operations because not all LLVM operations are registered. allowUnknownOperations(); declarePromisedInterface<DialectInlinerInterface, LLVMDialect>(); + detail::addBytecodeInterface(this); } #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp new file mode 100644 index 0000000..41d1f80 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp @@ -0,0 +1,154 @@ +//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMDialectBytecode.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <type_traits> + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +// Provide some forward declarations of the functions that will be generated by +// the include below. +static void write(DIExpressionElemAttr attribute, + DialectBytecodeWriter &writer); +static LogicalResult writeAttribute(Attribute attribute, + DialectBytecodeWriter &writer); + +//===--------------------------------------------------------------------===// +// Optional ArrayRefs +// +// Note that both the writer and reader functions consider attributes to be +// optional. This is because the attribute may be present or empty. +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalArrayRef(DialectBytecodeWriter &writer, + ArrayRef<EntryTy> storage) { + if (storage.empty()) { + writer.writeOwnedBool(false); + return; + } + + writer.writeOwnedBool(true); + writer.writeList(storage, [&](EntryTy val) { + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + (void)writer.writeOptionalAttribute(val); + } else if constexpr (std::is_integral_v<EntryTy>) { + (void)writer.writeVarInt(val); + } else { + static_assert(true, "EntryTy not supported"); + } + }); +} + +template <class EntryTy> +static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader, + SmallVectorImpl<EntryTy> &storage) { + bool isPresent = false; + if (failed(reader.readBool(isPresent))) + return failure(); + // Nothing to do here, the array is empty. + if (!isPresent) + return success(); + + auto readEntry = [&]() -> FailureOr<EntryTy> { + EntryTy temp; + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + if (succeeded(reader.readOptionalAttribute(temp))) + return temp; + } else if constexpr (std::is_integral_v<EntryTy>) { + if (succeeded(reader.readVarInt(temp))) + return temp; + } else { + static_assert(true, "EntryTy not supported"); + } + return failure(); + }; + + return reader.readList(storage, readEntry); +} + +//===--------------------------------------------------------------------===// +// Optional integral types +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalInt(DialectBytecodeWriter &writer, + std::optional<EntryTy> storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + EntryTy val = storage.value_or(0); + writer.writeVarIntWithFlag(val, storage.has_value()); +} + +template <class EntryTy> +static LogicalResult readOptionalInt(DialectBytecodeReader &reader, + std::optional<EntryTy> &storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + uint64_t result = 0; + bool flag = false; + if (failed(reader.readVarIntWithFlag(result, flag))) + return failure(); + if (flag) + storage = static_cast<EntryTy>(result); + else + storage = std::nullopt; + return success(); +} + +//===--------------------------------------------------------------------===// +// Tablegen generated bytecode functions +//===--------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc" + +//===--------------------------------------------------------------------===// +// LLVMDialectBytecodeInterface +//===--------------------------------------------------------------------===// + +/// This class implements the bytecode interface for the LLVM dialect. +struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface { + LLVMDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + // Attributes + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } + + // Types + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } +}; +} // namespace + +void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) { + dialect->addInterfaces<LLVMDialectBytecodeInterface>(); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h new file mode 100644 index 0000000..1a17cb4 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h @@ -0,0 +1,27 @@ +//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the LLVM dialect bytecode +// implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H +#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H + +namespace mlir::LLVM { +class LLVMDialect; + +namespace detail { +/// Add the interfaces necessary for encoding the LLVM dialect components in +/// bytecode. +void addBytecodeInterface(LLVMDialect *dialect); +} // namespace detail +} // namespace mlir::LLVM + +#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 5edcc40b..2a8c330 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF4x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f32x2 to f4x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() { " attribute"); } + // Validate layout combinations. According to the operation description, most + // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 + // can use other layout combinations. + bool isM8N8K4_F16 = + (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && + getMultiplicandAPtxType() == MMATypes::f16); + + if (!isM8N8K4_F16) { + // For all other shapes/types, layoutA must be row and layoutB must be col + if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { + return emitOpError("requires layoutA = #nvvm.mma_layout<row> and " + "layoutB = #nvvm.mma_layout<col> for shape <") + << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] + << "> with element types " + << stringifyEnum(*getMultiplicandAPtxType()) << " and " + << stringifyEnum(*getMultiplicandBPtxType()) + << ". Only m8n8k4 with f16 supports other layouts."; + } + } + return success(); } @@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair +ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getA())); + args.push_back(mt.lookupValue(op.getB())); + + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = + hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite; + + return {intId, std::move(args)}; +} + #define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite @@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result, } } +/// Verify the range attribute satisfies LLVM ConstantRange constructor +/// requirements for NVVM SpecialRangeableRegisterOp. +static LogicalResult +verifyConstantRangeAttr(Operation *op, + std::optional<LLVM::ConstantRangeAttr> rangeAttr) { + if (!rangeAttr) + return success(); + + const llvm::APInt &lower = rangeAttr->getLower(); + const llvm::APInt &upper = rangeAttr->getUpper(); + + // Check LLVM ConstantRange constructor condition + if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) { + unsigned bitWidth = lower.getBitWidth(); + llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth); + llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth); + return op->emitOpError( + "invalid range attribute: Lower == Upper, but they aren't min (") + << llvm::toString(minVal, 10, false) << ") or max (" + << llvm::toString(maxVal, 10, false) + << ") value! This is an invalid constant range."; + } + + return success(); +} + static llvm::Value *getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder) { return builder.CreateBitCast(arg, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index c477c6c..dcc1ef9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody( Value yielded = getSourceSkipUnary(terminator->getOperand(0)); Operation *reductionOp = yielded.getDefiningOp(); - if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + if (!reductionOp || reductionOp->getNumResults() != 1 || + reductionOp->getNumOperands() != 2) { errs << "expected reduction op to be binary"; return false; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 59013a2..cbc565b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() { SmallVector<int64_t> PackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getDestType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index dd9b4c2..d8f983f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( // FuseOp //===----------------------------------------------------------------------===// +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, loopTypes, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + // Loop types are automaticaly splat by the callee, setting up one is + // enough. + SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>()); + build(builder, result, loopTypes, target, mixedTileSizes, + mixedTileInterchange, applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + SmallVector<int64_t> staticTileSizes; + SmallVector<Value> dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); + SmallVector<int64_t> staticTileInterchange; + SmallVector<Value> dynamicTileInterchange; + dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange, + staticTileInterchange); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + auto staticTileInterchangeAttr = + builder.getDenseI64ArrayAttr(staticTileInterchange); + unsigned numExpectedLoops = + useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0); + SmallVector<Type> resultTypes; + resultTypes.reserve(numExpectedLoops); + assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && + "expected one loop type or as many as loops"); + if (loopTypes.size() == 1) + resultTypes.append(numExpectedLoops, loopTypes[0]); + else + llvm::append_range(resultTypes, loopTypes); + build(builder, result, /*transformed=*/target.getType(), + /*loops=*/resultTypes, + /*target=*/target, + /*tile_sizes=*/dynamicTileSizes, + /*tile_interchange=*/dynamicTileInterchange, + /*static_tile_sizes=*/staticTileSizesAttr, + /*static_tile_interchange=*/staticTileInterchangeAttr, + /*apply_cleanup=*/applyCleanup, + /*use_forall=*/useForall); +} + /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template <typename Range> @@ -630,13 +710,25 @@ DiagnosedSilenceableFailure transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - SmallVector<int64_t> tileSizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - SmallVector<int64_t> tileInterchange = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); + auto transformOp = cast<TransformOpInterface>(getOperation()); + + SmallVector<int64_t> tileSizes; + DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileSizes(), tileSizes); + if (!status.succeeded()) + return status; + SmallVector<int64_t> tileInterchange; + status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileInterchange(), tileInterchange); + if (!status.succeeded()) + return status; scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; + bool useForall = getUseForall(); + tilingOptions.setLoopType(useForall + ? scf::SCFTilingOptions::LoopType::ForallOp + : scf::SCFTilingOptions::LoopType::ForOp); SmallVector<OpFoldResult> tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); @@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } + size_t numLoops = + useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); LogicalResult result = applyTilingToAll( - rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizes.size() - llvm::count(tileSizes, 0), transformResults, + rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, + transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr<scf::SCFTileAndFuseResult> { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { - SmallVector<int64_t> permutation = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); - auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); - if (!std::is_permutation(sequence.begin(), sequence.end(), - permutation.begin(), permutation.end())) { - return emitOpError() << "expects interchange to be a permutation, found " - << getTileInterchange(); + auto iterspace_rank = getStaticTileSizes().size(); + ArrayRef<int64_t> permutation = getStaticTileInterchange(); + if (permutation.size() > iterspace_rank) + return emitOpError() + << "interchange length exceeds iteration space dimensions (" + << iterspace_rank << "), found " << getTileInterchange(); + SmallVector<bool> seen(iterspace_rank, false); + for (int64_t v : permutation) { + if (!ShapedType::isDynamic(v)) { + if (v < 0 || v >= static_cast<int64_t>(iterspace_rank)) + return emitOpError() << "expects interchange values to be in range [0, " + << iterspace_rank << "), found: " << v; + if (seen[v]) + return emitOpError() << "found duplicate interchange value: " << v; + seen[v] = true; + } } - SmallVector<int64_t> sizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); + ArrayRef<int64_t> sizes = getStaticTileSizes(); + size_t numExpectedLoops = + getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; return success(); } +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() { + return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext()); +} + +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() { + return getMixedValues(getStaticTileInterchange(), getTileInterchange(), + getContext()); +} + +void transform::FuseOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getTileInterchangeMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 0dac688..eb2d825 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { - // TODO: support the case that outer dimensions are not all 1s. A - // tensor.expand_shape will be generated in this case. - if (llvm::any_of(packOp.getAllOuterDims(), + if (llvm::any_of(packOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "not all outer dimensions of the result are 1s"); } + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + + // Verify that there are no: + // * non-unit + un-tiled-outer-dims, + // that are permuted. Supporting such cases would require refining the logic + // that generates the Transpose Op. + if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) { + static int prev = 0; + // Skip tiled dims - these can be permuted. + if (llvm::is_contained(innerDimsPos, dim)) + return true; + + // Check whether this dim has been permuted. Permuting unit dims is fine + // as that's effectively a no-op. + if (dim < prev && (packOp.getType().getShape()[prev] != 1 || + packOp.getType().getShape()[dim] != 1)) + return false; + + prev = dim; + return true; + })) { + return rewriter.notifyMatchFailure( + packOp, "At least one non-unit and un-tiled outer dim is permuted, " + "this is not supported ATM!"); + } + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); - int64_t numberOfTiles = innerDimsPos.size(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // - All outer dims are 1 - the corresponding transposition order doesn't - // matter, but requires all dim indices to be present. + // - All tiled outer dims are 1 - the corresponding transposition order + // doesn't matter, but requires all dim indices to be present. + // - Un-tiled outer dims remain un-permuted. - // 2.1 Get the permutation for linalg.transpose + // 2.1 Get the permutation for linalg.transpose: + // [ untiled-dims, inner-dims-pos ] + // Note, this logic assumes that the untiled dims are not permuted. SmallVector<int64_t> srcPermForTranspose; for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the @@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); - // 2.2 Create the init tensor for linalg.transpose with the correct shape - SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles, - oneIdxAttr); + // 2.2 Create the init tensor for linalg.transpose with the correct shape: + // [ untiled-dims, tiled-dims ] + ShapedType inputTy = cast<ShapedType>(input.getType()); + SmallVector<OpFoldResult> shapeForEmptyOp; + for (int64_t i = 0; i < srcRank; i++) { + if (llvm::is_contained(innerDimsPos, i)) { + // The tiled dims are appended after this loop. + continue; + } + if (inputTy.isStaticDim(i)) + shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i])); + else + shapeForEmptyOp.emplace_back( + tensor::DimOp::create(rewriter, loc, input, i).getResult()); + } shapeForEmptyOp.append(packOp.getMixedTiles()); // getMixedTiles() may contain Values pointing to constant ops, not the @@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); - // 3. Insert the inner tile to the destination: + // 3. Insert the inner tile into the destination tensor: // %inserted_tile = tensor.insert_slice(%transposed_tile) - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - // Outer dims are all 1s! - SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr); - SmallVector<int64_t> writeShape; + + // Compute the sizes attribute: + // [ outer-dims, tile-sizes ] + // Note that the output from the transpose Op excludes the tiled outer dims. + // However, given the assumption that: + // * all tiled outer dims == 1, + // we can just use a rank-expanding tensor.insert_slice. + SmallVector<OpFoldResult> writeSizes; + for (auto size : packOp.getAllOuterDims()) { + writeSizes.push_back(rewriter.getIndexAttr(size)); + } for (auto tileSize : packOp.getMixedTiles()) { - auto [tileSizeStatic, tileSizeOfr] = + auto [_, tileSizeOfr] = getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); writeSizes.push_back(tileSizeOfr); - writeShape.push_back(tileSizeStatic); } - // 4. Replace tensor.packOp with tensor.insert_slice created above + // TODO: Add a constructor for tensor.insert_slice that doesn't require + // strides nor offsets. + SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); + auto insert = tensor::InsertSliceOp::create( rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); + + // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index e25a012..1382c7ac 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR DEPENDS MLIRMemRefOpsIncGen @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda..507597b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } +void SubViewOp::inferStridedMetadataRanges( + ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange, + SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { + auto isUninitialized = + +[](IntegerValueRange range) { return range.isUninitialized(); }; + + // Bail early if any of the operands metadata is not ready: + SmallVector<IntegerValueRange> offsetOperands = + getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); + if (llvm::any_of(offsetOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> sizeOperands = + getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); + if (llvm::any_of(sizeOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> stridesOperands = + getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); + if (llvm::any_of(stridesOperands, isUninitialized)) + return; + + StridedMetadataRange sourceRange = + ranges[getSourceMutable().getOperandNumber()]; + if (sourceRange.isUninitialized()) + return; + + ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides(); + + // Get the dropped dims. + llvm::SmallBitVector droppedDims = getDroppedDims(); + + // Compute the new offset, strides and sizes. + ConstantIntRanges offset = sourceRange.getOffsets()[0]; + SmallVector<ConstantIntRanges> strides, sizes; + + for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { + bool dropped = droppedDims.test(i); + // Compute the new offset. + ConstantIntRanges off = + intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); + offset = intrange::inferAdd({offset, off}); + + // Skip dropped dimensions. + if (dropped) + continue; + // Multiply the strides. + strides.push_back( + intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); + // Get the sizes. + sizes.push_back(sizeOperands[i].getValue()); + } + + setMetadata(getResult(), + StridedMetadataRange::getRanked( + SmallVector<ConstantIntRanges>({std::move(offset)}), + std::move(sizes), std::move(strides))); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6564a4e..642ced9 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" @@ -74,14 +75,16 @@ struct MemRefPointerLikeModel } mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, - StringRef varName, Type varType, - Value originalVar) const { + StringRef varName, Type varType, Value originalVar, + bool &needsFree) const { auto memrefTy = cast<MemRefType>(pointer); // Check if this is a static memref (all dimensions are known) - if yes // then we can generate an alloca operation. - if (memrefTy.hasStaticShape()) + if (memrefTy.hasStaticShape()) { + needsFree = false; // alloca doesn't need deallocation return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + } // For dynamic memrefs, extract sizes from the original variable if // provided. Otherwise they cannot be handled. @@ -99,6 +102,7 @@ struct MemRefPointerLikeModel // Note: We only add dynamic sizes to the dynamicSizes array // Static dimensions are handled automatically by AllocOp } + needsFree = true; // alloc needs deallocation return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) .getResult(); } @@ -108,10 +112,14 @@ struct MemRefPointerLikeModel } bool genFree(Type pointer, OpBuilder &builder, Location loc, - TypedValue<PointerLikeType> varPtr, Type varType) const { - if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + TypedValue<PointerLikeType> varToFree, Value allocRes, + Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) { + // Use allocRes if provided to determine the allocation type + Value valueToInspect = allocRes ? allocRes : memrefValue; + // Walk through casts to find the original allocation - Value currentValue = memrefValue; + Value currentValue = valueToInspect; Operation *originalAlloc = nullptr; // Follow the chain of operations to find the original allocation @@ -150,7 +158,7 @@ struct MemRefPointerLikeModel return true; } if (isa<memref::AllocOp>(originalAlloc)) { - // This is an alloc - generate dealloc + // This is an alloc - generate dealloc on varToFree memref::DeallocOp::create(builder, loc, memrefValue); return true; } @@ -1003,6 +1011,142 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +//===----------------------------------------------------------------------===// +// Recipe Region Helpers +//===----------------------------------------------------------------------===// + +/// Create and populate an init region for privatization recipes. +/// Returns the init block on success, or nullptr on failure. +/// Sets needsFree to indicate if the allocated memory requires deallocation. +static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc, + Type varType, StringRef varName, + ValueRange bounds, + bool &needsFree) { + // Create init block with arguments: original value + bounds + SmallVector<Type> argTypes{varType}; + SmallVector<Location> argLocs{loc}; + for (Value bound : bounds) { + argTypes.push_back(bound.getType()); + argLocs.push_back(loc); + } + + auto initBlock = std::make_unique<Block>(); + initBlock->addArguments(argTypes, argLocs); + builder.setInsertionPointToStart(initBlock.get()); + + Value privatizedValue; + + // Get the block argument that represents the original variable + Value blockArgVar = initBlock->getArgument(0); + + // Generate init region body based on variable type + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + auto typedVar = cast<TypedValue<MappableType>>(blockArgVar); + privatizedValue = mappableTy.generatePrivateInit( + builder, loc, typedVar, varName, bounds, {}, needsFree); + if (!privatizedValue) + return nullptr; + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + // Use PointerLikeType's allocation API with the block argument + privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType, + blockArgVar, needsFree); + if (!privatizedValue) + return nullptr; + } + + // Add yield operation to init block + acc::YieldOp::create(builder, loc, privatizedValue); + + return initBlock; +} + +/// Create and populate a copy region for firstprivate recipes. +/// Returns the copy block on success, or nullptr on failure. +/// TODO: Handle MappableType - it does not yet have a copy API. +static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc, + Type varType, + ValueRange bounds) { + // Create copy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> copyArgTypes{varType, varType}; + SmallVector<Location> copyArgLocs{loc, loc}; + for (Value bound : bounds) { + copyArgTypes.push_back(bound.getType()); + copyArgLocs.push_back(loc); + } + + auto copyBlock = std::make_unique<Block>(); + copyBlock->addArguments(copyArgTypes, copyArgLocs); + builder.setInsertionPointToStart(copyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a copy API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + // Generate copy region body based on variable type + if (isPointerLike) { + auto pointerLikeTy = cast<PointerLikeType>(varType); + Value originalArg = copyBlock->getArgument(0); + Value privatizedArg = copyBlock->getArgument(1); + + // Generate copy operation using PointerLikeType interface + if (!pointerLikeTy.genCopy( + builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg), + cast<TypedValue<PointerLikeType>>(originalArg), varType)) + return nullptr; + } + + // Add terminator to copy block + acc::TerminatorOp::create(builder, loc); + + return copyBlock; +} + +/// Create and populate a destroy region for privatization recipes. +/// Returns the destroy block on success, or nullptr if not needed. +static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder, + Location loc, Type varType, + Value allocRes, + ValueRange bounds) { + // Create destroy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> destroyArgTypes{varType, varType}; + SmallVector<Location> destroyArgLocs{loc, loc}; + for (Value bound : bounds) { + destroyArgTypes.push_back(bound.getType()); + destroyArgLocs.push_back(loc); + } + + auto destroyBlock = std::make_unique<Block>(); + destroyBlock->addArguments(destroyArgTypes, destroyArgLocs); + builder.setInsertionPointToStart(destroyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a deallocation API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + auto privatizedArg = + cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1)); + // Pass allocRes to help determine the allocation type + if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType)) + return nullptr; + + acc::TerminatorOp::create(builder, loc); + + return destroyBlock; +} + } // namespace //===----------------------------------------------------------------------===// @@ -1050,6 +1194,55 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() { return success(); } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1080,6 +1273,60 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { return success(); } +std::optional<FirstprivateRecipeOp> +FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init, copy, and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + auto copyBlock = createCopyRegion(builder, loc, varType, bounds); + if (!copyBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + recipe.getCopyRegion().push_back(copyBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // ReductionRecipeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fa97b49..ac72002 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType( sourceTensorType.getEncoding()); } +// TODO: This uses neither offsets nor strides! RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 58256b0..45c54c7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), result); } +namespace { + +/// Fold `vector.step -> arith.cmpi` when the step value is compared to a +/// constant large enough such that the result is the same at all indices. +/// +/// For example, rewrite the 'greater than' comparison below, +/// +/// ```mlir +/// %cst = arith.constant dense<7> : vector<3xindex> +/// %stp = vector.step : vector<3xindex> +/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> +/// ``` +/// +/// as, +/// +/// ```mlir +/// %out = arith.constant dense<false> : vector<3xi1>. +/// ``` +/// +/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result +/// is false at ALL indices we fold. If the constant was 1, then +/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold, +/// conservatively preferring the 'compact' vector.step representation. +/// +/// Note: this folder only works for the case where the constant (`%cst` above) +/// is the second operand of the comparison. The arith.cmpi canonicalizer will +/// ensure that constants are always second (on the right). +struct StepCompareFolder : public OpRewritePattern<StepOp> { + using Base::Base; + + LogicalResult matchAndRewrite(StepOp stepOp, + PatternRewriter &rewriter) const override { + const int64_t stepSize = stepOp.getResult().getType().getNumElements(); + + for (OpOperand &use : stepOp.getResult().getUses()) { + auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner()); + if (!cmpiOp) + continue; + + // arith.cmpi canonicalizer makes constants final operands. + const unsigned stepOperandNumber = use.getOperandNumber(); + if (stepOperandNumber != 0) + continue; + + // Check that operand 1 is a constant. + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); + std::optional<int64_t> maybeConstValue = + getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; + + int64_t constValue = maybeConstValue.value(); + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional<bool> { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; + + // Handle ule and ugt. + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize - 1 <= constValue) { + return pred == arith::CmpIPredicate::ule; + } + + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; + + return std::nullopt; + }(); + + if (!maybeSplat.has_value()) + continue; + + rewriter.setInsertionPointAfter(cmpiOp); + + auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType()); + if (!type) + continue; + + auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); + + rewriter.replaceOp(cmpiOp, splat); + return success(); + } + + return failure(); + } +}; +} // namespace + +void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<StepCompareFolder>(context); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e95338f..12e6475 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern { // Some values may be yielded multiple times and correspond to multiple // results. Deduplicating occurs by taking each result with its matching // yielded value, and: - // 1. recording the unique first position at which the value is yielded. + // 1. recording the unique first position at which the value with uses is + // yielded. // 2. recording for the result, the first position at which the dedup'ed // value is yielded. // 3. skipping from the new result types / new yielded values any result // that has no use or whose yielded value has already been seen. for (OpResult result : warpOp.getResults()) { + if (result.use_empty()) + continue; Value yieldOperand = yield.getOperand(result.getResultNumber()); auto it = dedupYieldOperandPositionMap.insert( std::make_pair(yieldOperand, newResultTypes.size())); dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); - if (result.use_empty() || !it.second) + if (!it.second) continue; newResultTypes.push_back(result.getType()); newYieldValues.push_back(yieldOperand); @@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), escapingValueDistTypesElse.end()); - llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; for (auto [idx, val] : llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { - origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(val); newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); } - // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + // Replace the old `WarpOp` with the new one that has additional yield + // values and types. + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // `ifOp` returns the result of the inner warp op. SmallVector<Type> newIfOpDistResTypes; for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { @@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newIfOp = scf::IfOp::create( - rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), - static_cast<bool>(ifOp.thenBlock()), + rewriter, ifOp.getLoc(), newIfOpDistResTypes, + newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()), static_cast<bool>(ifOp.elseBlock())); auto encloseRegionInWarpOp = [&](Block *oldIfBranch, Block *newIfBranch, @@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { innerWarpInputVals.push_back( - newWarpOp.getResult(warpResRangeStart)); + newWarpOp.getResult(newIndices[warpResRangeStart])); escapeValToBlockArgIndex[escapingValues[i]] = innerWarpInputTypes.size(); innerWarpInputTypes.push_back(escapingValueInputTypes[i]); @@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newIfOp.getResult(newIdx), newIfOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `IfOp`. - for (auto [origIdx, newIdx] : origToNewYieldIdx) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `IfOp`, they should not have any uses - // at this point. - rewriter.eraseOp(ifOp); - rewriter.eraseOp(warpOp); return success(); } @@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { escapingValueDistTypes.begin(), escapingValueDistTypes.end()); // Next, we insert all non-`ForOp` yielded values and their distributed - // types. We also create a mapping between the non-`ForOp` yielded value - // index and the corresponding new `WarpOp` yield value index (needed to - // update users later). - llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping; + // types. for (auto [i, v] : llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { - nonForResultMapping[i] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); } // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. @@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // escaping values in the new `WarpOp`. SmallVector<Value> newForOpOperands; for (size_t i = 0; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); + newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); @@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { llvm::SmallDenseMap<Value, int64_t> argIndexMapping; for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { - innerWarpInput.push_back(newWarpOp.getResult(i)); + innerWarpInput.push_back(newWarpOp.getResult(newIndices[i])); argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = innerWarpInputType.size(); innerWarpInputType.push_back( @@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (!innerWarp.getResults().empty()) scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); - // Update the users of original `WarpOp` results that were coming from the + // Update the users of the new `WarpOp` results that were coming from the // original `ForOp` to the corresponding new `ForOp` result. for (auto [origIdx, newIdx] : forResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `ForOp`. - for (auto [origIdx, newIdx] : nonForResultMapping) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `ForOp`, they should not have any uses - // at this point. - rewriter.eraseOp(forOp); - rewriter.eraseOp(warpOp); // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 14639c5..fbae098 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern { auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); + int64_t targetShapeRank = targetShape->size(); auto dstVecType = cast<VectorType>(op->getResult(0).getType()); SmallVector<int64_t> originalSize = *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); - // Bail-out if rank(source) != rank(target). The main limitation here is the - // fact that `ExtractStridedSlice` requires the rank for the input and - // output to match. If needed, we can relax this later. - if (originalSize.size() != targetShape->size()) - return rewriter.notifyMatchFailure( - op, "expected input vector rank to match target shape rank"); + int64_t originalShapeRank = originalSize.size(); + Location loc = op->getLoc(); + + // Handle rank mismatch by adding leading unit dimensions to targetShape + SmallVector<int64_t> adjustedTargetShape(originalShapeRank); + int64_t rankDiff = originalShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), + adjustedTargetShape.begin() + rankDiff, 1); + std::copy(targetShape->begin(), targetShape->end(), + adjustedTargetShape.begin() + rankDiff); + + int64_t adjustedTargetShapeRank = adjustedTargetShape.size(); // Prepare the result vector. Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - SmallVector<int64_t> strides(targetShape->size(), 1); - VectorType newVecType = + SmallVector<int64_t> strides(adjustedTargetShapeRank, 1); + VectorType unrolledVecType = VectorType::get(*targetShape, dstVecType.getElementType()); // Create the unrolled computation. for (SmallVector<int64_t> offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { + StaticTileOffsetRange(originalSize, adjustedTargetShape)) { SmallVector<Value> extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = dyn_cast<VectorType>(operand.get().getType()); @@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern { extractOperands.push_back(operand.get()); continue; } - extractOperands.push_back( - rewriter.createOrFold<vector::ExtractStridedSliceOp>( - loc, operand.get(), offsets, *targetShape, strides)); + Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, operand.get(), offsets, adjustedTargetShape, strides); + + // Reshape to remove leading unit dims if needed + if (adjustedTargetShapeRank > targetShapeRank) { + extracted = rewriter.createOrFold<vector::ShapeCastOp>( + loc, VectorType::get(*targetShape, vecType.getElementType()), + extracted); + } + extractOperands.push_back(extracted); } + Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, op, extractOperands, newVecType); + rewriter, loc, op, extractOperands, unrolledVecType); + + Value computeResult = newOp->getResult(0); + + // Use strides sized to targetShape for proper insertion + SmallVector<int64_t> insertStrides = + (adjustedTargetShapeRank > targetShapeRank) + ? SmallVector<int64_t>(targetShapeRank, 1) + : strides; + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( - loc, newOp->getResult(0), result, offsets, strides); + loc, computeResult, result, offsets, insertStrides); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 36c498e..f77784a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op)) return getTileShape(op->getOpResult(0)); if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, - xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op)) + xegpu::StoreMatrixOp>(op)) return getTileShape(op->getOpOperand(0)); - if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op)) + if (isa<xegpu::StoreNdOp>(op)) return getTileShape(op->getOpOperand(1)); + // Handle LoadGatherOp and StoreScatterOp (with and without offset) + if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) { + if (loadGatherOp.getOffsets()) + return getTileShape(loadGatherOp->getOpResult(0)); + else + return getTileShape(loadGatherOp->getOpOperand(0)); + } + + if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) + return getTileShape(storeScatterOp.getOffsets() + ? storeScatterOp->getOpOperand(0) + : storeScatterOp->getOpOperand(1)); + if (isa<xegpu::DpasOp>(op)) { std::optional<SmallVector<int64_t>> aTile = getTileShape(op->getOpOperand(0)); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 3d19c5a..9b23dd6 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2200,10 +2200,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, os << '>'; } os << '['; - interleave( - loc.getLocations(), - [&](Location loc) { printLocationInternal(loc, pretty); }, - [&]() { os << ", "; }); + interleaveComma(loc.getLocations(), [&](Location loc) { + printLocationInternal(loc, pretty); + }); os << ']'; }) .Default([&](LocationAttr loc) { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 776b5c6..4d81918 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl { } // Otherwise, try to load the source file. - std::string ignored; - unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored); + auto bufferOrErr = llvm::MemoryBuffer::getFile(filename); + if (!bufferOrErr) + return 0; + unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc()); filenameToBufId[filename] = id; return id; } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 1fa04ed..89b81cf 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -121,6 +121,11 @@ namespace mlir { class MLIRContextImpl { public: //===--------------------------------------------------------------------===// + // Remark + //===--------------------------------------------------------------------===// + std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; + + //===--------------------------------------------------------------------===// // Debugging //===--------------------------------------------------------------------===// @@ -135,11 +140,6 @@ public: DiagnosticEngine diagEngine; //===--------------------------------------------------------------------===// - // Remark - //===--------------------------------------------------------------------===// - std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; - - //===--------------------------------------------------------------------===// // Options //===--------------------------------------------------------------------===// @@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); } -MLIRContext::~MLIRContext() = default; +MLIRContext::~MLIRContext() { + // finalize remark engine before destroying anything else. + impl->remarkEngine.reset(); +} /// Copy the specified array of elements into memory managed by the provided /// bump pointer allocator. This assumes the elements are all PODs. diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index a55f61a..031eae2 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -16,7 +16,7 @@ #include "llvm/ADT/StringRef.h" using namespace mlir::remark::detail; - +using namespace mlir::remark; //------------------------------------------------------------------------------ // Remark //------------------------------------------------------------------------------ @@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) { void Remark::print(llvm::raw_ostream &os, bool printLocation) const { // Header: [Type] pass:remarkName StringRef type = getRemarkTypeString(); - StringRef categoryName = getFullCategoryName(); + StringRef categoryName = getCombinedCategoryName(); StringRef name = remarkName; os << '[' << type << "] "; @@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const { os << "Function=" << getFunction() << " | "; if (printLocation) { - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) { os << " @" << flc.getFilename() << ":" << flc.getLine() << ":" << flc.getColumn(); + } } printArgs(os, getArgs()); @@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const { r.RemarkType = getRemarkType(); r.RemarkName = getRemarkName(); // MLIR does not use passes; instead, it has categories and sub-categories. - r.PassName = getFullCategoryName(); + r.PassName = getCombinedCategoryName(); r.FunctionName = getFunction(); r.Loc = locLambda(); for (const Remark::Arg &arg : getArgs()) { @@ -225,26 +226,42 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc, // RemarkEngine //===----------------------------------------------------------------------===// -void RemarkEngine::report(const Remark &&remark) { +void RemarkEngine::reportImpl(const Remark &remark) { // Stream the remark - if (remarkStreamer) + if (remarkStreamer) { remarkStreamer->streamOptimizationRemark(remark); + } // Print using MLIR's diagnostic if (printAsEmitRemarks) emitRemark(remark.getLocation(), remark.getMsg()); } +void RemarkEngine::report(const Remark &&remark) { + if (remarkEmittingPolicy) + remarkEmittingPolicy->reportRemark(remark); +} + RemarkEngine::~RemarkEngine() { + if (remarkEmittingPolicy) + remarkEmittingPolicy->finalize(); + if (remarkStreamer) remarkStreamer->finalize(); } -llvm::LogicalResult -RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, - std::string *errMsg) { - // If you need to validate categories/filters, do so here and set errMsg. +llvm::LogicalResult RemarkEngine::initialize( + std::unique_ptr<MLIRRemarkStreamerBase> streamer, + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy, + std::string *errMsg) { + remarkStreamer = std::move(streamer); + + auto reportFunc = + std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1); + remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc))); + + this->remarkEmittingPolicy = std::move(remarkEmittingPolicy); return success(); } @@ -301,14 +318,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks, } llvm::LogicalResult mlir::remark::enableOptimizationRemarks( - MLIRContext &ctx, - std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer, - const remark::RemarkCategories &cats, bool printAsEmitRemarks) { + MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, + const RemarkCategories &cats, bool printAsEmitRemarks) { auto engine = - std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats); + std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats); std::string errMsg; - if (failed(engine->initialize(std::move(streamer), &errMsg))) { + if (failed(engine->initialize(std::move(streamer), + std::move(remarkEmittingPolicy), &errMsg))) { llvm::report_fatal_error( llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg); } @@ -316,3 +334,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks( return success(); } + +//===----------------------------------------------------------------------===// +// Remark emitting policies +//===----------------------------------------------------------------------===// + +namespace mlir::remark { +RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default; +RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default; +} // namespace mlir::remark diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 388de1c..f96af02 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES FunctionInterfaces.cpp IndexingMapOpInterface.cpp InferIntRangeInterface.cpp + InferStridedMetadataInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp MemOpInterfaces.cpp @@ -64,6 +65,21 @@ add_mlir_library(MLIRFunctionInterfaces add_mlir_interface_library(IndexingMapOpInterface) add_mlir_interface_library(InferIntRangeInterface) + +add_mlir_library(MLIRInferStridedMetadataInterface + InferStridedMetadataInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces + + DEPENDS + MLIRInferStridedMetadataInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + MLIRIR +) + add_mlir_interface_library(InferTypeOpInterface) add_mlir_library(MLIRLoopLikeInterface diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index 9f3e97d..84fc9b8 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) { return os; } +SmallVector<IntegerValueRange> +mlir::getIntValueRanges(ArrayRef<OpFoldResult> values, + GetIntRangeFn getIntRange, int32_t indexBitwidth) { + SmallVector<IntegerValueRange> ranges; + ranges.reserve(values.size()); + for (OpFoldResult ofr : values) { + if (auto value = dyn_cast<Value>(ofr)) { + ranges.push_back(getIntRange(value)); + continue; + } + + // Create a constant range. + auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); + ranges.emplace_back(ConstantIntRanges::constant( + attr.getValue().sextOrTrunc(indexBitwidth))); + } + return ranges; +} + void mlir::intrange::detail::defaultInferResultRanges( InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) { diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp new file mode 100644 index 0000000..483e9f1 --- /dev/null +++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp @@ -0,0 +1,36 @@ +//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferStridedMetadataInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include <optional> + +using namespace mlir; + +#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc" + +void StridedMetadataRange::print(raw_ostream &os) const { + if (isUninitialized()) { + os << "strided_metadata<None>"; + return; + } + os << "strided_metadata<offset = ["; + llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], sizes = ["; + llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], strides = ["; + llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "]>"; +} diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp index d213a1a..bf36286 100644 --- a/mlir/lib/Remark/RemarkStreamer.cpp +++ b/mlir/lib/Remark/RemarkStreamer.cpp @@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() { namespace mlir::remark { LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks) { FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr = @@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer( if (failed(sOr)) return failure(); - return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat, + return remark::enableOptimizationRemarks(ctx, std::move(*sOr), + std::move(remarkEmittingPolicy), cat, printAsEmitRemarks); } diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp index 4bbcd8e..db39c70 100644 --- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp +++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp @@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) { return UnknownLoc::get(context); // Add a fused location to link the subprogram information. - StringAttr funcName = StringAttr::get(context, subprogram->getName()); StringAttr fileName = StringAttr::get(context, subprogram->getFilename()); return FusedLocWith<DISubprogramAttr>::get( - {NameLoc::get(funcName), - FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, + {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, translate(subprogram), context); } diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 132be4e..51c6077 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -956,7 +956,7 @@ inline parsed_inst_t ExpressionParser::buildNumericOp( << ", type = " << ty << " ***"; auto tysToPop = SmallVector<Type, numOperands>(); tysToPop.resize(numOperands); - std::fill(tysToPop.begin(), tysToPop.end(), ty); + llvm::fill(tysToPop, ty); auto operands = popOperands(tysToPop); if (failed(operands)) return failure(); diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index c883baa..3236b4f 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Parser.h" #include <optional> @@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, llvm::SourceMgr tdSrcMgr; tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); + tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); // This class provides a context argument for the llvm::SourceMgr diagnostic // handler. diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 30fd384..9ef405d 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Remarks/RemarkFormat.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/ManagedStatic.h" @@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { "bitstream", "Print bitstream file")), llvm::cl::cat(remarkCategory)}; + static llvm::cl::opt<RemarkPolicy, /*ExternalStorage=*/true> remarkPolicy{ + "remark-policy", + llvm::cl::desc("Specify the policy for remark output."), + cl::location(remarkPolicyFlag), + llvm::cl::value_desc("format"), + llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL), + llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all", + "Print all remarks"), + clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final", + "Print final remarks")), + llvm::cl::cat(remarkCategory)}; + static cl::opt<std::string, /*ExternalStorage=*/true> remarksAll( "remarks-filter", cl::desc("Show all remarks: passed, missed, failed, analysis"), @@ -517,18 +530,28 @@ performActions(raw_ostream &os, return failure(); context->enableMultithreading(wasThreadingEnabled); - + // Set the remark categories and policy. remark::RemarkCategories cats{ config.getRemarksAllFilter(), config.getRemarksPassedFilter(), config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(), config.getRemarksFailedFilter()}; mlir::MLIRContext &ctx = *context; + // Helper to create the appropriate policy based on configuration + auto createPolicy = [&config]() + -> std::unique_ptr<mlir::remark::detail::RemarkEmittingPolicyBase> { + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyAll>(); + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyFinal>(); + + llvm_unreachable("Invalid remark policy"); + }; switch (config.getRemarkFormat()) { case RemarkFormat::REMARK_FORMAT_STDOUT: if (failed(mlir::remark::enableOptimizationRemarks( - ctx, nullptr, cats, true /*printAsEmitRemarks*/))) + ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/))) return failure(); break; @@ -537,7 +560,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.yaml" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::YAML, cats))) + ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats))) return failure(); break; } @@ -547,7 +570,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.bitstream" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::Bitstream, cats))) + ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats))) return failure(); break; } @@ -593,6 +616,12 @@ performActions(raw_ostream &os, AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); os << OpWithState(op.get(), asmState) << '\n'; + + // This is required if the remark policy is final. Otherwise, the remarks are + // not emitted. + if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine()) + engine->getRemarkEmittingPolicy()->finalize(); + return success(); } diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 60b9567..1dbe7eca 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -31,6 +31,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include <optional> using namespace mlir; @@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, llvm::append_range(includeDirs, extraDirs); sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) { diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 3080b78..2d817be 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" #include <optional> @@ -448,6 +449,7 @@ void TableGenTextFile::initialize( return; } sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); // This class provides a context argument for the SourceMgr diagnostic diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index 111f58e..5f3b04a 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -66,7 +66,9 @@ size_t mlir::moveLoopInvariantCode( size_t numMoved = 0; for (Region *region : regions) { - LDBG() << "Original loop:\n" << *region->getParentOp(); + LDBG() << "Original loop:\n" + << OpWithFlags(region->getParentOp(), + OpPrintingFlags().skipRegions()); std::queue<Operation *> worklist; // Add top-level operations in the loop body to the worklist. @@ -90,7 +92,8 @@ size_t mlir::moveLoopInvariantCode( !canBeHoisted(op, definedOutside)) continue; - LDBG() << "Moving loop-invariant op: " << *op; + LDBG() << "Moving loop-invariant op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); moveOutOfRegion(op, region); ++numMoved; @@ -111,9 +114,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { [&](Value value, Region *) { return loopLike.isDefinedOutsideOfLoop(value); }, - [&](Operation *op, Region *) { - return isMemoryEffectFree(op) && isSpeculatable(op); - }, + [&](Operation *op, Region *) { return isPure(op); }, [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); } |