diff options
Diffstat (limited to 'mlir/lib')
19 files changed, 1004 insertions, 128 deletions
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 609cb34..db10ebc 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp + DataFlow/StridedMetadataRangeAnalysis.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis @@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis MLIRDataLayoutInterfaces MLIRFunctionInterfaces MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRPresburger diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp new file mode 100644 index 0000000..01c9daf --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp @@ -0,0 +1,127 @@ +//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "strided-metadata-range-analysis" + +using namespace mlir; +using namespace mlir::dataflow; + +/// Get the entry state for a value. For any value that is not a ranked memref, +/// this function sets the metadata to a top state with no offsets, sizes, or +/// strides. For `memref` types, this function will use the metadata in the type +/// to try to deduce as much informaiton as possible. +static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) { + // TODO: generalize this method with a type interface. + auto mTy = dyn_cast<BaseMemRefType>(v.getType()); + + // If not a memref or it's un-ranked, don't infer any metadata. + if (!mTy || !mTy.hasRank()) + return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0); + + // Get the top state. + auto metadata = + StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank()); + + // Compute the offset and strides. + int64_t offset; + SmallVector<int64_t> strides; + if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset))) + return metadata; + + // Refine the metadata if we know it from the type. + if (!ShapedType::isDynamic(offset)) { + metadata.getOffsets()[0] = + ConstantIntRanges::constant(APInt(indexBitwidth, offset)); + } + for (auto &&[size, range] : + llvm::zip_equal(mTy.getShape(), metadata.getSizes())) { + if (ShapedType::isDynamic(size)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, size)); + } + for (auto &&[stride, range] : + llvm::zip_equal(strides, metadata.getStrides())) { + if (ShapedType::isDynamic(stride)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, stride)); + } + + return metadata; +} + +StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis( + DataFlowSolver &solver, int32_t indexBitwidth) + : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) { + assert(indexBitwidth > 0 && "invalid bitwidth"); +} + +void StridedMetadataRangeAnalysis::setToEntryState( + StridedMetadataRangeLattice *lattice) { + propagateIfChanged(lattice, lattice->join(getEntryStateImpl( + lattice->getAnchor(), indexBitwidth))); +} + +LogicalResult StridedMetadataRangeAnalysis::visitOperation( + Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands, + ArrayRef<StridedMetadataRangeLattice *> results) { + auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op); + + // Bail if we cannot reason about the op. + if (!inferrable) { + setAllToEntryStates(results); + return success(); + } + + LDBG() << "Inferring metadata for: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + + // Helper function to retrieve int range values. + auto getIntRange = [&](Value value) -> IntegerValueRange { + auto lattice = getOrCreateFor<IntegerValueRangeLattice>( + getProgramPointAfter(op), value); + return lattice ? lattice->getValue() : IntegerValueRange(); + }; + + // Convert the arguments lattices to a vector. + SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector( + operands, [](const StridedMetadataRangeLattice *lattice) { + return lattice->getValue(); + }); + + // Callback to set metadata on a result. + auto joinCallback = [&](Value v, const StridedMetadataRange &md) { + auto result = cast<OpResult>(v); + assert(llvm::is_contained(op->getResults(), result)); + LDBG() << "- Inferred metadata: " << md; + StridedMetadataRangeLattice *lattice = results[result.getResultNumber()]; + ChangeResult changed = lattice->join(md); + LDBG() << "- Joined metadata: " << lattice->getValue(); + propagateIfChanged(lattice, changed); + }; + + // Infer the metadata. + inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback, + indexBitwidth); + return success(); +} diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 71986f8..bebf1b8 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) +add_subdirectory(MathToXeVM) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt new file mode 100644 index 0000000..050c0ed --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(MLIRMathToXeVM + MathToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithAttrToLLVMConversion + MLIRArithDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRMathDialect + MLIRXeVMDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp new file mode 100644 index 0000000..0fe31d0 --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -0,0 +1,167 @@ +//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/FormatVariadic.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-xevm" + +/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. +template <typename Op> +struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> { + + ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, + PatternBenefit benefit = 1) + : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {} + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isSPIRVCompatibleFloatOrVec(op.getType())) + return failure(); + + arith::FastMathFlags fastFlags = op.getFastmath(); + if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn)) + return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); + + SmallVector<Type, 1> operandTypes; + for (auto operand : adaptor.getOperands()) { + Type opTy = operand.getType(); + // This pass only supports operations on vectors that are already in SPIRV + // supported vector sizes: Distributing unsupported vector sizes to SPIRV + // supported vector sizes are done in other blocking optimization passes. + if (!isSPIRVCompatibleFloatOrVec(opTy)) + return rewriter.notifyMatchFailure( + op, llvm::formatv("incompatible operand type: '{0}'", opTy)); + operandTypes.push_back(opTy); + } + + auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>(); + auto funcOpRes = LLVM::lookupOrCreateFn( + rewriter, moduleOp, getMangledNativeFuncName(operandTypes), + operandTypes, op.getType()); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + + auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( + op, funcOp, adaptor.getOperands()); + // Preserve fastmath flags in our MLIR op when converting to llvm function + // calls, in order to allow further fastmath optimizations: We thus need to + // convert arith fastmath attrs into attrs recognized by llvm. + arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op); + mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; + callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); + return success(); + } + + inline bool isSPIRVCompatibleFloatOrVec(Type type) const { + if (type.isFloat()) + return true; + if (auto vecType = dyn_cast<VectorType>(type)) { + if (!vecType.getElementType().isFloat()) + return false; + // SPIRV distinguishes between vectors and matrices: OpenCL native math + // intrsinics are not compatible with matrices. + ArrayRef<int64_t> shape = vecType.getShape(); + if (shape.size() != 1) + return false; + // SPIRV only allows vectors of size 2, 3, 4, 8, 16. + if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || + shape[0] == 16) + return true; + } + return false; + } + + inline std::string + getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const { + std::string mangledFuncName = + "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); + + auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { + if (type.isF32()) + mangledFuncName += "f"; + else if (type.isF16()) + mangledFuncName += "Dh"; + else if (type.isF64()) + mangledFuncName += "d"; + }; + + for (auto type : operandTypes) { + if (auto vecType = dyn_cast<VectorType>(type)) { + mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; + appendFloatToMangledFunc(vecType.getElementType()); + } else + appendFloatToMangledFunc(type); + } + + return mangledFuncName; + } + + const StringRef nativeFunc; +}; + +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith) { + patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(), + "__spirv_ocl_native_exp"); + patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(), + "__spirv_ocl_native_cos"); + patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>( + patterns.getContext(), "__spirv_ocl_native_exp2"); + patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(), + "__spirv_ocl_native_log"); + patterns.add<ConvertNativeFuncPattern<math::Log2Op>>( + patterns.getContext(), "__spirv_ocl_native_log2"); + patterns.add<ConvertNativeFuncPattern<math::Log10Op>>( + patterns.getContext(), "__spirv_ocl_native_log10"); + patterns.add<ConvertNativeFuncPattern<math::PowFOp>>( + patterns.getContext(), "__spirv_ocl_native_powr"); + patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_rsqrt"); + patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(), + "__spirv_ocl_native_sin"); + patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_sqrt"); + patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(), + "__spirv_ocl_native_tan"); + if (convertArith) + patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>( + patterns.getContext(), "__spirv_ocl_native_divide"); +} + +namespace { +struct ConvertMathToXeVMPass + : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToXeVMPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateMathToXeVMConversionPatterns(patterns, convertArith); + ConversionTarget target(getContext()); + target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>(); + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a5336ed..00df14b1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1392,6 +1392,137 @@ public: } }; +// Collapse tensor<1xiN> into tensor<iN> +// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16> +static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input, + Location loc) { + SmallVector<ReassociationExprs, 1> reassociation; + // Create the collapsed type + auto inputType = cast<RankedTensorType>(input.getType()); + auto elemType = inputType.getElementType(); + auto collapsedType = RankedTensorType::get({}, elemType); + // Emit the collapse op + return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input, + reassociation); +} + +static llvm::SmallVector<int8_t> +convertToI8(const llvm::SmallVector<int32_t> &input) { + llvm::SmallVector<int8_t> output; + output.reserve(input.size()); + + for (auto v : llvm::map_range( + input, [](int32_t val) { return static_cast<int8_t>(val); })) { + output.push_back(v); + } + return output; +} + +// The shift or multiplier may be either constant or non-constant, depending on +// whether dynamic extension is enabled. +// - If the shift or multiplier is non-constant, add it as an input to +// linalg::GenericOp by: +// 1. Pushing it into 'genericInputs'. +// 2. Appending a corresponding affine map to 'indexingMaps'. +// - If the shift or multiplier is constant, set 'constant' instead. +static void setupLinalgGenericOpInputAndIndexingMap( + PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values, + SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps, + bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg, + bool isShift = false) { + + auto loc = op.getLoc(); + auto inputTy = cast<ShapedType>(op.getInput().getType()); + unsigned rank = inputTy.getRank(); + SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)}; + + if (isConstant) { + // If we are rescaling per-channel then we need to store the + // values in a buffer. + if (values.size() == 1) { + IntegerAttr intAttr = isShift + ? rewriter.getI8IntegerAttr(values.front()) + : rewriter.getI32IntegerAttr(values.front()); + constant = rewriter.create<arith::ConstantOp>(loc, intAttr); + } else { + auto elementType = + isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type(); + auto tensorType = RankedTensorType::get( + {static_cast<int64_t>(values.size())}, elementType); + DenseIntElementsAttr EltAttr; + if (isShift) + EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values)); + else + EltAttr = DenseIntElementsAttr::get(tensorType, values); + genericInputs.push_back( + arith::ConstantOp::create(rewriter, loc, EltAttr)); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } else { + // If we are not rescaling per-channel then we need to collapse 1xN to N + // and push broadcastMap. + auto operand = isShift ? op.getShift() : op.getMultiplier(); + auto tensorType = dyn_cast<RankedTensorType>(operand.getType()); + if (tensorType && tensorType.hasStaticShape() && + tensorType.getShape()[0] == 1) { + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = + AffineMap::get(rank, 0, {}, rewriter.getContext()); + genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc)); + indexingMaps.push_back(broadcastMap); + } else { + genericInputs.push_back(operand); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } + arg = indexingMaps.size() - 1; +} + +// Return the extended Zp to be used in subsequent arithmetic operations. +static Value getExtendZp(OpBuilder &builder, Type valueTy, + FailureOr<int64_t> maybeZp, Location loc, + ValueRange blockArgs, int64_t zpArg, + bool isOutputZp = false) { + Value result; + const int32_t bitwidth = valueTy.getIntOrFloatBitWidth(); + const uint32_t attrBitwidth = + isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32); + auto extendType = builder.getIntegerType(attrBitwidth); + // The Zp value can be either constant or non-constant, depending on + // whether dynamic extension is enabled. + // If 'maybeZp' fails, it indicates that Zp is non-constant and will + // be passed as an input to linalg::GenericOp. + if (failed(maybeZp)) { + result = blockArgs[zpArg]; + auto zpTy = result.getType(); + if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) { + // For ExtUIOp, the input must be signless. + // UnrealizedConversionCastOp will cast the input to signless type. + if (zpTy.isUnsignedInteger()) { + result = + UnrealizedConversionCastOp::create( + builder, loc, + builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result) + .getResult(0); + } + if (zpTy.isUnsignedInteger()) { + return builder.create<arith::ExtUIOp>(loc, extendType, result); + } else { + return builder.create<arith::ExtSIOp>(loc, extendType, result); + } + } + } else { + return builder.create<arith::ConstantOp>( + loc, IntegerAttr::get(extendType, *maybeZp)); + } + return result; +} + class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { public: using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; @@ -1423,40 +1554,46 @@ public: } } - // The shift and multiplier values. DenseElementsAttr shiftElems; - if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant shift input values"); + bool isShiftConstant = false; + if (matchPattern(op.getShift(), m_Constant(&shiftElems))) + isShiftConstant = true; DenseElementsAttr multiplierElems; - if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant multiplier input values"); - - llvm::SmallVector<int8_t> shiftValues = - llvm::to_vector(shiftElems.getValues<int8_t>()); - // explicit cast is required here - llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector( - llvm::map_range(multiplierElems.getValues<IntegerAttr>(), - [](IntegerAttr attr) -> int32_t { - return static_cast<int32_t>(attr.getInt()); - })); - - // If we shift by more than the bitwidth, this just sets to 0. - for (int i = 0, s = multiplierValues.size(); i < s; i++) { - if (shiftValues[i] > 63) { - shiftValues[i] = 0; - multiplierValues[i] = 0; + bool isMultiplierConstant = false; + if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) + isMultiplierConstant = true; + + llvm::SmallVector<int32_t> shiftValues; + llvm::SmallVector<int32_t> multiplierValues; + bool doubleRound; + + if (isMultiplierConstant && isShiftConstant) { + // explicit cast is required here + shiftValues = llvm::to_vector(llvm::map_range( + shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + multiplierValues = llvm::to_vector( + llvm::map_range(multiplierElems.getValues<IntegerAttr>(), + [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + + // If we shift by more than the bitwidth, this just sets to 0. + for (int i = 0, s = multiplierValues.size(); i < s; i++) { + if (shiftValues[i] > 63) { + shiftValues[i] = 0; + multiplierValues[i] = 0; + } } - } + // Double round only occurs if shift is greater than 31, check that this + // is ever true. + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && + llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); + } else + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND; - // Double round only occurs if shift is greater than 31, check that this - // is ever true. - - bool doubleRound = - op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && - llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); RoundingMode roundingMode = doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND; @@ -1468,45 +1605,43 @@ public: // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; - if (multiplierValues.size() == 1) { - multiplierConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); - } else { - SmallVector<AffineExpr, 2> multiplierExprs{ - rewriter.getAffineDimExpr(rank - 1)}; - auto multiplierType = - RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, - rewriter.getI32Type()); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, - DenseIntElementsAttr::get(multiplierType, multiplierValues))); - - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, multiplierExprs, - rewriter.getContext())); - - multiplierArg = indexingMaps.size() - 1; - } + setupLinalgGenericOpInputAndIndexingMap( + rewriter, multiplierValues, genericInputs, indexingMaps, + isMultiplierConstant, op, multiplierConstant, multiplierArg); // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; - if (shiftValues.size() == 1) { - shiftConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); - } else { - SmallVector<AffineExpr, 2> shiftExprs = { - rewriter.getAffineDimExpr(rank - 1)}; - auto shiftType = - RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, - rewriter.getIntegerType(8)); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, shiftExprs, - rewriter.getContext())); - shiftArg = indexingMaps.size() - 1; + setupLinalgGenericOpInputAndIndexingMap( + rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op, + shiftConstant, shiftArg, true); + + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext()); + FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); + FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); + // The inputZp and outputZp may be either constant or non-constant, + // depending on whether dynamic extension is enabled. + // - If the zp's are non-constant, add them as an inputs to + // linalg::GenericOp by: + // 1. Pushing it into 'genericInputs'. + // 2. Appending a corresponding affine map to 'indexingMaps'. + // - If the zp's are constant, they would be generated as arith.constant. + int64_t iZpArg = 0; + if (failed(maybeIZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(3), loc)); + indexingMaps.push_back(broadcastMap); + iZpArg = indexingMaps.size() - 1; + } + int64_t oZpArg = 0; + if (failed(maybeOZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(4), loc)); + indexingMaps.push_back(broadcastMap); + oZpArg = indexingMaps.size() - 1; } // Indexing maps for output values. @@ -1526,36 +1661,17 @@ public: Type valueTy = value.getType(); FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); - if (failed(maybeIZp)) { - (void)rewriter.notifyMatchFailure( - op, "input zero point cannot be statically determined"); - return; - } - - const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); - // Extend zeropoint for sub-32bits widths. - const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), - *maybeIZp)); + auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp, + nestedLoc, blockArgs, iZpArg); FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); - if (failed(maybeOZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return; - }; + auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp, + nestedLoc, blockArgs, oZpArg, true); IntegerType outIntType = cast<IntegerType>(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); - const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), - *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index 624519f..70faa71 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -64,12 +64,13 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { module.walk([&](func::CallOp callOp) { if (func::FuncOp calledFunc = dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { - callerMap[calledFunc].insert(callOp); + if (!calledFunc.isPublic() && !calledFunc.isExternal()) + callerMap[calledFunc].insert(callOp); } }); for (auto funcOp : module.getOps<func::FuncOp>()) { - if (funcOp.isExternal()) + if (funcOp.isExternal() || funcOp.isPublic()) continue; func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); // TODO: Support functions with multiple blocks. diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index e25a012..1382c7ac 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR DEPENDS MLIRMemRefOpsIncGen @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda..507597b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } +void SubViewOp::inferStridedMetadataRanges( + ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange, + SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { + auto isUninitialized = + +[](IntegerValueRange range) { return range.isUninitialized(); }; + + // Bail early if any of the operands metadata is not ready: + SmallVector<IntegerValueRange> offsetOperands = + getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); + if (llvm::any_of(offsetOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> sizeOperands = + getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); + if (llvm::any_of(sizeOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> stridesOperands = + getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); + if (llvm::any_of(stridesOperands, isUninitialized)) + return; + + StridedMetadataRange sourceRange = + ranges[getSourceMutable().getOperandNumber()]; + if (sourceRange.isUninitialized()) + return; + + ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides(); + + // Get the dropped dims. + llvm::SmallBitVector droppedDims = getDroppedDims(); + + // Compute the new offset, strides and sizes. + ConstantIntRanges offset = sourceRange.getOffsets()[0]; + SmallVector<ConstantIntRanges> strides, sizes; + + for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { + bool dropped = droppedDims.test(i); + // Compute the new offset. + ConstantIntRanges off = + intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); + offset = intrange::inferAdd({offset, off}); + + // Skip dropped dimensions. + if (dropped) + continue; + // Multiply the strides. + strides.push_back( + intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); + // Get the sizes. + sizes.push_back(sizeOperands[i].getValue()); + } + + setMetadata(getResult(), + StridedMetadataRange::getRanked( + SmallVector<ConstantIntRanges>({std::move(offset)}), + std::move(sizes), std::move(strides))); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6564a4e..642ced9 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" @@ -74,14 +75,16 @@ struct MemRefPointerLikeModel } mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, - StringRef varName, Type varType, - Value originalVar) const { + StringRef varName, Type varType, Value originalVar, + bool &needsFree) const { auto memrefTy = cast<MemRefType>(pointer); // Check if this is a static memref (all dimensions are known) - if yes // then we can generate an alloca operation. - if (memrefTy.hasStaticShape()) + if (memrefTy.hasStaticShape()) { + needsFree = false; // alloca doesn't need deallocation return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + } // For dynamic memrefs, extract sizes from the original variable if // provided. Otherwise they cannot be handled. @@ -99,6 +102,7 @@ struct MemRefPointerLikeModel // Note: We only add dynamic sizes to the dynamicSizes array // Static dimensions are handled automatically by AllocOp } + needsFree = true; // alloc needs deallocation return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) .getResult(); } @@ -108,10 +112,14 @@ struct MemRefPointerLikeModel } bool genFree(Type pointer, OpBuilder &builder, Location loc, - TypedValue<PointerLikeType> varPtr, Type varType) const { - if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + TypedValue<PointerLikeType> varToFree, Value allocRes, + Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) { + // Use allocRes if provided to determine the allocation type + Value valueToInspect = allocRes ? allocRes : memrefValue; + // Walk through casts to find the original allocation - Value currentValue = memrefValue; + Value currentValue = valueToInspect; Operation *originalAlloc = nullptr; // Follow the chain of operations to find the original allocation @@ -150,7 +158,7 @@ struct MemRefPointerLikeModel return true; } if (isa<memref::AllocOp>(originalAlloc)) { - // This is an alloc - generate dealloc + // This is an alloc - generate dealloc on varToFree memref::DeallocOp::create(builder, loc, memrefValue); return true; } @@ -1003,6 +1011,142 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +//===----------------------------------------------------------------------===// +// Recipe Region Helpers +//===----------------------------------------------------------------------===// + +/// Create and populate an init region for privatization recipes. +/// Returns the init block on success, or nullptr on failure. +/// Sets needsFree to indicate if the allocated memory requires deallocation. +static std::unique_ptr<Block> createInitRegion(OpBuilder &builder, Location loc, + Type varType, StringRef varName, + ValueRange bounds, + bool &needsFree) { + // Create init block with arguments: original value + bounds + SmallVector<Type> argTypes{varType}; + SmallVector<Location> argLocs{loc}; + for (Value bound : bounds) { + argTypes.push_back(bound.getType()); + argLocs.push_back(loc); + } + + auto initBlock = std::make_unique<Block>(); + initBlock->addArguments(argTypes, argLocs); + builder.setInsertionPointToStart(initBlock.get()); + + Value privatizedValue; + + // Get the block argument that represents the original variable + Value blockArgVar = initBlock->getArgument(0); + + // Generate init region body based on variable type + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + auto typedVar = cast<TypedValue<MappableType>>(blockArgVar); + privatizedValue = mappableTy.generatePrivateInit( + builder, loc, typedVar, varName, bounds, {}, needsFree); + if (!privatizedValue) + return nullptr; + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + // Use PointerLikeType's allocation API with the block argument + privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType, + blockArgVar, needsFree); + if (!privatizedValue) + return nullptr; + } + + // Add yield operation to init block + acc::YieldOp::create(builder, loc, privatizedValue); + + return initBlock; +} + +/// Create and populate a copy region for firstprivate recipes. +/// Returns the copy block on success, or nullptr on failure. +/// TODO: Handle MappableType - it does not yet have a copy API. +static std::unique_ptr<Block> createCopyRegion(OpBuilder &builder, Location loc, + Type varType, + ValueRange bounds) { + // Create copy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> copyArgTypes{varType, varType}; + SmallVector<Location> copyArgLocs{loc, loc}; + for (Value bound : bounds) { + copyArgTypes.push_back(bound.getType()); + copyArgLocs.push_back(loc); + } + + auto copyBlock = std::make_unique<Block>(); + copyBlock->addArguments(copyArgTypes, copyArgLocs); + builder.setInsertionPointToStart(copyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a copy API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + // Generate copy region body based on variable type + if (isPointerLike) { + auto pointerLikeTy = cast<PointerLikeType>(varType); + Value originalArg = copyBlock->getArgument(0); + Value privatizedArg = copyBlock->getArgument(1); + + // Generate copy operation using PointerLikeType interface + if (!pointerLikeTy.genCopy( + builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg), + cast<TypedValue<PointerLikeType>>(originalArg), varType)) + return nullptr; + } + + // Add terminator to copy block + acc::TerminatorOp::create(builder, loc); + + return copyBlock; +} + +/// Create and populate a destroy region for privatization recipes. +/// Returns the destroy block on success, or nullptr if not needed. +static std::unique_ptr<Block> createDestroyRegion(OpBuilder &builder, + Location loc, Type varType, + Value allocRes, + ValueRange bounds) { + // Create destroy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> destroyArgTypes{varType, varType}; + SmallVector<Location> destroyArgLocs{loc, loc}; + for (Value bound : bounds) { + destroyArgTypes.push_back(bound.getType()); + destroyArgLocs.push_back(loc); + } + + auto destroyBlock = std::make_unique<Block>(); + destroyBlock->addArguments(destroyArgTypes, destroyArgLocs); + builder.setInsertionPointToStart(destroyBlock.get()); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a deallocation API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return nullptr; + + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + auto privatizedArg = + cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1)); + // Pass allocRes to help determine the allocation type + if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType)) + return nullptr; + + acc::TerminatorOp::create(builder, loc); + + return destroyBlock; +} + } // namespace //===----------------------------------------------------------------------===// @@ -1050,6 +1194,55 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() { return success(); } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1080,6 +1273,60 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { return success(); } +std::optional<FirstprivateRecipeOp> +FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + // Create init, copy, and destroy blocks using shared helpers + OpBuilder::InsertionGuard guard(builder); + + // Save the original insertion point for creating the recipe operation later + auto originalInsertionPoint = builder.saveInsertionPoint(); + + bool needsFree = false; + auto initBlock = + createInitRegion(builder, loc, varType, varName, bounds, needsFree); + if (!initBlock) + return std::nullopt; + + auto copyBlock = createCopyRegion(builder, loc, varType, bounds); + if (!copyBlock) + return std::nullopt; + + // Only create destroy region if the allocation needs deallocation + std::unique_ptr<Block> destroyBlock; + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds); + if (!destroyBlock) + return std::nullopt; + } + + // Now create the recipe operation at the original insertion point and attach + // the blocks + builder.restoreInsertionPoint(originalInsertionPoint); + auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType); + + // Move the blocks into the recipe's regions + recipe.getInitRegion().push_back(initBlock.release()); + recipe.getCopyRegion().push_back(copyBlock.release()); + if (destroyBlock) + recipe.getDestroyRegion().push_back(destroyBlock.release()); + + return recipe; +} + //===----------------------------------------------------------------------===// // ReductionRecipeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 1fa04ed..89b81cf 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -121,6 +121,11 @@ namespace mlir { class MLIRContextImpl { public: //===--------------------------------------------------------------------===// + // Remark + //===--------------------------------------------------------------------===// + std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; + + //===--------------------------------------------------------------------===// // Debugging //===--------------------------------------------------------------------===// @@ -135,11 +140,6 @@ public: DiagnosticEngine diagEngine; //===--------------------------------------------------------------------===// - // Remark - //===--------------------------------------------------------------------===// - std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; - - //===--------------------------------------------------------------------===// // Options //===--------------------------------------------------------------------===// @@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); } -MLIRContext::~MLIRContext() = default; +MLIRContext::~MLIRContext() { + // finalize remark engine before destroying anything else. + impl->remarkEngine.reset(); +} /// Copy the specified array of elements into memory managed by the provided /// bump pointer allocator. This assumes the elements are all PODs. diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index a55f61a..031eae2 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -16,7 +16,7 @@ #include "llvm/ADT/StringRef.h" using namespace mlir::remark::detail; - +using namespace mlir::remark; //------------------------------------------------------------------------------ // Remark //------------------------------------------------------------------------------ @@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) { void Remark::print(llvm::raw_ostream &os, bool printLocation) const { // Header: [Type] pass:remarkName StringRef type = getRemarkTypeString(); - StringRef categoryName = getFullCategoryName(); + StringRef categoryName = getCombinedCategoryName(); StringRef name = remarkName; os << '[' << type << "] "; @@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const { os << "Function=" << getFunction() << " | "; if (printLocation) { - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) { os << " @" << flc.getFilename() << ":" << flc.getLine() << ":" << flc.getColumn(); + } } printArgs(os, getArgs()); @@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const { r.RemarkType = getRemarkType(); r.RemarkName = getRemarkName(); // MLIR does not use passes; instead, it has categories and sub-categories. - r.PassName = getFullCategoryName(); + r.PassName = getCombinedCategoryName(); r.FunctionName = getFunction(); r.Loc = locLambda(); for (const Remark::Arg &arg : getArgs()) { @@ -225,26 +226,42 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc, // RemarkEngine //===----------------------------------------------------------------------===// -void RemarkEngine::report(const Remark &&remark) { +void RemarkEngine::reportImpl(const Remark &remark) { // Stream the remark - if (remarkStreamer) + if (remarkStreamer) { remarkStreamer->streamOptimizationRemark(remark); + } // Print using MLIR's diagnostic if (printAsEmitRemarks) emitRemark(remark.getLocation(), remark.getMsg()); } +void RemarkEngine::report(const Remark &&remark) { + if (remarkEmittingPolicy) + remarkEmittingPolicy->reportRemark(remark); +} + RemarkEngine::~RemarkEngine() { + if (remarkEmittingPolicy) + remarkEmittingPolicy->finalize(); + if (remarkStreamer) remarkStreamer->finalize(); } -llvm::LogicalResult -RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, - std::string *errMsg) { - // If you need to validate categories/filters, do so here and set errMsg. +llvm::LogicalResult RemarkEngine::initialize( + std::unique_ptr<MLIRRemarkStreamerBase> streamer, + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy, + std::string *errMsg) { + remarkStreamer = std::move(streamer); + + auto reportFunc = + std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1); + remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc))); + + this->remarkEmittingPolicy = std::move(remarkEmittingPolicy); return success(); } @@ -301,14 +318,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks, } llvm::LogicalResult mlir::remark::enableOptimizationRemarks( - MLIRContext &ctx, - std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer, - const remark::RemarkCategories &cats, bool printAsEmitRemarks) { + MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, + const RemarkCategories &cats, bool printAsEmitRemarks) { auto engine = - std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats); + std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats); std::string errMsg; - if (failed(engine->initialize(std::move(streamer), &errMsg))) { + if (failed(engine->initialize(std::move(streamer), + std::move(remarkEmittingPolicy), &errMsg))) { llvm::report_fatal_error( llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg); } @@ -316,3 +334,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks( return success(); } + +//===----------------------------------------------------------------------===// +// Remark emitting policies +//===----------------------------------------------------------------------===// + +namespace mlir::remark { +RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default; +RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default; +} // namespace mlir::remark diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 388de1c..f96af02 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES FunctionInterfaces.cpp IndexingMapOpInterface.cpp InferIntRangeInterface.cpp + InferStridedMetadataInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp MemOpInterfaces.cpp @@ -64,6 +65,21 @@ add_mlir_library(MLIRFunctionInterfaces add_mlir_interface_library(IndexingMapOpInterface) add_mlir_interface_library(InferIntRangeInterface) + +add_mlir_library(MLIRInferStridedMetadataInterface + InferStridedMetadataInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces + + DEPENDS + MLIRInferStridedMetadataInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + MLIRIR +) + add_mlir_interface_library(InferTypeOpInterface) add_mlir_library(MLIRLoopLikeInterface diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index 9f3e97d..84fc9b8 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) { return os; } +SmallVector<IntegerValueRange> +mlir::getIntValueRanges(ArrayRef<OpFoldResult> values, + GetIntRangeFn getIntRange, int32_t indexBitwidth) { + SmallVector<IntegerValueRange> ranges; + ranges.reserve(values.size()); + for (OpFoldResult ofr : values) { + if (auto value = dyn_cast<Value>(ofr)) { + ranges.push_back(getIntRange(value)); + continue; + } + + // Create a constant range. + auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); + ranges.emplace_back(ConstantIntRanges::constant( + attr.getValue().sextOrTrunc(indexBitwidth))); + } + return ranges; +} + void mlir::intrange::detail::defaultInferResultRanges( InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) { diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp new file mode 100644 index 0000000..483e9f1 --- /dev/null +++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp @@ -0,0 +1,36 @@ +//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferStridedMetadataInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include <optional> + +using namespace mlir; + +#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc" + +void StridedMetadataRange::print(raw_ostream &os) const { + if (isUninitialized()) { + os << "strided_metadata<None>"; + return; + } + os << "strided_metadata<offset = ["; + llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], sizes = ["; + llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], strides = ["; + llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "]>"; +} diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp index d213a1a..bf36286 100644 --- a/mlir/lib/Remark/RemarkStreamer.cpp +++ b/mlir/lib/Remark/RemarkStreamer.cpp @@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() { namespace mlir::remark { LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks) { FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr = @@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer( if (failed(sOr)) return failure(); - return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat, + return remark::enableOptimizationRemarks(ctx, std::move(*sOr), + std::move(remarkEmittingPolicy), cat, printAsEmitRemarks); } diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 132be4e..51c6077 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -956,7 +956,7 @@ inline parsed_inst_t ExpressionParser::buildNumericOp( << ", type = " << ty << " ***"; auto tysToPop = SmallVector<Type, numOperands>(); tysToPop.resize(numOperands); - std::fill(tysToPop.begin(), tysToPop.end(), ty); + llvm::fill(tysToPop, ty); auto operands = popOperands(tysToPop); if (failed(operands)) return failure(); diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 30fd384..9ef405d 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Remarks/RemarkFormat.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/ManagedStatic.h" @@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { "bitstream", "Print bitstream file")), llvm::cl::cat(remarkCategory)}; + static llvm::cl::opt<RemarkPolicy, /*ExternalStorage=*/true> remarkPolicy{ + "remark-policy", + llvm::cl::desc("Specify the policy for remark output."), + cl::location(remarkPolicyFlag), + llvm::cl::value_desc("format"), + llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL), + llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all", + "Print all remarks"), + clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final", + "Print final remarks")), + llvm::cl::cat(remarkCategory)}; + static cl::opt<std::string, /*ExternalStorage=*/true> remarksAll( "remarks-filter", cl::desc("Show all remarks: passed, missed, failed, analysis"), @@ -517,18 +530,28 @@ performActions(raw_ostream &os, return failure(); context->enableMultithreading(wasThreadingEnabled); - + // Set the remark categories and policy. remark::RemarkCategories cats{ config.getRemarksAllFilter(), config.getRemarksPassedFilter(), config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(), config.getRemarksFailedFilter()}; mlir::MLIRContext &ctx = *context; + // Helper to create the appropriate policy based on configuration + auto createPolicy = [&config]() + -> std::unique_ptr<mlir::remark::detail::RemarkEmittingPolicyBase> { + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyAll>(); + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyFinal>(); + + llvm_unreachable("Invalid remark policy"); + }; switch (config.getRemarkFormat()) { case RemarkFormat::REMARK_FORMAT_STDOUT: if (failed(mlir::remark::enableOptimizationRemarks( - ctx, nullptr, cats, true /*printAsEmitRemarks*/))) + ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/))) return failure(); break; @@ -537,7 +560,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.yaml" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::YAML, cats))) + ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats))) return failure(); break; } @@ -547,7 +570,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.bitstream" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::Bitstream, cats))) + ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats))) return failure(); break; } @@ -593,6 +616,12 @@ performActions(raw_ostream &os, AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); os << OpWithState(op.get(), asmState) << '\n'; + + // This is required if the remark policy is final. Otherwise, the remarks are + // not emitted. + if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine()) + engine->getRemarkEmittingPolicy()->finalize(); + return success(); } diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index 111f58e..5f3b04a 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -66,7 +66,9 @@ size_t mlir::moveLoopInvariantCode( size_t numMoved = 0; for (Region *region : regions) { - LDBG() << "Original loop:\n" << *region->getParentOp(); + LDBG() << "Original loop:\n" + << OpWithFlags(region->getParentOp(), + OpPrintingFlags().skipRegions()); std::queue<Operation *> worklist; // Add top-level operations in the loop body to the worklist. @@ -90,7 +92,8 @@ size_t mlir::moveLoopInvariantCode( !canBeHoisted(op, definedOutside)) continue; - LDBG() << "Moving loop-invariant op: " << *op; + LDBG() << "Moving loop-invariant op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); moveOutOfRegion(op, region); ++numMoved; @@ -111,9 +114,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { [&](Value value, Region *) { return loopLike.isDefinedOutsideOfLoop(value); }, - [&](Operation *op, Region *) { - return isMemoryEffectFree(op) && isSpeculatable(op); - }, + [&](Operation *op, Region *) { return isPure(op); }, [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); } |