diff options
Diffstat (limited to 'mlir/lib')
106 files changed, 3626 insertions, 916 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 8062b474..a84d10d 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -258,6 +258,39 @@ getAllocEffectFor(Value value, return success(); } +static Operation *isDistinctObjectsOp(Operation *op) { + if (op && op->hasTrait<OpTrait::DistinctObjectsTrait>()) + return op; + + return nullptr; +} + +static Value getDistinctObjectsOperand(Operation *op, Value value) { + unsigned argNumber = cast<OpResult>(value).getResultNumber(); + return op->getOperand(argNumber); +} + +static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) { + // We should already checked that lhs and rhs are different. + assert(lhs != rhs && "lhs and rhs must be different"); + + // Result and corresponding operand must alias. + auto lhsOp = isDistinctObjectsOp(lhs.getDefiningOp()); + if (lhsOp && getDistinctObjectsOperand(lhsOp, lhs) == rhs) + return AliasResult::MustAlias; + + auto rhsOp = isDistinctObjectsOp(rhs.getDefiningOp()); + if (rhsOp && getDistinctObjectsOperand(rhsOp, rhs) == lhs) + return AliasResult::MustAlias; + + // If two different values come from the same `DistinctObjects` operation, + // they don't alias. + if (lhsOp && lhsOp == rhsOp) + return AliasResult::NoAlias; + + return std::nullopt; +} + /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { if (lhs == rhs) @@ -289,6 +322,9 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { : AliasResult::MayAlias; } + if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs)) + return *result; + // Otherwise, neither of the values are constant so check to see if either has // an allocation effect. bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 609cb34..db10ebc 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp + DataFlow/StridedMetadataRangeAnalysis.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis @@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis MLIRDataLayoutInterfaces MLIRFunctionInterfaces MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRPresburger diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp new file mode 100644 index 0000000..01c9daf --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp @@ -0,0 +1,127 @@ +//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "strided-metadata-range-analysis" + +using namespace mlir; +using namespace mlir::dataflow; + +/// Get the entry state for a value. For any value that is not a ranked memref, +/// this function sets the metadata to a top state with no offsets, sizes, or +/// strides. For `memref` types, this function will use the metadata in the type +/// to try to deduce as much informaiton as possible. +static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) { + // TODO: generalize this method with a type interface. + auto mTy = dyn_cast<BaseMemRefType>(v.getType()); + + // If not a memref or it's un-ranked, don't infer any metadata. + if (!mTy || !mTy.hasRank()) + return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0); + + // Get the top state. + auto metadata = + StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank()); + + // Compute the offset and strides. + int64_t offset; + SmallVector<int64_t> strides; + if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset))) + return metadata; + + // Refine the metadata if we know it from the type. + if (!ShapedType::isDynamic(offset)) { + metadata.getOffsets()[0] = + ConstantIntRanges::constant(APInt(indexBitwidth, offset)); + } + for (auto &&[size, range] : + llvm::zip_equal(mTy.getShape(), metadata.getSizes())) { + if (ShapedType::isDynamic(size)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, size)); + } + for (auto &&[stride, range] : + llvm::zip_equal(strides, metadata.getStrides())) { + if (ShapedType::isDynamic(stride)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, stride)); + } + + return metadata; +} + +StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis( + DataFlowSolver &solver, int32_t indexBitwidth) + : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) { + assert(indexBitwidth > 0 && "invalid bitwidth"); +} + +void StridedMetadataRangeAnalysis::setToEntryState( + StridedMetadataRangeLattice *lattice) { + propagateIfChanged(lattice, lattice->join(getEntryStateImpl( + lattice->getAnchor(), indexBitwidth))); +} + +LogicalResult StridedMetadataRangeAnalysis::visitOperation( + Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands, + ArrayRef<StridedMetadataRangeLattice *> results) { + auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op); + + // Bail if we cannot reason about the op. + if (!inferrable) { + setAllToEntryStates(results); + return success(); + } + + LDBG() << "Inferring metadata for: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + + // Helper function to retrieve int range values. + auto getIntRange = [&](Value value) -> IntegerValueRange { + auto lattice = getOrCreateFor<IntegerValueRangeLattice>( + getProgramPointAfter(op), value); + return lattice ? lattice->getValue() : IntegerValueRange(); + }; + + // Convert the arguments lattices to a vector. + SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector( + operands, [](const StridedMetadataRangeLattice *lattice) { + return lattice->getValue(); + }); + + // Callback to set metadata on a result. + auto joinCallback = [&](Value v, const StridedMetadataRange &md) { + auto result = cast<OpResult>(v); + assert(llvm::is_contained(op->getResults(), result)); + LDBG() << "- Inferred metadata: " << md; + StridedMetadataRangeLattice *lattice = results[result.getResultNumber()]; + ChangeResult changed = lattice->join(md); + LDBG() << "- Joined metadata: " << lattice->getValue(); + propagateIfChanged(lattice, changed); + }; + + // Infer the metadata. + inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback, + indexBitwidth); + return success(); +} diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 30ce1fb..6588b53 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -1244,8 +1244,9 @@ bool FlatLinearValueConstraints::areVarsAlignedWithOther( /// Checks if the SSA values associated with `cst`'s variables in range /// [start, end) are unique. -static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( - const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { +[[maybe_unused]] static bool +areVarsUnique(const FlatLinearValueConstraints &cst, unsigned start, + unsigned end) { assert(start <= cst.getNumDimAndSymbolVars() && "Start position out of bounds"); @@ -1267,14 +1268,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( } /// Checks if the SSA values associated with `cst`'s variables are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst) { return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars()); } /// Checks if the SSA values associated with `cst`'s variables of kind `kind` /// are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { if (kind == VarKind::SetDim) diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index a1cbe29..547a4c2 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -34,7 +34,7 @@ using Direction = Simplex::Direction; const int nullIndex = std::numeric_limits<int>::max(); // Return a + scale*b; -LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static SmallVector<DynamicAPInt, 8> scaleAndAddForAssert(ArrayRef<DynamicAPInt> a, const DynamicAPInt &scale, ArrayRef<DynamicAPInt> b) { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7b17106..06d0256 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2730,6 +2730,17 @@ public: operation->get(), toMlirStringRef(name))); } + static void + forEachAttr(MlirOperation op, + llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) { + intptr_t n = mlirOperationGetNumAttributes(op); + for (intptr_t i = 0; i < n; ++i) { + MlirNamedAttribute na = mlirOperationGetAttribute(op, i); + MlirStringRef name = mlirIdentifierStr(na.name); + fn(name, na.attribute); + } + } + static void bind(nb::module_ &m) { nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") .def("__contains__", &PyOpAttributeMap::dunderContains) @@ -2737,7 +2748,50 @@ public: .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) .def("__setitem__", &PyOpAttributeMap::dunderSetItem) - .def("__delitem__", &PyOpAttributeMap::dunderDelItem); + .def("__delitem__", &PyOpAttributeMap::dunderDelItem) + .def("__iter__", + [](PyOpAttributeMap &self) { + nb::list keys; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + keys.append(nb::str(name.data, name.length)); + }); + return nb::iter(keys); + }) + .def("keys", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + out.append(nb::str(name.data, name.length)); + }); + return out; + }) + .def("values", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef, MlirAttribute attr) { + out.append(PyAttribute(self.operation->getContext(), attr) + .maybeDownCast()); + }); + return out; + }) + .def("items", [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute attr) { + out.append(nb::make_tuple( + nb::str(name.data, name.length), + PyAttribute(self.operation->getContext(), attr) + .maybeDownCast())); + }); + return out; + }); } private: diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index d506b7f..0f0ed22 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -197,10 +197,15 @@ public: MlirPatternRewriter rewriter, void *userData) -> MlirLogicalResult { nb::handle f(static_cast<PyObject *>(userData)); - nb::object res = f(op, PyPatternRewriter(rewriter)); + + PyMlirContextRef ctx = + PyMlirContext::forContext(mlirOperationGetContext(op)); + nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); + + nb::object res = f(opView, PyPatternRewriter(rewriter)); return logicalResultFromObject(res); }; - MlirRewritePattern pattern = mlirOpRewritePattenCreate( + MlirRewritePattern pattern = mlirOpRewritePatternCreate( rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); @@ -261,7 +266,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet //---------------------------------------------------------------------------- - nb::class_<MlirRewritePattern>(m, "RewritePattern"); nb::class_<PyRewritePatternSet>(m, "RewritePatternSet") .def( "__init__", diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 70dee59..41ceb15 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,17 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::FrozenRewritePatternSet * -unwrap(MlirFrozenRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); -} - -static inline MlirFrozenRewritePatternSet -wrap(mlir::FrozenRewritePatternSet *module) { - return {module}; -} - MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set) { auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); @@ -311,15 +300,6 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, /// PatternRewriter API //===----------------------------------------------------------------------===// -inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { - assert(rewriter.ptr && "unexpected null rewriter"); - return static_cast<mlir::PatternRewriter *>(rewriter.ptr); -} - -inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { - return {rewriter}; -} - MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); } @@ -361,7 +341,7 @@ private: } // namespace mlir -MlirRewritePattern mlirOpRewritePattenCreate( +MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) { @@ -400,15 +380,6 @@ void mlirRewritePatternSetAdd(MlirRewritePatternSet set, //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { - assert(module.ptr && "unexpected null module"); - return static_cast<mlir::PDLPatternModule *>(module.ptr); -} - -static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { - return {module}; -} - MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); @@ -426,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { return wrap(m); } -inline const mlir::PDLValue *unwrap(MlirPDLValue value) { - assert(value.ptr && "unexpected null PDL value"); - return static_cast<const mlir::PDLValue *>(value.ptr); -} - -inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } - -inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { - assert(results.ptr && "unexpected null PDL results"); - return static_cast<mlir::PDLResultList *>(results.ptr); -} - -inline MlirPDLResultList wrap(mlir::PDLResultList *results) { - return {results}; -} - MlirValue mlirPDLValueAsValue(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast<mlir::Value>()); } diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 71986f8..bebf1b8 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) +add_subdirectory(MathToXeVM) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index b215211..c03f3a5 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns( GPUSubgroupBroadcastOpToROCDL>(converter); patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset); - populateMathToROCDLConversionPatterns(converter, patterns); + populateMathToROCDLConversionPatterns(converter, patterns, chipset); } diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt index 2771955a..8cc3fde 100644 --- a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL Core LINK_LIBS PUBLIC + MLIRAMDGPUUtils MLIRDialectUtils MLIRFuncDialect MLIRGPUToGPURuntimeTransforms diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index df219f3..a2dfc12 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,8 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -19,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/DebugLog.h" #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" @@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +struct ClampFOpConversion final + : public ConvertOpToLLVMPattern<math::ClampFOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only f16 and f32 types are supported by fmed3 + Type opTy = op.getType(); + Type resultType = getTypeConverter()->convertType(opTy); + + if (auto vectorType = dyn_cast<VectorType>(opTy)) + opTy = vectorType.getElementType(); + + if (!isa<Float16Type, Float32Type>(opTy)) + return rewriter.notifyMatchFailure( + op, "fmed3 only supports f16 and f32 types"); + + // Handle multi-dimensional vectors (converted to LLVM arrays) + if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) -> Value { + typename math::ClampFOp::Adaptor adaptor(operands); + return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getValue(), adaptor.getMin(), + adaptor.getMax()); + }, + rewriter); + + // Handle 1D vectors and scalars directly + rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(), + op.getMin(), op.getMax()); + return success(); + } +}; + void mlir::populateMathToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + std::optional<amdgpu::Chipset> chipset) { // Handled by mathToLLVM: math::AbsIOp // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp @@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns( // worth creating a separate pass for it. populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32", "__ocml_fmod_f64", "__ocml_fmod_f16"); + + if (chipset.has_value() && chipset->majorVersion >= 9) { + patterns.add<ClampFOpConversion>(converter); + } else { + LDBG() << "Chipset dependent patterns were not added"; + } } -namespace { -struct ConvertMathToROCDLPass - : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> { - ConvertMathToROCDLPass() = default; +struct ConvertMathToROCDLPass final + : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> { + using impl::ConvertMathToROCDLBase< + ConvertMathToROCDLPass>::ConvertMathToROCDLBase; + void runOnOperation() override; }; -} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); @@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() { RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - populateMathToROCDLConversionPatterns(converter, patterns); + + FailureOr<amdgpu::Chipset> maybeChipset; + if (!chipset.empty()) { + maybeChipset = amdgpu::Chipset::parse(chipset); + if (failed(maybeChipset)) + return signalPassFailure(); + } + populateMathToROCDLConversionPatterns( + converter, patterns, + succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt); + ConversionTarget target(getContext()); - target.addLegalDialect<BuiltinDialect, func::FuncDialect, - vector::VectorDialect, LLVM::LLVMDialect>(); + target + .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect, + LLVM::LLVMDialect, ROCDL::ROCDLDialect>(); target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index f0d8b78..610ce1f 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { if (auto vectorType = dyn_cast<VectorType>(operandType)) nanAttr = DenseElementsAttr::get(vectorType, nan); - Value NanValue = + Value nanValue = spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr); Value lhs = spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp, - NanValue, adaptor.getLhs()); + nanValue, adaptor.getLhs()); Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs); // TODO: The following just forcefully casts y into an integer value in diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt new file mode 100644 index 0000000..050c0ed --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(MLIRMathToXeVM + MathToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithAttrToLLVMConversion + MLIRArithDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRMathDialect + MLIRXeVMDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp new file mode 100644 index 0000000..0fe31d0 --- /dev/null +++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp @@ -0,0 +1,167 @@ +//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/FormatVariadic.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOXEVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-xevm" + +/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics. +template <typename Op> +struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> { + + ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, + PatternBenefit benefit = 1) + : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {} + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isSPIRVCompatibleFloatOrVec(op.getType())) + return failure(); + + arith::FastMathFlags fastFlags = op.getFastmath(); + if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn)) + return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation"); + + SmallVector<Type, 1> operandTypes; + for (auto operand : adaptor.getOperands()) { + Type opTy = operand.getType(); + // This pass only supports operations on vectors that are already in SPIRV + // supported vector sizes: Distributing unsupported vector sizes to SPIRV + // supported vector sizes are done in other blocking optimization passes. + if (!isSPIRVCompatibleFloatOrVec(opTy)) + return rewriter.notifyMatchFailure( + op, llvm::formatv("incompatible operand type: '{0}'", opTy)); + operandTypes.push_back(opTy); + } + + auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>(); + auto funcOpRes = LLVM::lookupOrCreateFn( + rewriter, moduleOp, getMangledNativeFuncName(operandTypes), + operandTypes, op.getType()); + assert(!failed(funcOpRes)); + LLVM::LLVMFuncOp funcOp = funcOpRes.value(); + + auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( + op, funcOp, adaptor.getOperands()); + // Preserve fastmath flags in our MLIR op when converting to llvm function + // calls, in order to allow further fastmath optimizations: We thus need to + // convert arith fastmath attrs into attrs recognized by llvm. + arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op); + mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0]; + callOp->setAttr(fastAttr.getName(), fastAttr.getValue()); + return success(); + } + + inline bool isSPIRVCompatibleFloatOrVec(Type type) const { + if (type.isFloat()) + return true; + if (auto vecType = dyn_cast<VectorType>(type)) { + if (!vecType.getElementType().isFloat()) + return false; + // SPIRV distinguishes between vectors and matrices: OpenCL native math + // intrsinics are not compatible with matrices. + ArrayRef<int64_t> shape = vecType.getShape(); + if (shape.size() != 1) + return false; + // SPIRV only allows vectors of size 2, 3, 4, 8, 16. + if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || + shape[0] == 16) + return true; + } + return false; + } + + inline std::string + getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const { + std::string mangledFuncName = + "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str(); + + auto appendFloatToMangledFunc = [&mangledFuncName](Type type) { + if (type.isF32()) + mangledFuncName += "f"; + else if (type.isF16()) + mangledFuncName += "Dh"; + else if (type.isF64()) + mangledFuncName += "d"; + }; + + for (auto type : operandTypes) { + if (auto vecType = dyn_cast<VectorType>(type)) { + mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_"; + appendFloatToMangledFunc(vecType.getElementType()); + } else + appendFloatToMangledFunc(type); + } + + return mangledFuncName; + } + + const StringRef nativeFunc; +}; + +void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, + bool convertArith) { + patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(), + "__spirv_ocl_native_exp"); + patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(), + "__spirv_ocl_native_cos"); + patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>( + patterns.getContext(), "__spirv_ocl_native_exp2"); + patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(), + "__spirv_ocl_native_log"); + patterns.add<ConvertNativeFuncPattern<math::Log2Op>>( + patterns.getContext(), "__spirv_ocl_native_log2"); + patterns.add<ConvertNativeFuncPattern<math::Log10Op>>( + patterns.getContext(), "__spirv_ocl_native_log10"); + patterns.add<ConvertNativeFuncPattern<math::PowFOp>>( + patterns.getContext(), "__spirv_ocl_native_powr"); + patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_rsqrt"); + patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(), + "__spirv_ocl_native_sin"); + patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>( + patterns.getContext(), "__spirv_ocl_native_sqrt"); + patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(), + "__spirv_ocl_native_tan"); + if (convertArith) + patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>( + patterns.getContext(), "__spirv_ocl_native_divide"); +} + +namespace { +struct ConvertMathToXeVMPass + : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> { + using Base::Base; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToXeVMPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateMathToXeVMConversionPatterns(patterns, convertArith); + ConversionTarget target(getContext()); + target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>(); + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 2b7bdc9..11f866c 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include <cstdint> #include <numeric> @@ -110,9 +111,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, {TypeAttr::get(memrefType.getElementType())})); IndexType indexType = builder.getIndexType(); - int64_t numElements = std::accumulate(memrefType.getShape().begin(), - memrefType.getShape().end(), int64_t{1}, - std::multiplies<int64_t>()); + int64_t numElements = llvm::product_of(memrefType.getShape()); emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( builder, loc, indexType, builder.getIndexAttr(numElements)); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a5336ed..00df14b1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1392,6 +1392,137 @@ public: } }; +// Collapse tensor<1xiN> into tensor<iN> +// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16> +static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input, + Location loc) { + SmallVector<ReassociationExprs, 1> reassociation; + // Create the collapsed type + auto inputType = cast<RankedTensorType>(input.getType()); + auto elemType = inputType.getElementType(); + auto collapsedType = RankedTensorType::get({}, elemType); + // Emit the collapse op + return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input, + reassociation); +} + +static llvm::SmallVector<int8_t> +convertToI8(const llvm::SmallVector<int32_t> &input) { + llvm::SmallVector<int8_t> output; + output.reserve(input.size()); + + for (auto v : llvm::map_range( + input, [](int32_t val) { return static_cast<int8_t>(val); })) { + output.push_back(v); + } + return output; +} + +// The shift or multiplier may be either constant or non-constant, depending on +// whether dynamic extension is enabled. +// - If the shift or multiplier is non-constant, add it as an input to +// linalg::GenericOp by: +// 1. Pushing it into 'genericInputs'. +// 2. Appending a corresponding affine map to 'indexingMaps'. +// - If the shift or multiplier is constant, set 'constant' instead. +static void setupLinalgGenericOpInputAndIndexingMap( + PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values, + SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps, + bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg, + bool isShift = false) { + + auto loc = op.getLoc(); + auto inputTy = cast<ShapedType>(op.getInput().getType()); + unsigned rank = inputTy.getRank(); + SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)}; + + if (isConstant) { + // If we are rescaling per-channel then we need to store the + // values in a buffer. + if (values.size() == 1) { + IntegerAttr intAttr = isShift + ? rewriter.getI8IntegerAttr(values.front()) + : rewriter.getI32IntegerAttr(values.front()); + constant = rewriter.create<arith::ConstantOp>(loc, intAttr); + } else { + auto elementType = + isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type(); + auto tensorType = RankedTensorType::get( + {static_cast<int64_t>(values.size())}, elementType); + DenseIntElementsAttr EltAttr; + if (isShift) + EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values)); + else + EltAttr = DenseIntElementsAttr::get(tensorType, values); + genericInputs.push_back( + arith::ConstantOp::create(rewriter, loc, EltAttr)); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } else { + // If we are not rescaling per-channel then we need to collapse 1xN to N + // and push broadcastMap. + auto operand = isShift ? op.getShift() : op.getMultiplier(); + auto tensorType = dyn_cast<RankedTensorType>(operand.getType()); + if (tensorType && tensorType.hasStaticShape() && + tensorType.getShape()[0] == 1) { + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = + AffineMap::get(rank, 0, {}, rewriter.getContext()); + genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc)); + indexingMaps.push_back(broadcastMap); + } else { + genericInputs.push_back(operand); + indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, + /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + } + arg = indexingMaps.size() - 1; +} + +// Return the extended Zp to be used in subsequent arithmetic operations. +static Value getExtendZp(OpBuilder &builder, Type valueTy, + FailureOr<int64_t> maybeZp, Location loc, + ValueRange blockArgs, int64_t zpArg, + bool isOutputZp = false) { + Value result; + const int32_t bitwidth = valueTy.getIntOrFloatBitWidth(); + const uint32_t attrBitwidth = + isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32); + auto extendType = builder.getIntegerType(attrBitwidth); + // The Zp value can be either constant or non-constant, depending on + // whether dynamic extension is enabled. + // If 'maybeZp' fails, it indicates that Zp is non-constant and will + // be passed as an input to linalg::GenericOp. + if (failed(maybeZp)) { + result = blockArgs[zpArg]; + auto zpTy = result.getType(); + if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) { + // For ExtUIOp, the input must be signless. + // UnrealizedConversionCastOp will cast the input to signless type. + if (zpTy.isUnsignedInteger()) { + result = + UnrealizedConversionCastOp::create( + builder, loc, + builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result) + .getResult(0); + } + if (zpTy.isUnsignedInteger()) { + return builder.create<arith::ExtUIOp>(loc, extendType, result); + } else { + return builder.create<arith::ExtSIOp>(loc, extendType, result); + } + } + } else { + return builder.create<arith::ConstantOp>( + loc, IntegerAttr::get(extendType, *maybeZp)); + } + return result; +} + class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { public: using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; @@ -1423,40 +1554,46 @@ public: } } - // The shift and multiplier values. DenseElementsAttr shiftElems; - if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant shift input values"); + bool isShiftConstant = false; + if (matchPattern(op.getShift(), m_Constant(&shiftElems))) + isShiftConstant = true; DenseElementsAttr multiplierElems; - if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) - return rewriter.notifyMatchFailure( - op, "tosa.rescale requires constant multiplier input values"); - - llvm::SmallVector<int8_t> shiftValues = - llvm::to_vector(shiftElems.getValues<int8_t>()); - // explicit cast is required here - llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector( - llvm::map_range(multiplierElems.getValues<IntegerAttr>(), - [](IntegerAttr attr) -> int32_t { - return static_cast<int32_t>(attr.getInt()); - })); - - // If we shift by more than the bitwidth, this just sets to 0. - for (int i = 0, s = multiplierValues.size(); i < s; i++) { - if (shiftValues[i] > 63) { - shiftValues[i] = 0; - multiplierValues[i] = 0; + bool isMultiplierConstant = false; + if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems))) + isMultiplierConstant = true; + + llvm::SmallVector<int32_t> shiftValues; + llvm::SmallVector<int32_t> multiplierValues; + bool doubleRound; + + if (isMultiplierConstant && isShiftConstant) { + // explicit cast is required here + shiftValues = llvm::to_vector(llvm::map_range( + shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + multiplierValues = llvm::to_vector( + llvm::map_range(multiplierElems.getValues<IntegerAttr>(), + [](IntegerAttr attr) -> int32_t { + return static_cast<int32_t>(attr.getInt()); + })); + + // If we shift by more than the bitwidth, this just sets to 0. + for (int i = 0, s = multiplierValues.size(); i < s; i++) { + if (shiftValues[i] > 63) { + shiftValues[i] = 0; + multiplierValues[i] = 0; + } } - } + // Double round only occurs if shift is greater than 31, check that this + // is ever true. + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && + llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); + } else + doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND; - // Double round only occurs if shift is greater than 31, check that this - // is ever true. - - bool doubleRound = - op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && - llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); RoundingMode roundingMode = doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND; @@ -1468,45 +1605,43 @@ public: // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; - if (multiplierValues.size() == 1) { - multiplierConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front())); - } else { - SmallVector<AffineExpr, 2> multiplierExprs{ - rewriter.getAffineDimExpr(rank - 1)}; - auto multiplierType = - RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, - rewriter.getI32Type()); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, - DenseIntElementsAttr::get(multiplierType, multiplierValues))); - - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, multiplierExprs, - rewriter.getContext())); - - multiplierArg = indexingMaps.size() - 1; - } + setupLinalgGenericOpInputAndIndexingMap( + rewriter, multiplierValues, genericInputs, indexingMaps, + isMultiplierConstant, op, multiplierConstant, multiplierArg); // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; - if (shiftValues.size() == 1) { - shiftConstant = arith::ConstantOp::create( - rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front())); - } else { - SmallVector<AffineExpr, 2> shiftExprs = { - rewriter.getAffineDimExpr(rank - 1)}; - auto shiftType = - RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, - rewriter.getIntegerType(8)); - genericInputs.push_back(arith::ConstantOp::create( - rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues))); - indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, shiftExprs, - rewriter.getContext())); - shiftArg = indexingMaps.size() - 1; + setupLinalgGenericOpInputAndIndexingMap( + rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op, + shiftConstant, shiftArg, true); + + // broadcastMap = affine_map<(d0, d1) -> ()> + // It would affect as broadcast for scalar values in linalg::GenericOp. + AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext()); + FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); + FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); + // The inputZp and outputZp may be either constant or non-constant, + // depending on whether dynamic extension is enabled. + // - If the zp's are non-constant, add them as an inputs to + // linalg::GenericOp by: + // 1. Pushing it into 'genericInputs'. + // 2. Appending a corresponding affine map to 'indexingMaps'. + // - If the zp's are constant, they would be generated as arith.constant. + int64_t iZpArg = 0; + if (failed(maybeIZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(3), loc)); + indexingMaps.push_back(broadcastMap); + iZpArg = indexingMaps.size() - 1; + } + int64_t oZpArg = 0; + if (failed(maybeOZp)) { + genericInputs.push_back( + collapse1xNTensorToN(rewriter, op->getOperand(4), loc)); + indexingMaps.push_back(broadcastMap); + oZpArg = indexingMaps.size() - 1; } // Indexing maps for output values. @@ -1526,36 +1661,17 @@ public: Type valueTy = value.getType(); FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); - if (failed(maybeIZp)) { - (void)rewriter.notifyMatchFailure( - op, "input zero point cannot be statically determined"); - return; - } - - const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); - // Extend zeropoint for sub-32bits widths. - const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; - auto inputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), - *maybeIZp)); + auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp, + nestedLoc, blockArgs, iZpArg); FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); - if (failed(maybeOZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return; - }; + auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp, + nestedLoc, blockArgs, oZpArg, true); IntegerType outIntType = cast<IntegerType>(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); - const int32_t outAttrBitwidth = 32; assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); - auto outputZp = arith::ConstantOp::create( - nestedBuilder, loc, - IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), - *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 802691c..9bf9ca3 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> @@ -70,8 +71,7 @@ TensorType inferReshapeExpandedType(TensorType inputType, // Calculate the product of all elements in 'newShape' except for the -1 // placeholder, which we discard by negating the result. - int64_t totalSizeNoPlaceholder = -std::accumulate( - newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>()); + int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape); // If there is a 0 component in 'newShape', resolve the placeholder as // 0. diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp index 79c2f23..245a3ef 100644 --- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -20,6 +20,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/DebugLog.h" #include <numeric> @@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter, if (isPacked) src = collapseLastDim(rewriter, src); int64_t rows = vecShape[0]; - int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1, - std::multiplies<int64_t>()); + int64_t cols = llvm::product_of(vecShape.drop_front()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); @@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter, ArrayRef<int64_t> shape = vecTy.getShape(); int64_t rows = shape[0]; - int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, - std::multiplies<int64_t>()); + int64_t cols = llvm::product_of(shape.drop_front()); auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); return amx::TileLoadOp::create(rewriter, loc, tileType, buf, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5355909..41d8d53 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle`. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + // For 1-d vector, we additionally do a `shufflevector`. int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0); SmallVector<int32_t> zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison, - zeroValues); + auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>( + broadcast.getLoc(), v, poison, zeroValues); + rewriter.replaceOp(broadcast, shuffle); return success(); } }; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index c45c45e..c9eba69 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOSCF @@ -760,8 +761,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { if (vectorType.getRank() != 1) { // Flatten n-D vectors to 1D. This is done to allow indexing with a // non-constant value. - auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies<int64_t>()); + int64_t flatLength = llvm::product_of(shape); auto flatVectorType = VectorType::get({flatLength}, vectorType.getElementType()); value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt index 84b2580..dd9edc4 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM MLIRIndexDialect MLIRSCFDialect MLIRXeGPUDialect + MLIRXeGPUUtils MLIRPass MLIRTransforms MLIRSCFTransforms diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 9ead1d8..fcbf66d 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -20,9 +20,12 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" @@ -61,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { case xegpu::MemorySpace::SLM: return static_cast<int>(xevm::AddrSpace::SHARED); } + llvm_unreachable("Unknown XeGPU memory space"); } // Get same bitwidth flat vector type of new element type. @@ -184,6 +188,7 @@ class CreateNdDescToXeVMPattern int64_t rank = mixedSizes.size(); if (rank != 2) return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -362,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { // Add a builder that creates // offset * elemByteSize + baseAddr -static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, - Value baseAddr, Value offset, int64_t elemByteSize) { +static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter, + Location loc, Value baseAddr, Value offset, + int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI64Type(), elemByteSize); + rewriter, loc, baseAddr.getType(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -389,7 +395,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // Load result or Store valye Type can be vector or scalar. Type valOrResTy; if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) - valOrResTy = op.getResult().getType(); + valOrResTy = + this->getTypeConverter()->convertType(op.getResult().getType()); else valOrResTy = adaptor.getValue().getType(); VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); @@ -440,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // If offset is provided, we add them to the base pointer. // Offset is in number of elements, we need to multiply by // element byte size. - basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); + basePtrI64 = + addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -503,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { } }; +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than +// 32 bits will be converted to 32 bits. +class CreateMemDescOpPattern final + : public OpConversionPattern<xegpu::CreateMemDescOp> { +public: + using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto resTy = op.getMemDesc(); + + // Create the result MemRefType with the same shape, element type, and + // memory space + auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); + + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, + op.getSource(), zero, ValueRange()); + rewriter.replaceOp(op, viewOp); + return success(); + } +}; + +template <typename OpType, + typename = std::enable_if_t<llvm::is_one_of< + OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> +class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<OpFoldResult> offsets = op.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); + + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + Value basePtrStruct = adaptor.getMemDesc(); + Value mdescVal = op.getMemDesc(); + // Load result or Store value Type can be vector or scalar. + Value data; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) + data = op.getResult(); + else + data = adaptor.getData(); + VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); + if (!valOrResVecTy) + valOrResVecTy = VectorType::get(1, data.getType()); + + int64_t elemBitWidth = + valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + + // Default memory space is SLM. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); + + auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); + + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, basePtrStruct); + + // Convert base pointer (ptr) to i32 + Value basePtrI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), basePtrLLVM); + + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + linearOffset = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), linearOffset); + basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, + elemByteSize); + + // convert base pointer (i32) to LLVM pointer type + basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); + + if (op.getSubgroupBlockIoAttr()) { + // if the attribute 'subgroup_block_io' is set to true, it lowers to + // xevm.blockload + + Type intElemTy = rewriter.getIntegerType(elemBitWidth); + VectorType intVecTy = + VectorType::get(valOrResVecTy.getShape(), intElemTy); + + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + Value loadOp = + xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); + if (intVecTy != valOrResVecTy) { + loadOp = + vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); + } + rewriter.replaceOp(op, loadOp); + } else { + Value dataToStore = adaptor.getData(); + if (valOrResVecTy != intVecTy) { + dataToStore = + vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); + } + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, + nullptr); + rewriter.eraseOp(op); + } + return success(); + } + + if (valOrResVecTy.getNumElements() >= 1) { + auto chipOpt = xegpu::getChipStr(op); + if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { + // the lowering for chunk load only works for pvc and bmg + return rewriter.notifyMatchFailure( + op, "The lowering is specific to pvc or bmg."); + } + } + + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + // if the size of valOrResVecTy is 1, it lowers to a scalar load/store + // operation. LLVM load/store does not support vector of size 1, so we + // need to handle this case separately. + auto scalarTy = valOrResVecTy.getElementType(); + LLVM::LoadOp loadOp; + if (valOrResVecTy.getNumElements() == 1) + loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); + else + loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } + return success(); + } +}; + class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -545,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { op, "Expected element type bit width to be multiple of 8."); elemByteSize = elemBitWidth / 8; } - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets, + elemByteSize); } } // Default memory space is global. @@ -774,9 +923,7 @@ struct ConvertXeGPUToXeVMPass if (rank < 1 || type.getNumElements() == 1) return elemType; // Otherwise, convert the vector to a flat vector type. - int64_t sum = - std::accumulate(type.getShape().begin(), type.getShape().end(), - int64_t{1}, std::multiplies<int64_t>()); + int64_t sum = llvm::product_of(type.getShape()); return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { @@ -785,6 +932,13 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); + // Convert MemDescType into flattened MemRefType for SLM + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + Type elemTy = type.getElementType(); + int numElems = type.getNumElements(); + return MemRefType::get(numElems, elemTy, AffineMap(), 3); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); @@ -879,10 +1033,30 @@ struct ConvertXeGPUToXeVMPass } return {}; }; - typeConverter.addSourceMaterialization(memrefMaterializationCast); - typeConverter.addSourceMaterialization(ui64MaterializationCast); - typeConverter.addSourceMaterialization(ui32MaterializationCast); - typeConverter.addSourceMaterialization(vectorMaterializationCast); + + // If result type of original op is single element vector and lowered type + // is scalar. This materialization cast creates a single element vector by + // broadcasting the scalar value. + auto singleElementVectorMaterializationCast = + [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType().isIntOrIndexOrFloat()) { + // If the input is a scalar, and the target type is a vector of single + // element, create a single element vector by broadcasting. + if (auto vecTy = dyn_cast<VectorType>(type)) { + if (vecTy.getNumElements() == 1) { + return vector::BroadcastOp::create(builder, loc, vecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization( + singleElementVectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); @@ -919,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns( LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( typeConverter, patterns.getContext()); + patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>, + LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>, + CreateMemDescOpPattern>(typeConverter, patterns.getContext()); patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index f405d0c..61166db 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -339,6 +339,25 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// +// ScaledExtPacked816Op +//===----------------------------------------------------------------------===// +LogicalResult ScaledExtPacked816Op::verify() { + int blockSize = getBlockSize(); + assert((blockSize == 16 || blockSize == 32) && "invalid block size"); + int firstScaleByte = getFirstScaleByte(); + if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError( + "blockSize of 16 can only have firstScaleByte be 0 or 1."); + } + if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError( + "blockSize of 32 can only have firstScaleByte be 0 or 2."); + } + + return success(); +} + +//===----------------------------------------------------------------------===// // WMMAOp //===----------------------------------------------------------------------===// LogicalResult WMMAOp::verify() { @@ -757,13 +776,13 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> { offset = numElements - 4l; } Type scaleSrcElemType = scaleSrcType.getElementType(); - auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}), - scaleSrcElemType); + auto newSrcType = + VectorType::get(ArrayRef{numElements}, scaleSrcElemType); Value newScaleSrc = vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); auto extract = vector::ExtractStridedSliceOp::create( - rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset}, - ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1}); + rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, + ArrayRef{int64_t(1)}); rewriter.modifyOpInPlace(op, [&] { op->setOperand(opIdx, extract); setOpsel(opIdx, opsel); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 68990ef..d9c097c 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } +/// Returns stride expressed in number of bytes for the given `elementStride` +/// stride encoded in number of elements of the type `mType`. +static Value computeStrideInBytes(Location loc, MemRefType mType, + Value elementStride, RewriterBase &rewriter) { + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) + .getResult(); +} + /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -static Value getStride(Location loc, MemRefType mType, Value base, - RewriterBase &rewriter) { +static Value inferStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = rewriter.getIntegerType(64); @@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base, if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); + return computeStrideInBytes( + loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); @@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands, return getTileSizes(getLoc(), getTileType(), rewriter); } -LogicalResult amx::TileLoadOp::verify() { - MemRefType memrefTy = getMemRefType(); +template <typename OpTy, + typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> || + std::is_same_v<OpTy, amx::TileStoreOp>>> +static LogicalResult tileTransferVerifier(OpTy op) { + MemRefType memrefTy = op.getMemRefType(); unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); + if (op.getIndices().size() != rank) + return op.emitOpError("requires ") << rank << " indices"; + + if (failed(verifyTileSize(op, op.getTileType()))) + return failure(); + + // Validate basic buffer properties when the stride is implicit. + if (!op.getStride()) { + if (rank < 2) + return op.emitOpError("requires at least 2D memref"); + SmallVector<int64_t> strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return op.emitOpError("requires memref with unit innermost stride"); + } + + return success(); +} + +void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res, + Value base, ValueRange indices) { + build(builder, state, res, base, indices, /*stride=*/nullptr); } +LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); return intrinsicOperands; } -LogicalResult amx::TileStoreOp::verify() { - MemRefType memrefTy = getMemRefType(); - unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector<int64_t> strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); +void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange indices, Value val) { + build(builder, state, base, indices, val, /*stride=*/nullptr); } +LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); } + SmallVector<Value> amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, @@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); intrinsicOperands.push_back(adaptor.getVal()); return intrinsicOperands; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 7e5ce26..e0a53cd 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -125,9 +125,9 @@ static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, // Use "unused attribute" marker to silence clang-tidy warning stemming from // the inability to see through "llvm::TypeSwitch". template <> -bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op, - Region *src, Region *dest, - const IRMapping &mapping) { +[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src, + Region *dest, + const IRMapping &mapping) { // If it's a valid dimension, we need to check that it remains so. if (isValidDim(op.getResult(), src)) return remainsLegalAfterInline( @@ -1032,8 +1032,8 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map, /// Simplify the map while exploiting information on the values in `operands`. // Use "unused attribute" marker to silence warning stemming from the inability // to see through the template expansion. -static void LLVM_ATTRIBUTE_UNUSED -simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) { +[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map, + ArrayRef<Value> operands) { assert(map.getNumInputs() == operands.size() && "invalid operands for map"); SmallVector<AffineExpr> newResults; newResults.reserve(map.getNumResults()); @@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, return success(*map != initialMap); } +/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form +/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`, +/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove) +/// into `replacementsMap`. If no entries were added to `replacementsMap`, +/// nothing was found. +static void shortenAddChainsContainingAll( + AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove, + AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) { + auto binOp = dyn_cast<AffineBinaryOpExpr>(e); + if (!binOp) + return; + AffineExpr lhs = binOp.getLHS(); + AffineExpr rhs = binOp.getRHS(); + if (binOp.getKind() != AffineExprKind::Add) { + shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap); + shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap); + return; + } + SmallVector<AffineExpr> toPreserve; + llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove); + AffineExpr thisTerm = rhs; + AffineExpr nextTerm = lhs; + + while (thisTerm) { + if (!ourTracker.erase(thisTerm)) { + toPreserve.push_back(thisTerm); + shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal, + replacementsMap); + } + auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm); + if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) { + thisTerm = nextTerm; + nextTerm = AffineExpr(); + } else { + thisTerm = nextBinOp.getRHS(); + nextTerm = nextBinOp.getLHS(); + } + } + if (!ourTracker.empty()) + return; + // We reverse the terms to be preserved here in order to preserve + // associativity between them. + AffineExpr newExpr = newVal; + for (AffineExpr preserved : llvm::reverse(toPreserve)) + newExpr = newExpr + preserved; + replacementsMap.insert({e, newExpr}); +} + +/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N + +/// ...` (not necessarily in order) where the set of the `x_i` is the set of +/// outputs of an `affine.delinearize_index` whos inverse is that expression, +/// replace that expression with the input of that delinearize_index op. +/// +/// `unitDimInput` is the input that was detected as the potential start to this +/// replacement chain - if it isn't the rightmost result of the delinearization, +/// this method fails. (This is intended to ensure we don't have redundant scans +/// over the same expression). +/// +/// While this currently only handles delinearizations with a constant basis, +/// that isn't a fundamental limitation. +/// +/// This is a utility function for `replaceDimOrSym` below. +static LogicalResult replaceAffineDelinearizeIndexInverseExpression( + AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, + SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) { + if (!delinOp.getDynamicBasis().empty()) + return failure(); + if (resultToReplace != delinOp.getMultiIndex().back()) + return failure(); + + MLIRContext *ctx = delinOp.getContext(); + SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr()); + for (auto [pos, dim] : llvm::enumerate(dims)) { + auto asResult = dyn_cast_if_present<OpResult>(dim); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx); + } + for (auto [pos, sym] : llvm::enumerate(syms)) { + auto asResult = dyn_cast_if_present<OpResult>(sym); + if (!asResult) + continue; + if (asResult.getOwner() == delinOp.getOperation()) + resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx); + } + if (llvm::is_contained(resToExpr, AffineExpr())) + return failure(); + + bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>); + int64_t stride = 1; + llvm::SmallDenseSet<AffineExpr, 4> expectedExprs; + // This isn't zip_equal since sometimes the delinearize basis is missing a + // size for the first result. + for (auto [binding, size] : llvm::zip( + llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) { + expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx)); + stride *= size; + } + if (resToExpr.size() != delinOp.getStaticBasis().size()) + expectedExprs.insert(resToExpr[0] * stride); + + DenseMap<AffineExpr, AffineExpr> replacements; + AffineExpr delinInExpr = isDimReplacement + ? getAffineDimExpr(dims.size(), ctx) + : getAffineSymbolExpr(syms.size(), ctx); + + for (AffineExpr e : map->getResults()) + shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements); + if (replacements.empty()) + return failure(); + + AffineMap origMap = *map; + if (isDimReplacement) + dims.push_back(delinOp.getLinearIndex()); + else + syms.push_back(delinOp.getLinearIndex()); + *map = origMap.replace(replacements, dims.size(), syms.size()); + + // Blank out dead dimensions and symbols + for (AffineExpr e : resToExpr) { + if (auto d = dyn_cast<AffineDimExpr>(e)) { + unsigned pos = d.getPosition(); + if (!map->isFunctionOfDim(pos)) + dims[pos] = nullptr; + } + if (auto s = dyn_cast<AffineSymbolExpr>(e)) { + unsigned pos = s.getPosition(); + if (!map->isFunctionOfSymbol(pos)) + syms[pos] = nullptr; + } + } + return success(); +} + /// Replace all occurrences of AffineExpr at position `pos` in `map` by the /// defining AffineApplyOp expression and operands. /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. @@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map, syms); } + if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) { + return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims, + syms); + } + auto affineApply = v.getDefiningOp<AffineApplyOp>(); if (!affineApply) return failure(); @@ -2460,6 +2600,65 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) { return success(folded); } +/// Returns constant trip count in trivial cases. +static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) { + int64_t step = forOp.getStepAsInt(); + if (!forOp.hasConstantBounds() || step <= 0) + return std::nullopt; + int64_t lb = forOp.getConstantLowerBound(); + int64_t ub = forOp.getConstantUpperBound(); + return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; +} + +/// Fold the empty loop. +static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) { + if (!llvm::hasSingleElement(*forOp.getBody())) + return {}; + if (forOp.getNumResults() == 0) + return {}; + std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp); + if (tripCount == 0) { + // The initial values of the iteration arguments would be the op's + // results. + return forOp.getInits(); + } + SmallVector<Value, 4> replacements; + auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator()); + auto iterArgs = forOp.getRegionIterArgs(); + bool hasValDefinedOutsideLoop = false; + bool iterArgsNotInOrder = false; + for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { + Value val = yieldOp.getOperand(i); + BlockArgument *iterArgIt = llvm::find(iterArgs, val); + // TODO: It should be possible to perform a replacement by computing the + // last value of the IV based on the bounds and the step. + if (val == forOp.getInductionVar()) + return {}; + if (iterArgIt == iterArgs.end()) { + // `val` is defined outside of the loop. + assert(forOp.isDefinedOutsideOfLoop(val) && + "must be defined outside of the loop"); + hasValDefinedOutsideLoop = true; + replacements.push_back(val); + } else { + unsigned pos = std::distance(iterArgs.begin(), iterArgIt); + if (pos != i) + iterArgsNotInOrder = true; + replacements.push_back(forOp.getInits()[pos]); + } + } + // Bail out when the trip count is unknown and the loop returns any value + // defined outside of the loop or any iterArg out of order. + if (!tripCount.has_value() && + (hasValDefinedOutsideLoop || iterArgsNotInOrder)) + return {}; + // Bail out when the loop iterates more than once and it returns any iterArg + // out of order. + if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder) + return {}; + return llvm::to_vector_of<OpFoldResult>(replacements); +} + /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); @@ -2491,79 +2690,30 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { return success(); } -namespace { -/// Returns constant trip count in trivial cases. -static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) { - int64_t step = forOp.getStepAsInt(); - if (!forOp.hasConstantBounds() || step <= 0) - return std::nullopt; - int64_t lb = forOp.getConstantLowerBound(); - int64_t ub = forOp.getConstantUpperBound(); - return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step; +/// Returns true if the affine.for has zero iterations in trivial cases. +static bool hasTrivialZeroTripCount(AffineForOp op) { + return getTrivialConstantTripCount(op) == 0; } -/// This is a pattern to fold trivially empty loop bodies. -/// TODO: This should be moved into the folding hook. -struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> { - using OpRewritePattern<AffineForOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(AffineForOp forOp, - PatternRewriter &rewriter) const override { - // Check that the body only contains a yield. - if (!llvm::hasSingleElement(*forOp.getBody())) - return failure(); - if (forOp.getNumResults() == 0) - return success(); - std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp); - if (tripCount == 0) { - // The initial values of the iteration arguments would be the op's - // results. - rewriter.replaceOp(forOp, forOp.getInits()); - return success(); - } - SmallVector<Value, 4> replacements; - auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator()); - auto iterArgs = forOp.getRegionIterArgs(); - bool hasValDefinedOutsideLoop = false; - bool iterArgsNotInOrder = false; - for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) { - Value val = yieldOp.getOperand(i); - auto *iterArgIt = llvm::find(iterArgs, val); - // TODO: It should be possible to perform a replacement by computing the - // last value of the IV based on the bounds and the step. - if (val == forOp.getInductionVar()) - return failure(); - if (iterArgIt == iterArgs.end()) { - // `val` is defined outside of the loop. - assert(forOp.isDefinedOutsideOfLoop(val) && - "must be defined outside of the loop"); - hasValDefinedOutsideLoop = true; - replacements.push_back(val); - } else { - unsigned pos = std::distance(iterArgs.begin(), iterArgIt); - if (pos != i) - iterArgsNotInOrder = true; - replacements.push_back(forOp.getInits()[pos]); - } - } - // Bail out when the trip count is unknown and the loop returns any value - // defined outside of the loop or any iterArg out of order. - if (!tripCount.has_value() && - (hasValDefinedOutsideLoop || iterArgsNotInOrder)) - return failure(); - // Bail out when the loop iterates more than once and it returns any iterArg - // out of order. - if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder) - return failure(); - rewriter.replaceOp(forOp, replacements); - return success(); +LogicalResult AffineForOp::fold(FoldAdaptor adaptor, + SmallVectorImpl<OpFoldResult> &results) { + bool folded = succeeded(foldLoopBounds(*this)); + folded |= succeeded(canonicalizeLoopBounds(*this)); + if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) { + // The initial values of the loop-carried variables (iter_args) are the + // results of the op. But this must be avoided for an affine.for op that + // does not return any results. Since ops that do not return results cannot + // be folded away, we would enter an infinite loop of folds on the same + // affine.for op. + results.assign(getInits().begin(), getInits().end()); + folded = true; } -}; -} // namespace - -void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<AffineForEmptyLoopFolder>(context); + SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this); + if (!foldResults.empty()) { + results.assign(foldResults); + folded = true; + } + return success(folded); } OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { @@ -2606,27 +2756,6 @@ void AffineForOp::getSuccessorRegions( regions.push_back(RegionSuccessor(getResults())); } -/// Returns true if the affine.for has zero iterations in trivial cases. -static bool hasTrivialZeroTripCount(AffineForOp op) { - return getTrivialConstantTripCount(op) == 0; -} - -LogicalResult AffineForOp::fold(FoldAdaptor adaptor, - SmallVectorImpl<OpFoldResult> &results) { - bool folded = succeeded(foldLoopBounds(*this)); - folded |= succeeded(canonicalizeLoopBounds(*this)); - if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) { - // The initial values of the loop-carried variables (iter_args) are the - // results of the op. But this must be avoided for an affine.for op that - // does not return any results. Since ops that do not return results cannot - // be folded away, we would enter an infinite loop of folds on the same - // affine.for op. - results.assign(getInits().begin(), getInits().end()); - folded = true; - } - return success(folded); -} - AffineBound AffineForOp::getLowerBound() { return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap()); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index cd216ef..4743941 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1357,7 +1357,7 @@ bool mlir::affine::isValidLoopInterchangePermutation( /// Returns true if `loops` is a perfectly nested loop nest, where loops appear /// in it from outermost to innermost. -bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] bool mlir::affine::isPerfectlyNested(ArrayRef<AffineForOp> loops) { assert(!loops.empty() && "no loops provided"); @@ -1920,8 +1920,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, return copyNestRoot; } -static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED -emitRemarkForBlock(Block &block) { +[[maybe_unused]] static InFlightDiagnostic emitRemarkForBlock(Block &block) { return block.getParentOp()->emitRemark(); } diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index b1fc9aa..f54baff 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -351,9 +351,9 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values, Value one = ConstantOp::create(builder, loc, resultType, builder.getOneAttr(resultType)); ArithBuilder arithBuilder(builder, loc); - return std::accumulate( - values.begin(), values.end(), one, - [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); }); + return llvm::accumulate(values, one, [&arithBuilder](Value acc, Value v) { + return arithBuilder.mul(acc, v); + }); } /// Map strings to float types. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index a50ddbe..bc17990 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -41,28 +41,37 @@ namespace bufferization { using namespace mlir; -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { - func::ReturnOp returnOp; +/// Get all the ReturnOp in the funcOp. +static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) { + SmallVector<func::ReturnOp> returnOps; for (Block &b : funcOp.getBody()) { if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; + returnOps.push_back(candidateOp); } } - return returnOp; + return returnOps; } -/// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null<func::FuncOp>( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +/// Get the operands at the specified position for all returnOps. +static SmallVector<Value> +getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) { + return llvm::map_to_vector(returnOps, [&](func::ReturnOp returnOp) { + return returnOp.getOperand(pos); + }); +} + +/// Check if all given values are the same buffer as the block argument (modulo +/// cast ops). +static bool operandsEqualFuncArgument(ArrayRef<Value> operands, + BlockArgument argument) { + for (Value val : operands) { + while (auto castOp = val.getDefiningOp<memref::CastOp>()) + val = castOp.getSource(); + + if (val != argument) + return false; + } + return true; } LogicalResult @@ -72,48 +81,55 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap; // Collect the mapping of functions to their call sites. module.walk([&](func::CallOp callOp) { - if (func::FuncOp calledFunc = getCalledFunction(callOp)) { - callerMap[calledFunc].insert(callOp); + if (func::FuncOp calledFunc = + dyn_cast_or_null<func::FuncOp>(callOp.resolveCallable())) { + if (!calledFunc.isPublic() && !calledFunc.isExternal()) + callerMap[calledFunc].insert(callOp); } }); for (auto funcOp : module.getOps<func::FuncOp>()) { - if (funcOp.isExternal()) + if (funcOp.isExternal() || funcOp.isPublic()) continue; - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - // TODO: Support functions with multiple blocks. - if (!returnOp) + SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp); + if (returnOps.empty()) continue; // Compute erased results. - SmallVector<Value> newReturnValues; - BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults()); + size_t numReturnOps = returnOps.size(); + size_t numReturnValues = funcOp.getFunctionType().getNumResults(); + SmallVector<SmallVector<Value>> newReturnValues(numReturnOps); + BitVector erasedResultIndices(numReturnValues); DenseMap<int64_t, int64_t> resultToArgs; - for (const auto &it : llvm::enumerate(returnOp.getOperands())) { + for (size_t i = 0; i < numReturnValues; ++i) { bool erased = false; + SmallVector<Value> returnOperands = + getReturnOpsOperandInPos(returnOps, i); for (BlockArgument bbArg : funcOp.getArguments()) { - Value val = it.value(); - while (auto castOp = val.getDefiningOp<memref::CastOp>()) - val = castOp.getSource(); - - if (val == bbArg) { - resultToArgs[it.index()] = bbArg.getArgNumber(); + if (operandsEqualFuncArgument(returnOperands, bbArg)) { + resultToArgs[i] = bbArg.getArgNumber(); erased = true; break; } } if (erased) { - erasedResultIndices.set(it.index()); + erasedResultIndices.set(i); } else { - newReturnValues.push_back(it.value()); + for (auto [newReturnValue, operand] : + llvm::zip(newReturnValues, returnOperands)) { + newReturnValue.push_back(operand); + } } } // Update function. if (failed(funcOp.eraseResults(erasedResultIndices))) return failure(); - returnOp.getOperandsMutable().assign(newReturnValues); + + for (auto [returnOp, newReturnValue] : + llvm::zip(returnOps, newReturnValues)) + returnOp.getOperandsMutable().assign(newReturnValue); // Update function calls. for (func::CallOp callOp : callerMap[funcOp]) { diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 19eba6b..b5f8dda 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2460,8 +2460,7 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed, << dDim << ")"; scales[i] = eDim / dDim; } - if (std::accumulate(scales.begin(), scales.end(), 1, - std::multiplies<int64_t>()) != warpSize) + if (llvm::product_of(scales) != warpSize) return op->emitOpError() << "incompatible distribution dimensions from " << expandedVecType << " to " << distributedVecType << " with warp size = " << warpSize; diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt index 70a9c77..ec68acf 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRGPUPipelines GPUToNVVMPipeline.cpp + GPUToXeVMPipeline.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU @@ -11,12 +12,17 @@ add_mlir_dialect_library(MLIRGPUPipelines MLIRTransforms MLIRLinalgTransforms MLIRAffineToStandard + MLIRGPUToLLVMSPV MLIRGPUToNVVMTransforms MLIRIndexToLLVM MLIRMathToLLVM + MLIRMathToXeVM MLIRNVGPUToNVVM MLIRNVVMToLLVM MLIRReconcileUnrealizedCasts MLIRSCFToControlFlow MLIRVectorToSCF + MLIRXeGPUTransforms + MLIRXeGPUToXeVM + MLIRXeVMToLLVM ) diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp new file mode 100644 index 0000000..1a1485b --- /dev/null +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -0,0 +1,139 @@ +//===- GPUToXeVMPipeline.cpp - Lowering pipeline to XeVM/LLVM -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing the lowering to XeVM as a generally +// usable sink pass. If XeGPU ops are used, it expects the MLIR code to have +// XeGPU ops already embedded in gpu code. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/MathToXeVM/MathToXeVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Target/LLVM/XeVM/Target.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +//===----------------------------------------------------------------------===// +// Pre-GPU common pipeline for both Host and GPU. +//===----------------------------------------------------------------------===// +void buildPreGPUCommonPassPipeline( + OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) { + // builtin.module scope passes. + pm.addPass(createCSEPass()); + pm.addPass(createConvertVectorToSCFPass()); + { + GpuXeVMAttachTargetOptions xevmTargetOptions; + xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher; + xevmTargetOptions.triple = options.zebinTriple; + xevmTargetOptions.chip = options.zebinChip; + xevmTargetOptions.optLevel = options.optLevel; + xevmTargetOptions.cmdOptions = options.cmdOptions; + pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions)); + } + pm.addPass(createLowerAffinePass()); + pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass()); +} + +//===----------------------------------------------------------------------===// +// GPUModule-specific stuff. +//===----------------------------------------------------------------------===// +void buildGPUPassPipeline(OpPassManager &pm, + const mlir::gpu::GPUToXeVMPipelineOptions &options) { + if (options.xegpuOpLevel == "workgroup") { + pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute()); + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); + pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking()); + pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); + } + if (options.xegpuOpLevel == "subgroup" || + options.xegpuOpLevel == "workgroup") { + pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout()); + pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute()); + pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); + pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize()); + } + pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToXeVM()); + pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass()); + { + ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions; + gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex; + pm.addNestedPass<gpu::GPUModuleOp>( + createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions)); + } + pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass()); + pm.addNestedPass<gpu::GPUModuleOp>(createReconcileUnrealizedCastsPass()); +} + +//===----------------------------------------------------------------------===// +// Post-GPU pipeline for both Host and GPU. +//===----------------------------------------------------------------------===// +void buildPostGPUCommonPassPipeline( + OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) { + // builtin.module scope passes. + pm.addPass(createSCFToControlFlowPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + { + GpuToLLVMConversionPassOptions gpuToLLVMOptions; + gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv; + gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv; + pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions)); + } + pm.addPass(createLowerAffinePass()); + pm.addPass(createConvertToLLVMPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + // gpu-module-to-binary + { + GpuModuleToBinaryPassOptions gpuToModuleBinOptions; + gpuToModuleBinOptions.compilationTarget = options.binaryFormat; + gpuToModuleBinOptions.cmdOptions = options.cmdOptions; + pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions)); + } +} +} // namespace + +void mlir::gpu::buildLowerToXeVMPassPipeline( + OpPassManager &pm, const GPUToXeVMPipelineOptions &options) { + // Pre-GPU common pipelines. + buildPreGPUCommonPassPipeline(pm, options); + + // GPUModule-specific stuff. + buildGPUPassPipeline(pm, options); + + // Post-GPU pipeline for both Host and GPU. + buildPostGPUCommonPassPipeline(pm, options); +} + +void mlir::gpu::registerGPUToXeVMPipeline() { + PassPipelineRegistration<GPUToXeVMPipelineOptions>( + "gpu-lower-to-xevm-pipeline", + "The default GPU to XeVM lowering pipeline. It starts by lowering GPU " + "code to the " + "specified compilation target (default is fatbin) then lowers the host " + "code.", + buildLowerToXeVMPassPipeline); +} diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp index 88f531f..572b746 100644 --- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp +++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> @@ -118,8 +119,7 @@ bool WarpDistributionPattern::delinearizeLaneId( return false; sizes.push_back(large / small); } - if (std::accumulate(sizes.begin(), sizes.end(), 1, - std::multiplies<int64_t>()) != warpSize) + if (llvm::product_of(sizes) != warpSize) return false; AffineExpr s0, s1; diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ec581ac..cc66fac 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -8,11 +8,13 @@ add_mlir_dialect_library(MLIRLLVMDialect IR/LLVMMemorySlot.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp + IR/LLVMDialectBytecode.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR DEPENDS + MLIRLLVMDialectBytecodeIncGen MLIRLLVMOpsIncGen MLIRLLVMTypesIncGen MLIRLLVMIntrinsicOpsIncGen diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5d08ccc..3eae67f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -29,6 +29,8 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/Error.h" +#include "LLVMDialectBytecode.h" + #include <numeric> #include <optional> @@ -2824,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() { return success(); } +// Folding for shufflevector op when v1 is single element 1D vector +// and the mask is a single zero. OpFoldResult will be v1 in this case. +OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) { + // Check if operand 0 is a single element vector. + auto vecType = llvm::dyn_cast<VectorType>(getV1().getType()); + if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1) + return {}; + // Check if the mask is a single zero. + // Note: The mask is guaranteed to be non-empty. + if (getMask().size() != 1 || getMask()[0] != 0) + return {}; + return getV1(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// @@ -4237,6 +4253,7 @@ void LLVMDialect::initialize() { // Support unknown operations because not all LLVM operations are registered. allowUnknownOperations(); declarePromisedInterface<DialectInlinerInterface, LLVMDialect>(); + detail::addBytecodeInterface(this); } #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp new file mode 100644 index 0000000..41d1f80 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.cpp @@ -0,0 +1,154 @@ +//===- LLVMDialectBytecode.cpp - LLVM Bytecode Implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMDialectBytecode.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <type_traits> + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +// Provide some forward declarations of the functions that will be generated by +// the include below. +static void write(DIExpressionElemAttr attribute, + DialectBytecodeWriter &writer); +static LogicalResult writeAttribute(Attribute attribute, + DialectBytecodeWriter &writer); + +//===--------------------------------------------------------------------===// +// Optional ArrayRefs +// +// Note that both the writer and reader functions consider attributes to be +// optional. This is because the attribute may be present or empty. +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalArrayRef(DialectBytecodeWriter &writer, + ArrayRef<EntryTy> storage) { + if (storage.empty()) { + writer.writeOwnedBool(false); + return; + } + + writer.writeOwnedBool(true); + writer.writeList(storage, [&](EntryTy val) { + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + (void)writer.writeOptionalAttribute(val); + } else if constexpr (std::is_integral_v<EntryTy>) { + (void)writer.writeVarInt(val); + } else { + static_assert(true, "EntryTy not supported"); + } + }); +} + +template <class EntryTy> +static LogicalResult readOptionalArrayRef(DialectBytecodeReader &reader, + SmallVectorImpl<EntryTy> &storage) { + bool isPresent = false; + if (failed(reader.readBool(isPresent))) + return failure(); + // Nothing to do here, the array is empty. + if (!isPresent) + return success(); + + auto readEntry = [&]() -> FailureOr<EntryTy> { + EntryTy temp; + if constexpr (std::is_base_of_v<Attribute, EntryTy>) { + if (succeeded(reader.readOptionalAttribute(temp))) + return temp; + } else if constexpr (std::is_integral_v<EntryTy>) { + if (succeeded(reader.readVarInt(temp))) + return temp; + } else { + static_assert(true, "EntryTy not supported"); + } + return failure(); + }; + + return reader.readList(storage, readEntry); +} + +//===--------------------------------------------------------------------===// +// Optional integral types +//===--------------------------------------------------------------------===// + +template <class EntryTy> +static void writeOptionalInt(DialectBytecodeWriter &writer, + std::optional<EntryTy> storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + EntryTy val = storage.value_or(0); + writer.writeVarIntWithFlag(val, storage.has_value()); +} + +template <class EntryTy> +static LogicalResult readOptionalInt(DialectBytecodeReader &reader, + std::optional<EntryTy> &storage) { + static_assert(std::is_integral_v<EntryTy>, + "EntryTy must be an integral type"); + uint64_t result = 0; + bool flag = false; + if (failed(reader.readVarIntWithFlag(result, flag))) + return failure(); + if (flag) + storage = static_cast<EntryTy>(result); + else + storage = std::nullopt; + return success(); +} + +//===--------------------------------------------------------------------===// +// Tablegen generated bytecode functions +//===--------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialectBytecode.cpp.inc" + +//===--------------------------------------------------------------------===// +// LLVMDialectBytecodeInterface +//===--------------------------------------------------------------------===// + +/// This class implements the bytecode interface for the LLVM dialect. +struct LLVMDialectBytecodeInterface : public BytecodeDialectInterface { + LLVMDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + // Attributes + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } + + // Types + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } +}; +} // namespace + +void LLVM::detail::addBytecodeInterface(LLVMDialect *dialect) { + dialect->addInterfaces<LLVMDialectBytecodeInterface>(); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h new file mode 100644 index 0000000..1a17cb4 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialectBytecode.h @@ -0,0 +1,27 @@ +//===- LLVMDialectBytecode.h - LLVM Bytecode Implementation -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the LLVM dialect bytecode +// implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H +#define LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H + +namespace mlir::LLVM { +class LLVMDialect; + +namespace detail { +/// Add the interfaces necessary for encoding the LLVM dialect components in +/// bytecode. +void addBytecodeInterface(LLVMDialect *dialect); +} // namespace detail +} // namespace mlir::LLVM + +#endif // LIB_MLIR_DIALECT_LLVM_IR_LLVMDIALECTBYTECODE_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 01a16ce..ac35eea 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -134,10 +134,10 @@ static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams, /// These are unused for now. /// TODO: Move over to these once more types have been migrated to TypeDef. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 5edcc40b..2a8c330 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF4x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f32x2 to f4x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() { " attribute"); } + // Validate layout combinations. According to the operation description, most + // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 + // can use other layout combinations. + bool isM8N8K4_F16 = + (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && + getMultiplicandAPtxType() == MMATypes::f16); + + if (!isM8N8K4_F16) { + // For all other shapes/types, layoutA must be row and layoutB must be col + if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { + return emitOpError("requires layoutA = #nvvm.mma_layout<row> and " + "layoutB = #nvvm.mma_layout<col> for shape <") + << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] + << "> with element types " + << stringifyEnum(*getMultiplicandAPtxType()) << " and " + << stringifyEnum(*getMultiplicandBPtxType()) + << ". Only m8n8k4 with f16 supports other layouts."; + } + } + return success(); } @@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair +ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getA())); + args.push_back(mt.lookupValue(op.getB())); + + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = + hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite; + + return {intId, std::move(args)}; +} + #define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite @@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result, } } +/// Verify the range attribute satisfies LLVM ConstantRange constructor +/// requirements for NVVM SpecialRangeableRegisterOp. +static LogicalResult +verifyConstantRangeAttr(Operation *op, + std::optional<LLVM::ConstantRangeAttr> rangeAttr) { + if (!rangeAttr) + return success(); + + const llvm::APInt &lower = rangeAttr->getLower(); + const llvm::APInt &upper = rangeAttr->getUpper(); + + // Check LLVM ConstantRange constructor condition + if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) { + unsigned bitWidth = lower.getBitWidth(); + llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth); + llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth); + return op->emitOpError( + "invalid range attribute: Lower == Upper, but they aren't min (") + << llvm::toString(minVal, 10, false) << ") or max (" + << llvm::toString(maxVal, 10, false) + << ") value! This is an invalid constant range."; + } + + return success(); +} + static llvm::Value *getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder) { return builder.CreateBitCast(arg, diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 17371ec..6d54bb6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -180,6 +181,15 @@ void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) { // ROCDLDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// +namespace { +struct ROCDLInlinerInterface final : DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } +}; +} // namespace + // TODO: This should be the llvm.rocdl dialect once this is supported. void ROCDLDialect::initialize() { addOperations< @@ -194,6 +204,7 @@ void ROCDLDialect::initialize() { // Support unknown operations because not all ROCDL operations are registered. allowUnknownOperations(); + addInterfaces<ROCDLInlinerInterface>(); declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index c477c6c..dcc1ef9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody( Value yielded = getSourceSkipUnary(terminator->getOperand(0)); Operation *reductionOp = yielded.getDefiningOp(); - if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + if (!reductionOp || reductionOp->getNumResults() != 1 || + reductionOp->getNumOperands() != 2) { errs << "expected reduction op to be binary"; return false; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 59013a2..cbc565b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() { SmallVector<int64_t> PackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getDestType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index dd9b4c2..9a8a63e 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -576,6 +576,86 @@ transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( // FuseOp //===----------------------------------------------------------------------===// +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, loopTypes, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef<int64_t> staticTileSizes, + ArrayRef<int64_t> staticTileInterchange, + bool applyCleanup, bool useForall) { + return build( + builder, result, + /*target=*/target, + /*mixedTileSizes=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + /*mixedTileInterchange=*/ + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)), + applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + // Loop types are automaticaly splat by the callee, setting up one is + // enough. + SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>()); + build(builder, result, loopTypes, target, mixedTileSizes, + mixedTileInterchange, applyCleanup, useForall); +} + +void transform::FuseOp::build(OpBuilder &builder, OperationState &result, + TypeRange loopTypes, Value target, + ArrayRef<OpFoldResult> mixedTileSizes, + ArrayRef<OpFoldResult> mixedTileInterchange, + bool applyCleanup, bool useForall) { + SmallVector<int64_t> staticTileSizes; + SmallVector<Value> dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); + SmallVector<int64_t> staticTileInterchange; + SmallVector<Value> dynamicTileInterchange; + dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange, + staticTileInterchange); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, + // horrible bugs ensue. + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + auto staticTileInterchangeAttr = + builder.getDenseI64ArrayAttr(staticTileInterchange); + unsigned numExpectedLoops = + useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0); + SmallVector<Type> resultTypes; + resultTypes.reserve(numExpectedLoops); + assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && + "expected one loop type or as many as loops"); + if (loopTypes.size() == 1) + resultTypes.append(numExpectedLoops, loopTypes[0]); + else + llvm::append_range(resultTypes, loopTypes); + build(builder, result, /*transformed=*/target.getType(), + /*loops=*/resultTypes, + /*target=*/target, + /*tile_sizes=*/dynamicTileSizes, + /*tile_interchange=*/dynamicTileInterchange, + /*static_tile_sizes=*/staticTileSizesAttr, + /*static_tile_interchange=*/staticTileInterchangeAttr, + /*apply_cleanup=*/applyCleanup, + /*use_forall=*/useForall); +} + /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template <typename Range> @@ -630,13 +710,25 @@ DiagnosedSilenceableFailure transform::FuseOp::apply(transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - SmallVector<int64_t> tileSizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - SmallVector<int64_t> tileInterchange = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); + auto transformOp = cast<TransformOpInterface>(getOperation()); + + SmallVector<int64_t> tileSizes; + DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileSizes(), tileSizes); + if (!status.succeeded()) + return status; + SmallVector<int64_t> tileInterchange; + status = reifyMixedParamAndHandleResults( + state, transformOp, getMixedTileInterchange(), tileInterchange); + if (!status.succeeded()) + return status; scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; + bool useForall = getUseForall(); + tilingOptions.setLoopType(useForall + ? scf::SCFTilingOptions::LoopType::ForallOp + : scf::SCFTilingOptions::LoopType::ForOp); SmallVector<OpFoldResult> tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); @@ -652,9 +744,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tileAndFuseOptions.cleanupPatterns = std::move(patterns); } + size_t numLoops = + useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0); LogicalResult result = applyTilingToAll( - rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizes.size() - llvm::count(tileSizes, 0), transformResults, + rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops, + transformResults, [&](TilingInterface tilingInterfaceOp) -> FailureOr<scf::SCFTileAndFuseResult> { return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, @@ -665,24 +759,51 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, } LogicalResult transform::FuseOp::verify() { - SmallVector<int64_t> permutation = - extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); - auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); - if (!std::is_permutation(sequence.begin(), sequence.end(), - permutation.begin(), permutation.end())) { - return emitOpError() << "expects interchange to be a permutation, found " - << getTileInterchange(); + auto iterspace_rank = getStaticTileSizes().size(); + ArrayRef<int64_t> permutation = getStaticTileInterchange(); + if (permutation.size() > iterspace_rank) + return emitOpError() + << "interchange length exceeds iteration space dimensions (" + << iterspace_rank << "), found " << getTileInterchange(); + SmallVector<bool> seen(iterspace_rank, false); + for (int64_t v : permutation) { + if (!ShapedType::isDynamic(v)) { + if (v < 0 || v >= static_cast<int64_t>(iterspace_rank)) + return emitOpError() << "expects interchange values to be in range [0, " + << iterspace_rank << "), found: " << v; + if (seen[v]) + return emitOpError() << "found duplicate interchange value: " << v; + seen[v] = true; + } } - SmallVector<int64_t> sizes = - extractFromIntegerArrayAttr<int64_t>(getTileSizes()); - size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); + ArrayRef<int64_t> sizes = getStaticTileSizes(); + size_t numExpectedLoops = + getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0); if (numExpectedLoops != getNumResults() - 1) return emitOpError() << "expects " << numExpectedLoops << " loop results"; return success(); } +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() { + return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext()); +} + +SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() { + return getMixedValues(getStaticTileInterchange(), getTileInterchange(), + getContext()); +} + +void transform::FuseOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getTileSizesMutable(), effects); + onlyReadsHandle(getTileInterchangeMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // FuseIntoContainingOp //===----------------------------------------------------------------------===// @@ -2336,26 +2457,24 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, } // Set options. - TilingInterface paddedOp; PadTilingInterfaceOptions options; options.setPaddingValues(paddingValues) .setPaddingSizes(getMixedPaddingSizes()) .setPadToMultipleOf(getPadToMultipleOf()); - // Apply padding. - SmallVector<tensor::PadOp> newPadOps; - FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp( - rewriter, cast<TilingInterface>(targetOp.getOperation()), options, - newPadOps); - if (failed(maybePaddedOp)) { + auto maybePadOps = rewriteAsPaddedOp( + rewriter, cast<TilingInterface>(targetOp.getOperation()), options); + if (failed(maybePadOps)) { auto diag = emitSilenceableError() << "failed to pad op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } + const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value(); // Set transform results. - paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation())); - padOps.append(newPadOps.begin(), newPadOps.end()); + paddedOps.push_back(paddedOp); + padOps.append(paddedOperands.begin(), paddedOperands.end()); + rewriter.replaceOp(targetOp.getOperation(), slicedResults); } results.set(cast<OpResult>(getPadded()), paddedOps); @@ -2903,10 +3022,10 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } if (dynamicPointParseResult.has_value()) { - Type ChunkSizesType; + Type chunkSizesType; if (failed(*dynamicPointParseResult) || parser.parseComma() || - parser.parseType(ChunkSizesType) || - parser.resolveOperand(dynamicChunkSizes, ChunkSizesType, + parser.parseType(chunkSizesType) || + parser.resolveOperand(dynamicChunkSizes, chunkSizesType, result.operands)) { return failure(); } @@ -3278,9 +3397,9 @@ void transform::ContinuousTileSizesOp::getEffects( } static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, - Type targetType, Type tile_sizes, + Type targetType, Type tileSizes, Type) { - printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes}); + printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes}); } static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 0956c5d..3e787a2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) { /// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. -SmallVector<OpFoldResult> linalg::computePaddedShape( - RewriterBase &rewriter, TypedValue<RankedTensorType> v, - AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes, - const PadTilingInterfaceOptions &options) { +SmallVector<OpFoldResult> +linalg::computePaddedShape(OpBuilder &builder, TypedValue<RankedTensorType> v, + AffineMap indexingMap, + ArrayRef<OpFoldResult> indexingSizes, + const PadTilingInterfaceOptions &options) { Location loc = v.getLoc(); SmallVector<OpFoldResult> paddedShape; auto tensorType = cast<RankedTensorType>(v.getType()); @@ -109,7 +110,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( // "Full-rank" padding specification. SmallVector<OpFoldResult> paddingSizes = - getFullRankPaddingSizes(rewriter, indexingSizes, options); + getFullRankPaddingSizes(builder, indexingSizes, options); // For each dimension in the operand's shape, iterate over indexingSizes and // add the various term contributions. @@ -147,28 +148,27 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; - bindDims(rewriter.getContext(), d0); - bindSymbols(rewriter.getContext(), s0); + bindDims(builder.getContext(), d0); + bindSymbols(builder.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); paddingDimOfr = affine::makeComposedFoldedAffineApply( - rewriter, loc, composedMap, - {indexingSizes[paddingDim], paddingSize}, + builder, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); } else { // Otherwise just set to paddingSize. paddingDimOfr = affine::makeComposedFoldedAffineApply( - rewriter, loc, projectedMap, paddingSize); + builder, loc, projectedMap, paddingSize); } // Adjust for the maximum accessed index, which is (paddingSize - 1) * // multiplier. AffineExpr d0; - bindDims(rewriter.getContext(), d0); + bindDims(builder.getContext(), d0); int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( - rewriter, loc, subtractMap, {paddingDimOfr}); + builder, loc, subtractMap, {paddingDimOfr}); terms.push_back(maxAccessIdx); LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); @@ -177,19 +177,19 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( // If there are no terms, just return the dim. if (terms.empty()) { paddedShape[resultIndex] = - createFoldedDimOp(rewriter, loc, v, resultIndex); + createFoldedDimOp(builder, loc, v, resultIndex); continue; } // Sum individual terms' contributions. SmallVector<AffineExpr> dims(terms.size()); - bindDimsList(rewriter.getContext(), MutableArrayRef{dims}); + bindDimsList(builder.getContext(), MutableArrayRef{dims}); AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; // Add 1 to the maximum accessed index and get the final padded size. - OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( - rewriter, loc, sumExpr + 1, terms); + OpFoldResult paddedDimOfr = + affine::makeComposedFoldedAffineApply(builder, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } @@ -198,7 +198,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( FailureOr<SmallVector<OpFoldResult>> linalg::computeIndexingMapOpInterfacePaddedShape( - RewriterBase &rewriter, OpOperand &operandToPad, + OpBuilder &builder, OpOperand &operandToPad, ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) { auto transferOp = llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner()); @@ -206,9 +206,9 @@ linalg::computeIndexingMapOpInterfacePaddedShape( return failure(); // clang-format off - assert(llvm::all_of(iterationDomain, [&rewriter](Range r) { - return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) && - r.stride == OpFoldResult(rewriter.getIndexAttr(1)); + assert(llvm::all_of(iterationDomain, [&builder](Range r) { + return r.offset == OpFoldResult(builder.getIndexAttr(0)) && + r.stride == OpFoldResult(builder.getIndexAttr(1)); }) && "expected 0-offset 1-stride loop ranges"); // clang-format on SmallVector<OpFoldResult> loopUpperBounds; @@ -218,13 +218,13 @@ linalg::computeIndexingMapOpInterfacePaddedShape( AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad); return computePaddedShape( - rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()), + builder, cast<TypedValue<RankedTensorType>>(operandToPad.get()), indexingMap, loopUpperBounds, options); } /// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding /// Value. -static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, +static Value padOperand(OpBuilder &builder, TilingInterface opToPad, TypedValue<RankedTensorType> v, ArrayRef<OpFoldResult> paddedShape, Attribute paddingValueAttr) { @@ -232,15 +232,15 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, if (auto complexTy = dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) { if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) { - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(), complexTy, complexAttr); } } else if (isa<ub::PoisonAttr>(paddingValueAttr)) { - paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), + paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(), getElementTypeOrSelf(v.getType())); } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) { paddingValue = - arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); + arith::ConstantOp::create(builder, opToPad.getLoc(), typedAttr); } assert(paddingValue && "failed to create value from padding attribute"); @@ -259,49 +259,48 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, RankedTensorType::get(tensorShape, getElementTypeOrSelf(v)); LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: " << paddedTensorType); - return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v, + return makeComposedPadHighOp(builder, opToPad.getLoc(), paddedTensorType, v, paddingValue, /*nofold=*/false, dynDims); } -FailureOr<TilingInterface> linalg::rewriteAsPaddedOp( - RewriterBase &rewriter, TilingInterface opToPad, - const PadTilingInterfaceOptions &constOptions, - SmallVector<tensor::PadOp> &padOps, +FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp( + OpBuilder &builder, TilingInterface toPad, + PadTilingInterfaceOptions options, const PadSizeComputationFunction &computePaddingSizeFun) { - LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); + LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n"); + SmallVector<tensor::PadOp> padOps; + Location loc = toPad.getLoc(); - Location loc = opToPad.getLoc(); - PadTilingInterfaceOptions options(constOptions); // Allow inference of pad values if they are not explicitly specified. // TODO: be mindful about the value depending on the actual operation. if (options.paddingValues.empty()) { - SmallVector<Type> types(opToPad->getOperandTypes()); - llvm::append_range(types, opToPad->getResultTypes()); + SmallVector<Type> types(toPad->getOperandTypes()); + llvm::append_range(types, toPad->getResultTypes()); for (Type t : types) { options.paddingValues.push_back( - rewriter.getZeroAttr(getElementTypeOrSelf(t))); + builder.getZeroAttr(getElementTypeOrSelf(t))); } } - if (llvm::any_of(opToPad->getOperands(), + if (llvm::any_of(toPad->getOperands(), [](Value v) { return isa<MemRefType>(v.getType()); })) { - return rewriter.notifyMatchFailure(opToPad, - "expected operation on tensors"); + LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n"); + return failure(); } - OpBuilder::InsertionGuard g(rewriter); - // Set IP after opToPad because we also take the dims of opToPad's output. - rewriter.setInsertionPointAfter(opToPad); + OpBuilder::InsertionGuard g(builder); + // Set IP after toPad because we also take the dims of toPad's output. + builder.setInsertionPointAfter(toPad); // 1. Get the loopUpperBounds from the TilingInterface. - SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter); + SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder); // 2. For each operand. SmallVector<Value> newOperands; - newOperands.reserve(opToPad->getNumOperands()); - for (OpOperand &opOperand : opToPad->getOpOperands()) { + newOperands.reserve(toPad->getNumOperands()); + for (OpOperand &opOperand : toPad->getOpOperands()) { Value operand = opOperand.get(); - LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n"); + LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n"); // 2.a. Skip scalar-like operands. Type operandType = operand.getType(); @@ -311,30 +310,31 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp( newOperands.push_back(operand); continue; } + // 2.a. Compute padded shape. FailureOr<SmallVector<OpFoldResult>> maybePaddedShape = - computePaddingSizeFun(rewriter, opOperand, iterationDomain, options); + computePaddingSizeFun(builder, opOperand, iterationDomain, options); if (failed(maybePaddedShape)) { - return rewriter.notifyMatchFailure(opToPad, "could not pad op"); + LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n"); + return failure(); } // 2.b. Expect proper `paddingValues`. // TODO: we may want to allow garbage padding in the future, in which case // we would just not assert. if (opOperand.getOperandNumber() >= options.paddingValues.size()) { - return rewriter.notifyMatchFailure(opToPad, - "--no padding value specified"); + LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n"); + return failure(); } Attribute paddingValueAttr = options.paddingValues[opOperand.getOperandNumber()]; // 2.c. Perform actual padding. - Value paddedOperand = padOperand( - rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand), - *maybePaddedShape, paddingValueAttr); + Value paddedOperand = + padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand), + *maybePaddedShape, paddingValueAttr); LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n"); - // 2.d. Perform actual padding. newOperands.push_back(paddedOperand); if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>()) padOps.push_back(padOp); @@ -342,38 +342,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp( // 3. Form the resulting tensor::ExtractSliceOp. ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) { - LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n"); - return rewriter.notifyMatchFailure(opToPad, - "failed to reify result shapes"); + if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) { + LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n"); + return failure(); } - assert(reifiedResultShapes.size() == opToPad->getNumResults() && + assert(reifiedResultShapes.size() == toPad->getNumResults() && "expected same number of results"); - // Clone `opToPad` to operate on the statically padded shapes. + // Clone `toPad` to operate on the statically padded shapes. auto resultTensorTypes = - ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes(); - // clone **should** properly notify the rewriter. + ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes(); + // clone **should** properly notify the builder. TilingInterface paddedOp = - clone(rewriter, opToPad, resultTensorTypes, newOperands); + clone(builder, toPad, resultTensorTypes, newOperands); LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n"); - // Recover the slice out of the new static results. This keeps the original - // opToPad around because it uses the dims of the original results. + // Recover the slice out of the new static results. SmallVector<Value> paddedSubtensorResults; - paddedSubtensorResults.reserve(opToPad->getNumResults()); + paddedSubtensorResults.reserve(toPad->getNumResults()); for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank(); - SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); - SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); + SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0)); + SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1)); paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create( - rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], + builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } - rewriter.replaceOp(opToPad, paddedSubtensorResults); - - return paddedOp; + return PadTilingInterfaceResult{padOps, paddedOp, paddedSubtensorResults}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp index f277c5f..0ae2a9c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp @@ -266,9 +266,8 @@ struct StructuredOpShardingInterface LinalgOp linalgOp = llvm::cast<LinalgOp>(op); SmallVector<utils::IteratorType> iteratorTypes = linalgOp.getIteratorTypesArray(); - unsigned reductionItersCount = std::accumulate( - iteratorTypes.begin(), iteratorTypes.end(), 0, - [](unsigned count, utils::IteratorType iter) { + unsigned reductionItersCount = llvm::accumulate( + iteratorTypes, 0u, [](unsigned count, utils::IteratorType iter) { return count + (iter == utils::IteratorType::reduction); }); shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 0dac688..eb2d825 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { - // TODO: support the case that outer dimensions are not all 1s. A - // tensor.expand_shape will be generated in this case. - if (llvm::any_of(packOp.getAllOuterDims(), + if (llvm::any_of(packOp.getTiledOuterDims(), [](int64_t dim) { return dim != 1; })) { return rewriter.notifyMatchFailure( packOp, "not all outer dimensions of the result are 1s"); } + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); + auto outerDimsPerm = packOp.getOuterDimsPerm(); + + // Verify that there are no: + // * non-unit + un-tiled-outer-dims, + // that are permuted. Supporting such cases would require refining the logic + // that generates the Transpose Op. + if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) { + static int prev = 0; + // Skip tiled dims - these can be permuted. + if (llvm::is_contained(innerDimsPos, dim)) + return true; + + // Check whether this dim has been permuted. Permuting unit dims is fine + // as that's effectively a no-op. + if (dim < prev && (packOp.getType().getShape()[prev] != 1 || + packOp.getType().getShape()[dim] != 1)) + return false; + + prev = dim; + return true; + })) { + return rewriter.notifyMatchFailure( + packOp, "At least one non-unit and un-tiled outer dim is permuted, " + "this is not supported ATM!"); + } + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); - int64_t numberOfTiles = innerDimsPos.size(); // 1. Get the input that is going to be packed. If the input requires padding, // add a padding operation and return that as the input. @@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // - All outer dims are 1 - the corresponding transposition order doesn't - // matter, but requires all dim indices to be present. + // - All tiled outer dims are 1 - the corresponding transposition order + // doesn't matter, but requires all dim indices to be present. + // - Un-tiled outer dims remain un-permuted. - // 2.1 Get the permutation for linalg.transpose + // 2.1 Get the permutation for linalg.transpose: + // [ untiled-dims, inner-dims-pos ] + // Note, this logic assumes that the untiled dims are not permuted. SmallVector<int64_t> srcPermForTranspose; for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the @@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); - // 2.2 Create the init tensor for linalg.transpose with the correct shape - SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles, - oneIdxAttr); + // 2.2 Create the init tensor for linalg.transpose with the correct shape: + // [ untiled-dims, tiled-dims ] + ShapedType inputTy = cast<ShapedType>(input.getType()); + SmallVector<OpFoldResult> shapeForEmptyOp; + for (int64_t i = 0; i < srcRank; i++) { + if (llvm::is_contained(innerDimsPos, i)) { + // The tiled dims are appended after this loop. + continue; + } + if (inputTy.isStaticDim(i)) + shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i])); + else + shapeForEmptyOp.emplace_back( + tensor::DimOp::create(rewriter, loc, input, i).getResult()); + } shapeForEmptyOp.append(packOp.getMixedTiles()); // getMixedTiles() may contain Values pointing to constant ops, not the @@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); - // 3. Insert the inner tile to the destination: + // 3. Insert the inner tile into the destination tensor: // %inserted_tile = tensor.insert_slice(%transposed_tile) - SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); - SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); - // Outer dims are all 1s! - SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr); - SmallVector<int64_t> writeShape; + + // Compute the sizes attribute: + // [ outer-dims, tile-sizes ] + // Note that the output from the transpose Op excludes the tiled outer dims. + // However, given the assumption that: + // * all tiled outer dims == 1, + // we can just use a rank-expanding tensor.insert_slice. + SmallVector<OpFoldResult> writeSizes; + for (auto size : packOp.getAllOuterDims()) { + writeSizes.push_back(rewriter.getIndexAttr(size)); + } for (auto tileSize : packOp.getMixedTiles()) { - auto [tileSizeStatic, tileSizeOfr] = + auto [_, tileSizeOfr] = getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); writeSizes.push_back(tileSizeOfr); - writeShape.push_back(tileSizeStatic); } - // 4. Replace tensor.packOp with tensor.insert_slice created above + // TODO: Add a constructor for tensor.insert_slice that doesn't require + // strides nor offsets. + SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); + SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); + auto insert = tensor::InsertSliceOp::create( rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); + + // 4. Replace tensor.packOp with tensor.insert_slice created above rewriter.replaceOp(packOp, insert.getResult()); return success(); diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index e25a012..1382c7ac 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR DEPENDS MLIRMemRefOpsIncGen @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda..94947b7 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,45 @@ public: return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern<ReinterpretCastOp> { +public: + using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + unsigned srcStaticCount = llvm::count_if( + llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { return isa<Attribute>(ofr); }); + + SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()}; + SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes(); + SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides(); + + // TODO: Using counting comparison instead of direct comparison because + // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns + // IntegerAttrs, while constifyIndexValues (and therefore + // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs. + if (srcStaticCount == + llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides), + [](OpFoldResult ofr) { return isa<Attribute>(ofr); })) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides); + + rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); + results.add<ReinterpretCastOpExtractStridedMetadataFolder, + ReinterpretCastOpConstantFolder>(context); } FailureOr<std::optional<SmallVector<Value>>> @@ -3437,6 +3471,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } +void SubViewOp::inferStridedMetadataRanges( + ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange, + SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { + auto isUninitialized = + +[](IntegerValueRange range) { return range.isUninitialized(); }; + + // Bail early if any of the operands metadata is not ready: + SmallVector<IntegerValueRange> offsetOperands = + getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); + if (llvm::any_of(offsetOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> sizeOperands = + getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); + if (llvm::any_of(sizeOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> stridesOperands = + getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); + if (llvm::any_of(stridesOperands, isUninitialized)) + return; + + StridedMetadataRange sourceRange = + ranges[getSourceMutable().getOperandNumber()]; + if (sourceRange.isUninitialized()) + return; + + ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides(); + + // Get the dropped dims. + llvm::SmallBitVector droppedDims = getDroppedDims(); + + // Compute the new offset, strides and sizes. + ConstantIntRanges offset = sourceRange.getOffsets()[0]; + SmallVector<ConstantIntRanges> strides, sizes; + + for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { + bool dropped = droppedDims.test(i); + // Compute the new offset. + ConstantIntRanges off = + intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); + offset = intrange::inferAdd({offset, off}); + + // Skip dropped dimensions. + if (dropped) + continue; + // Multiply the strides. + strides.push_back( + intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); + // Get the sizes. + sizes.push_back(sizeOperands[i].getValue()); + } + + setMetadata(getResult(), + StridedMetadataRange::getRanked( + SmallVector<ConstantIntRanges>({std::move(offset)}), + std::move(sizes), std::move(strides))); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index 49b7162..6f815ae 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -121,7 +121,7 @@ struct EmulateWideIntPass final [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); RewritePatternSet patterns(ctx); - // Add common pattenrs to support contants, functions, etc. + // Add common patterns to support contants, functions, etc. arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6564a4e..dcfe2c7 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" @@ -39,6 +40,16 @@ static bool isScalarLikeType(Type type) { return type.isIntOrIndexOrFloat() || isa<ComplexType>(type); } +/// Helper function to attach the `VarName` attribute to an operation +/// if a variable name is provided. +static void attachVarNameAttr(Operation *op, OpBuilder &builder, + StringRef varName) { + if (!varName.empty()) { + auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName); + op->setAttr(acc::getVarNameAttrName(), varNameAttr); + } +} + struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, MemRefType> { @@ -74,14 +85,18 @@ struct MemRefPointerLikeModel } mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, - StringRef varName, Type varType, - Value originalVar) const { + StringRef varName, Type varType, Value originalVar, + bool &needsFree) const { auto memrefTy = cast<MemRefType>(pointer); // Check if this is a static memref (all dimensions are known) - if yes // then we can generate an alloca operation. - if (memrefTy.hasStaticShape()) - return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + if (memrefTy.hasStaticShape()) { + needsFree = false; // alloca doesn't need deallocation + auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy); + attachVarNameAttr(allocaOp, builder, varName); + return allocaOp.getResult(); + } // For dynamic memrefs, extract sizes from the original variable if // provided. Otherwise they cannot be handled. @@ -99,8 +114,11 @@ struct MemRefPointerLikeModel // Note: We only add dynamic sizes to the dynamicSizes array // Static dimensions are handled automatically by AllocOp } - return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) - .getResult(); + needsFree = true; // alloc needs deallocation + auto allocOp = + memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes); + attachVarNameAttr(allocOp, builder, varName); + return allocOp.getResult(); } // TODO: Unranked not yet supported. @@ -108,10 +126,14 @@ struct MemRefPointerLikeModel } bool genFree(Type pointer, OpBuilder &builder, Location loc, - TypedValue<PointerLikeType> varPtr, Type varType) const { - if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + TypedValue<PointerLikeType> varToFree, Value allocRes, + Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) { + // Use allocRes if provided to determine the allocation type + Value valueToInspect = allocRes ? allocRes : memrefValue; + // Walk through casts to find the original allocation - Value currentValue = memrefValue; + Value currentValue = valueToInspect; Operation *originalAlloc = nullptr; // Follow the chain of operations to find the original allocation @@ -150,7 +172,7 @@ struct MemRefPointerLikeModel return true; } if (isa<memref::AllocOp>(originalAlloc)) { - // This is an alloc - generate dealloc + // This is an alloc - generate dealloc on varToFree memref::DeallocOp::create(builder, loc, memrefValue); return true; } @@ -1003,6 +1025,138 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { } }; +//===----------------------------------------------------------------------===// +// Recipe Region Helpers +//===----------------------------------------------------------------------===// + +/// Create and populate an init region for privatization recipes. +/// Returns success if the region is populated, failure otherwise. +/// Sets needsFree to indicate if the allocated memory requires deallocation. +static LogicalResult createInitRegion(OpBuilder &builder, Location loc, + Region &initRegion, Type varType, + StringRef varName, ValueRange bounds, + bool &needsFree) { + // Create init block with arguments: original value + bounds + SmallVector<Type> argTypes{varType}; + SmallVector<Location> argLocs{loc}; + for (Value bound : bounds) { + argTypes.push_back(bound.getType()); + argLocs.push_back(loc); + } + + Block *initBlock = builder.createBlock(&initRegion); + initBlock->addArguments(argTypes, argLocs); + builder.setInsertionPointToStart(initBlock); + + Value privatizedValue; + + // Get the block argument that represents the original variable + Value blockArgVar = initBlock->getArgument(0); + + // Generate init region body based on variable type + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + auto typedVar = cast<TypedValue<MappableType>>(blockArgVar); + privatizedValue = mappableTy.generatePrivateInit( + builder, loc, typedVar, varName, bounds, {}, needsFree); + if (!privatizedValue) + return failure(); + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + // Use PointerLikeType's allocation API with the block argument + privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType, + blockArgVar, needsFree); + if (!privatizedValue) + return failure(); + } + + // Add yield operation to init block + acc::YieldOp::create(builder, loc, privatizedValue); + + return success(); +} + +/// Create and populate a copy region for firstprivate recipes. +/// Returns success if the region is populated, failure otherwise. +/// TODO: Handle MappableType - it does not yet have a copy API. +static LogicalResult createCopyRegion(OpBuilder &builder, Location loc, + Region ©Region, Type varType, + ValueRange bounds) { + // Create copy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> copyArgTypes{varType, varType}; + SmallVector<Location> copyArgLocs{loc, loc}; + for (Value bound : bounds) { + copyArgTypes.push_back(bound.getType()); + copyArgLocs.push_back(loc); + } + + Block *copyBlock = builder.createBlock(©Region); + copyBlock->addArguments(copyArgTypes, copyArgLocs); + builder.setInsertionPointToStart(copyBlock); + + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + // TODO: Handle MappableType - it does not yet have a copy API. + // Otherwise, for now just fallback to pointer-like behavior. + if (isMappable && !isPointerLike) + return failure(); + + // Generate copy region body based on variable type + if (isPointerLike) { + auto pointerLikeTy = cast<PointerLikeType>(varType); + Value originalArg = copyBlock->getArgument(0); + Value privatizedArg = copyBlock->getArgument(1); + + // Generate copy operation using PointerLikeType interface + if (!pointerLikeTy.genCopy( + builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg), + cast<TypedValue<PointerLikeType>>(originalArg), varType)) + return failure(); + } + + // Add terminator to copy block + acc::TerminatorOp::create(builder, loc); + + return success(); +} + +/// Create and populate a destroy region for privatization recipes. +/// Returns success if the region is populated, failure otherwise. +static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc, + Region &destroyRegion, Type varType, + Value allocRes, ValueRange bounds) { + // Create destroy block with arguments: original value + privatized value + + // bounds + SmallVector<Type> destroyArgTypes{varType, varType}; + SmallVector<Location> destroyArgLocs{loc, loc}; + for (Value bound : bounds) { + destroyArgTypes.push_back(bound.getType()); + destroyArgLocs.push_back(loc); + } + + Block *destroyBlock = builder.createBlock(&destroyRegion); + destroyBlock->addArguments(destroyArgTypes, destroyArgLocs); + builder.setInsertionPointToStart(destroyBlock); + + auto varToFree = + cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1)); + if (isa<MappableType>(varType)) { + auto mappableTy = cast<MappableType>(varType); + if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree)) + return failure(); + } else { + assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType"); + auto pointerLikeTy = cast<PointerLikeType>(varType); + if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType)) + return failure(); + } + + acc::TerminatorOp::create(builder, loc); + return success(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -1050,6 +1204,48 @@ LogicalResult acc::PrivateRecipeOp::verifyRegions() { return success(); } +std::optional<PrivateRecipeOp> +PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + OpBuilder::InsertionGuard guard(builder); + + // Create the recipe operation first so regions have proper parent context + auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType); + + // Populate the init region + bool needsFree = false; + if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType, + varName, bounds, needsFree))) { + recipe.erase(); + return std::nullopt; + } + + // Only create destroy region if the allocation needs deallocation + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = + cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(), + varType, allocRes, bounds))) { + recipe.erase(); + return std::nullopt; + } + } + + return recipe; +} + //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// @@ -1080,6 +1276,55 @@ LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { return success(); } +std::optional<FirstprivateRecipeOp> +FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc, + StringRef recipeName, Type varType, + StringRef varName, ValueRange bounds) { + // First, validate that we can handle this variable type + bool isMappable = isa<MappableType>(varType); + bool isPointerLike = isa<PointerLikeType>(varType); + + // Unsupported type + if (!isMappable && !isPointerLike) + return std::nullopt; + + OpBuilder::InsertionGuard guard(builder); + + // Create the recipe operation first so regions have proper parent context + auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType); + + // Populate the init region + bool needsFree = false; + if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType, + varName, bounds, needsFree))) { + recipe.erase(); + return std::nullopt; + } + + // Populate the copy region + if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType, + bounds))) { + recipe.erase(); + return std::nullopt; + } + + // Only create destroy region if the allocation needs deallocation + if (needsFree) { + // Extract the allocated value from the init block's yield operation + auto yieldOp = + cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator()); + Value allocRes = yieldOp.getOperand(0); + + if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(), + varType, allocRes, bounds))) { + recipe.erase(); + return std::nullopt; + } + } + + return recipe; +} + //===----------------------------------------------------------------------===// // ReductionRecipeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index b663908..8c4f80f 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Quant/Utils/UniformSupport.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> using namespace mlir; @@ -76,9 +77,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { // using the right quantization parameters. int64_t flattenIndex = 0; auto shape = type.getShape(); - int64_t chunkSize = - std::accumulate(std::next(shape.begin(), quantizationDim + 1), - shape.end(), 1, std::multiplies<int64_t>()); + int64_t chunkSize = llvm::product_of(shape.drop_front(quantizationDim + 1)); Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth); return attr.mapValues(newElementType, [&](const APFloat &old) { int chunkIndex = (flattenIndex++) / chunkSize; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 5511998..fe50865 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -400,7 +400,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { return emitOpError("operand element type mismatch: expected to be ") << resultType.getElementType() << ", but provided " << elementType; } - unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); + unsigned totalCount = llvm::sum_of(sizes); if (totalCount != cType.getNumElements()) return emitOpError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << totalCount; diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 08fccfa..645cbff 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -158,7 +158,7 @@ static FailureOr<GridOp> getGridAndVerify(Operation *op, } template <typename It> -bool isUnique(It begin, It end) { +static bool isUnique(It begin, It end) { if (begin == end) { return true; } @@ -1010,18 +1010,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, return success(); } -template <typename It> -static auto product(It begin, It end) { - using ElementType = std::decay_t<decltype(*begin)>; - return std::accumulate(begin, end, static_cast<ElementType>(1), - std::multiplies<ElementType>()); -} - -template <typename R> -static auto product(R &&range) { - return product(adl_begin(range), adl_end(range)); -} - static LogicalResult verifyDimensionCompatibility(Location loc, int64_t expectedDimSize, int64_t resultDimSize, diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index a1711a6..069191c 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -143,8 +143,8 @@ void VarInfo::setNum(Var::Num n) { /// Helper function for `assertUsageConsistency` to better handle SMLoc /// mismatches. -LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc -minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { +[[maybe_unused]] static llvm::SMLoc minSMLoc(AsmParser &parser, llvm::SMLoc sm1, + llvm::SMLoc sm2) { const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1)); assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`"); const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index f539502..684c088 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -43,8 +43,8 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// #ifndef NDEBUG -LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder, - Location loc, Value memref) { +[[maybe_unused]] static void dumpIndexMemRef(OpBuilder &builder, Location loc, + Value memref) { memref = memref::CastOp::create( builder, loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref); createFuncCall(builder, loc, "printMemrefInd", TypeRange{}, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fa97b49..ac72002 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType( sourceTensorType.getEncoding()); } +// TODO: This uses neither offsets nor strides! RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 5aad671..1cba1bb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace tosa { @@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) { } TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { - return TargetEnvAttr::get(context, Level::eightK, + return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK, {Profile::pro_int, Profile::pro_fp}, {}); } @@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index c51b5e9..00f84bc 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2368,9 +2368,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { } } - int64_t newShapeElementsNum = std::accumulate( - shapeValues.begin(), shapeValues.end(), 1LL, - [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); + int64_t newShapeElementsNum = + llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) { + return (dim > 0) ? acc * dim : acc; + }); bool isStaticNewShape = llvm::all_of(shapeValues, [](int64_t s) { return s > 0; }); if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp index bcb880a..a0661e4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -61,8 +61,8 @@ public: ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - const auto targetEnvAttr = - TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + const auto targetEnvAttr = TargetEnvAttr::get( + ctx, specificationVersion, level, selectedProfiles, selectedExtensions); mod->setAttr(TargetEnvAttr::name, targetEnvAttr); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index d33ebe3..5786f53 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/Matchers.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" using namespace mlir; @@ -375,8 +376,7 @@ llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr, for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis]; ++reductionAxisVal) { - int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1, - oldShape.end(), 1, std::multiplies<int>()); + int64_t stride = llvm::product_of(oldShape.drop_front(reductionAxis + 1)); int64_t index = indexAtOldTensor + stride * reductionAxisVal; reducedValue = OperationType::calcOneElement(reducedValue, oldTensor[index]); @@ -424,8 +424,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> { auto oldShape = shapedOldElementsValues.getShape(); auto newShape = resultType.getShape(); - auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1, - std::multiplies<int>()); + int64_t newNumOfElements = llvm::product_of(newShape); llvm::SmallVector<APInt> newReducedTensor(newNumOfElements); for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 20f9333..f072e3e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { //===----------------------------------------------------------------------===// template <typename T> -FailureOr<SmallVector<T>> -TosaProfileCompliance::getOperatorDefinition(Operation *op, - CheckCondition &condition) { +FailureOr<OpComplianceInfo<T>> +TosaProfileCompliance::getOperatorDefinition(Operation *op) { const std::string opName = op->getName().getStringRef().str(); const auto complianceMap = getProfileComplianceMap<T>(); const auto it = complianceMap.find(opName); if (it == complianceMap.end()) return {}; - return findMatchedProfile<T>(op, it->second, condition); + return findMatchedEntry<T>(op, it->second); } template <typename T> @@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( if (specRequiredModeSet.size() == 0) return success(); - CheckCondition condition = CheckCondition::invalid; - const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition); - if (failed(maybeOpRequiredMode)) { + const auto maybeOpDefinition = getOperatorDefinition<T>(op); + if (failed(maybeOpDefinition)) { // Operators such as control-flow and shape ops do not have an operand type // restriction. When the profile compliance information of operation is not // found, confirm if the target have enabled the profile required from the // specification. - int mode_count = 0; + int modeCount = 0; for (const auto &cands : specRequiredModeSet) { if (targetEnv.allowsAnyOf(cands)) return success(); - mode_count += cands.size(); + modeCount += cands.size(); } op->emitOpError() << "illegal: requires" - << (mode_count > 1 ? " any of " : " ") << "[" + << (modeCount > 1 ? " any of " : " ") << "[" << llvm::join(stringifyProfile<T>(specRequiredModeSet), ", ") << "] but not enabled in target\n"; @@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( // Find the required profiles or extensions according to the operand type // combination. - const auto opRequiredMode = maybeOpRequiredMode.value(); + const auto opDefinition = maybeOpDefinition.value(); + const SmallVector<T> opRequiredMode = opDefinition.mode; + const CheckCondition condition = opDefinition.condition; + if (opRequiredMode.size() == 0) { // No matched restriction found. return success(); @@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( } } + // Ensure the matched op compliance version does not exceed the target + // specification version. + const VersionedTypeInfo versionedTypeInfo = + opDefinition.operandTypeInfoSet[0]; + const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second}; + const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()}; + if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) { + op->emitOpError() << "illegal: the target specification version (" + << stringifyVersion(targetVersion) + << ") is not backwards compatible with the op compliance " + "specification version (" + << stringifyVersion(complianceVersion) << ")\n"; + return failure(); + } + return success(); } @@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op, } LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { - CheckCondition condition = CheckCondition::invalid; - const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); - const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + const auto maybeProfDef = getOperatorDefinition<Profile>(op); + const auto maybeExtDef = getOperatorDefinition<Extension>(op); if (failed(maybeProfDef) && failed(maybeExtDef)) return success(); - const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || - (succeeded(maybeExtDef) && !maybeExtDef->empty()); + const bool hasEntry = + (succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->mode.empty()); if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); @@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { SmallVector<TypeInfo> bestTypeInfo; const auto searchBestMatch = [&](auto map) { for (const auto &complianceInfos : map[opName]) { - for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) { + for (const auto &versionedTypeInfos : + complianceInfos.operandTypeInfoSet) { + const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first; const int matches = llvm::count_if( llvm::zip_equal(current, typeInfos), [&](const auto zipType) { return isSameTypeInfo(std::get<0>(zipType), @@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { // Find the profiles or extensions requirement according to the signature of // type of the operand list. template <typename T> -SmallVector<T> TosaProfileCompliance::findMatchedProfile( - Operation *op, SmallVector<OpComplianceInfo<T>> compInfo, - CheckCondition &condition) { +OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry( + Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) { assert(compInfo.size() != 0 && "profile-based compliance information is empty"); @@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile( return {}; for (size_t i = 0; i < compInfo.size(); i++) { - SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet; - for (SmallVector<TypeInfo> expected : sets) { + SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet; + for (const auto &set : sets) { + SmallVector<TypeInfo> expected = set.first; assert(present.size() == expected.size() && "the entries for profile-based compliance do not match between " "the generated metadata and the type definition retrieved from " " the operation"); - bool is_found = true; + bool isFound = true; // Compare the type signature between the given operation and the // compliance metadata. for (size_t j = 0; j < expected.size(); j++) { if (!isSameTypeInfo(present[j], expected[j])) { // Verify the next mode set from the list. - is_found = false; + isFound = false; break; } } - if (is_found == true) { - condition = compInfo[i].condition; - return compInfo[i].mode; + if (isFound == true) { + SmallVector<VersionedTypeInfo> typeInfoSet{set}; + OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet, + compInfo[i].condition}; + return info; } } } diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp index 9a24c2b..a2cff6a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -21,10 +21,10 @@ using namespace mlir; // These are automatically generated by ODS but are not used as the Transform // dialect uses a different dispatch mechanism to support dialect extensions. -LLVM_ATTRIBUTE_UNUSED static OptionalParseResult +[[maybe_unused]] static OptionalParseResult generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); -LLVM_ATTRIBUTE_UNUSED static LogicalResult -generatedTypePrinter(Type def, AsmPrinter &printer); +[[maybe_unused]] static LogicalResult generatedTypePrinter(Type def, + AsmPrinter &printer); #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc" diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index e1648ab9..305b06eb 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -81,21 +81,10 @@ SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1, return computeElementwiseMulImpl(v1, v2); } -int64_t mlir::computeSum(ArrayRef<int64_t> basis) { - assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && - "basis must be nonnegative"); - if (basis.empty()) - return 0; - return std::accumulate(basis.begin(), basis.end(), 1, std::plus<int64_t>()); -} - int64_t mlir::computeProduct(ArrayRef<int64_t> basis) { assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && "basis must be nonnegative"); - if (basis.empty()) - return 1; - return std::accumulate(basis.begin(), basis.end(), 1, - std::multiplies<int64_t>()); + return llvm::product_of(basis); } int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) { @@ -158,19 +147,11 @@ SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1, } AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { - if (basis.empty()) - return getAffineConstantExpr(0, ctx); - return std::accumulate(basis.begin(), basis.end(), - getAffineConstantExpr(0, ctx), - std::plus<AffineExpr>()); + return llvm::sum_of(basis, getAffineConstantExpr(0, ctx)); } AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { - if (basis.empty()) - return getAffineConstantExpr(1, ctx); - return std::accumulate(basis.begin(), basis.end(), - getAffineConstantExpr(1, ctx), - std::multiplies<AffineExpr>()); + return llvm::product_of(basis, getAffineConstantExpr(1, ctx)); } AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 7b2734d..6e9118e 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -374,11 +374,11 @@ mlir::composeReassociationIndices( if (consumerReassociations.empty()) return composedIndices; - size_t consumerDims = std::accumulate( - consumerReassociations.begin(), consumerReassociations.end(), 0, - [](size_t all, ReassociationIndicesRef indices) { - return all + indices.size(); - }); + size_t consumerDims = + llvm::accumulate(consumerReassociations, size_t(0), + [](size_t all, ReassociationIndicesRef indices) { + return all + indices.size(); + }); if (producerReassociations.size() != consumerDims) return std::nullopt; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a7e3ba8..45c54c7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2496,8 +2496,7 @@ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> { auto srcElems = vector::ToElementsOp::create( rewriter, toElementsOp.getLoc(), bcastOp.getSource()); - int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1, - std::multiplies<int64_t>()); + int64_t dstCount = llvm::product_of(dstShape); SmallVector<Value> replacements; replacements.reserve(dstCount); @@ -7602,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), result); } +namespace { + +/// Fold `vector.step -> arith.cmpi` when the step value is compared to a +/// constant large enough such that the result is the same at all indices. +/// +/// For example, rewrite the 'greater than' comparison below, +/// +/// ```mlir +/// %cst = arith.constant dense<7> : vector<3xindex> +/// %stp = vector.step : vector<3xindex> +/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> +/// ``` +/// +/// as, +/// +/// ```mlir +/// %out = arith.constant dense<false> : vector<3xi1>. +/// ``` +/// +/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result +/// is false at ALL indices we fold. If the constant was 1, then +/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold, +/// conservatively preferring the 'compact' vector.step representation. +/// +/// Note: this folder only works for the case where the constant (`%cst` above) +/// is the second operand of the comparison. The arith.cmpi canonicalizer will +/// ensure that constants are always second (on the right). +struct StepCompareFolder : public OpRewritePattern<StepOp> { + using Base::Base; + + LogicalResult matchAndRewrite(StepOp stepOp, + PatternRewriter &rewriter) const override { + const int64_t stepSize = stepOp.getResult().getType().getNumElements(); + + for (OpOperand &use : stepOp.getResult().getUses()) { + auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner()); + if (!cmpiOp) + continue; + + // arith.cmpi canonicalizer makes constants final operands. + const unsigned stepOperandNumber = use.getOperandNumber(); + if (stepOperandNumber != 0) + continue; + + // Check that operand 1 is a constant. + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); + std::optional<int64_t> maybeConstValue = + getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; + + int64_t constValue = maybeConstValue.value(); + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional<bool> { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; + + // Handle ule and ugt. + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize - 1 <= constValue) { + return pred == arith::CmpIPredicate::ule; + } + + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; + + return std::nullopt; + }(); + + if (!maybeSplat.has_value()) + continue; + + rewriter.setInsertionPointAfter(cmpiOp); + + auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType()); + if (!type) + continue; + + auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); + + rewriter.replaceOp(cmpiOp, splat); + return success(); + } + + return failure(); + } +}; +} // namespace + +void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<StepCompareFolder>(context); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index c5f22b2..0eba0b1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" #include <numeric> #define DEBUG_TYPE "vector-shape-cast-lowering" @@ -166,10 +167,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { const VectorType resultType = shapeCast.getResultVectorType(); const ArrayRef<int64_t> resultShape = resultType.getShape(); - const int64_t nSlices = - std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1, - std::multiplies<int64_t>()); - + const int64_t nSlices = llvm::product_of(sourceShape.take_front(sourceDim)); SmallVector<int64_t> extractIndex(sourceDim, 0); SmallVector<int64_t> insertIndex(resultDim, 0); Value result = ub::PoisonOp::create(rewriter, loc, resultType); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e95338f..7c019e7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern { // Some values may be yielded multiple times and correspond to multiple // results. Deduplicating occurs by taking each result with its matching // yielded value, and: - // 1. recording the unique first position at which the value is yielded. + // 1. recording the unique first position at which the value with uses is + // yielded. // 2. recording for the result, the first position at which the dedup'ed // value is yielded. // 3. skipping from the new result types / new yielded values any result // that has no use or whose yielded value has already been seen. for (OpResult result : warpOp.getResults()) { + if (result.use_empty()) + continue; Value yieldOperand = yield.getOperand(result.getResultNumber()); auto it = dedupYieldOperandPositionMap.insert( std::make_pair(yieldOperand, newResultTypes.size())); dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); - if (result.use_empty() || !it.second) + if (!it.second) continue; newResultTypes.push_back(result.getType()); newYieldValues.push_back(yieldOperand); @@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), escapingValueDistTypesElse.end()); - llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; for (auto [idx, val] : llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { - origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(val); newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); } - // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + // Replace the old `WarpOp` with the new one that has additional yield + // values and types. + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // `ifOp` returns the result of the inner warp op. SmallVector<Type> newIfOpDistResTypes; for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { @@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newIfOp = scf::IfOp::create( - rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), - static_cast<bool>(ifOp.thenBlock()), + rewriter, ifOp.getLoc(), newIfOpDistResTypes, + newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()), static_cast<bool>(ifOp.elseBlock())); auto encloseRegionInWarpOp = [&](Block *oldIfBranch, Block *newIfBranch, @@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { innerWarpInputVals.push_back( - newWarpOp.getResult(warpResRangeStart)); + newWarpOp.getResult(newIndices[warpResRangeStart])); escapeValToBlockArgIndex[escapingValues[i]] = innerWarpInputTypes.size(); innerWarpInputTypes.push_back(escapingValueInputTypes[i]); @@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newIfOp.getResult(newIdx), newIfOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `IfOp`. - for (auto [origIdx, newIdx] : origToNewYieldIdx) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `IfOp`, they should not have any uses - // at this point. - rewriter.eraseOp(ifOp); - rewriter.eraseOp(warpOp); return success(); } @@ -2038,11 +2032,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern { } // Newly created `WarpOp` will yield values in following order: - // 1. All init args of the `ForOp`. - // 2. All escaping values. - // 3. All non-`ForOp` yielded values. + // 1. Loop bounds. + // 2. All init args of the `ForOp`. + // 3. All escaping values. + // 4. All non-`ForOp` yielded values. SmallVector<Value> newWarpOpYieldValues; SmallVector<Type> newWarpOpDistTypes; + newWarpOpYieldValues.insert( + newWarpOpYieldValues.end(), + {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()}); + newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), + {forOp.getLowerBound().getType(), + forOp.getUpperBound().getType(), + forOp.getStep().getType()}); for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { newWarpOpYieldValues.push_back(initArg); // Compute the distributed type for this init arg. @@ -2065,36 +2067,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern { escapingValueDistTypes.begin(), escapingValueDistTypes.end()); // Next, we insert all non-`ForOp` yielded values and their distributed - // types. We also create a mapping between the non-`ForOp` yielded value - // index and the corresponding new `WarpOp` yield value index (needed to - // update users later). - llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping; + // types. for (auto [i, v] : llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { - nonForResultMapping[i] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); } // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. + const unsigned initArgsStartIdx = 3; // After loop bounds. const unsigned escapingValuesStartIdx = + initArgsStartIdx + forOp.getInitArgs().size(); // `ForOp` init args are positioned before // escaping values in the new `WarpOp`. SmallVector<Value> newForOpOperands; - for (size_t i = 0; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); + for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i) + newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = scf::ForOp::create( - rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr, - forOp.getUnsignedCmp()); + rewriter, forOp.getLoc(), + /**LowerBound=**/ newWarpOp.getResult(newIndices[0]), + /**UpperBound=**/ newWarpOp.getResult(newIndices[1]), + /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were // contained within the original `ForOp` body. @@ -2110,7 +2113,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { llvm::SmallDenseMap<Value, int64_t> argIndexMapping; for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { - innerWarpInput.push_back(newWarpOp.getResult(i)); + innerWarpInput.push_back(newWarpOp.getResult(newIndices[i])); argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = innerWarpInputType.size(); innerWarpInputType.push_back( @@ -2146,20 +2149,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (!innerWarp.getResults().empty()) scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); - // Update the users of original `WarpOp` results that were coming from the + // Update the users of the new `WarpOp` results that were coming from the // original `ForOp` to the corresponding new `ForOp` result. for (auto [origIdx, newIdx] : forResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `ForOp`. - for (auto [origIdx, newIdx] : nonForResultMapping) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `ForOp`, they should not have any uses - // at this point. - rewriter.eraseOp(forOp); - rewriter.eraseOp(warpOp); // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 963b2c8..aa2dd89 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" #define DEBUG_TYPE "vector-drop-unit-dim" @@ -557,8 +558,7 @@ struct CastAwayConstantMaskLeadingOneDim // If any of the dropped unit dims has a size of `0`, the entire mask is a // zero mask, else the unit dim has no effect on the mask. int64_t flatLeadingSize = - std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1, - static_cast<int64_t>(1), std::multiplies<int64_t>()); + llvm::product_of(dimSizes.take_front(dropDim + 1)); SmallVector<int64_t> newDimSizes = {flatLeadingSize}; newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 1b656d8..ea93085 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -817,6 +817,50 @@ struct LinearizeVectorToElements final } }; +/// Convert broadcasts from scalars or 1-element vectors, such as +/// +/// ```mlir +/// vector.broadcast %value : f32 to vector<4x4xf32> +/// ``` +/// +/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed. +/// The above becomes, +/// +/// ```mlir +/// %out_1d = vector.broadcast %value : f32 to vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +/// ``` +struct LinearizeVectorBroadcast final + : public OpConversionPattern<vector::BroadcastOp> { + using Base::Base; + + LinearizeVectorBroadcast(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + int numElements = 1; + Type sourceType = broadcastOp.getSourceType(); + if (auto vecType = dyn_cast<VectorType>(sourceType)) { + numElements = vecType.getNumElements(); + } + + if (numElements != 1) { + return rewriter.notifyMatchFailure( + broadcastOp, "only broadcasts of single elements can be linearized."); + } + + auto dstTy = getTypeConverter()->convertType(broadcastOp.getType()); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy, + adaptor.getSource()); + + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore, - LinearizeVectorFromElements, LinearizeVectorToElements>( - typeConverter, patterns.getContext()); + LinearizeVectorBroadcast, LinearizeVectorFromElements, + LinearizeVectorToElements>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 14639c5..fbae098 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern { auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); + int64_t targetShapeRank = targetShape->size(); auto dstVecType = cast<VectorType>(op->getResult(0).getType()); SmallVector<int64_t> originalSize = *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); - // Bail-out if rank(source) != rank(target). The main limitation here is the - // fact that `ExtractStridedSlice` requires the rank for the input and - // output to match. If needed, we can relax this later. - if (originalSize.size() != targetShape->size()) - return rewriter.notifyMatchFailure( - op, "expected input vector rank to match target shape rank"); + int64_t originalShapeRank = originalSize.size(); + Location loc = op->getLoc(); + + // Handle rank mismatch by adding leading unit dimensions to targetShape + SmallVector<int64_t> adjustedTargetShape(originalShapeRank); + int64_t rankDiff = originalShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), + adjustedTargetShape.begin() + rankDiff, 1); + std::copy(targetShape->begin(), targetShape->end(), + adjustedTargetShape.begin() + rankDiff); + + int64_t adjustedTargetShapeRank = adjustedTargetShape.size(); // Prepare the result vector. Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - SmallVector<int64_t> strides(targetShape->size(), 1); - VectorType newVecType = + SmallVector<int64_t> strides(adjustedTargetShapeRank, 1); + VectorType unrolledVecType = VectorType::get(*targetShape, dstVecType.getElementType()); // Create the unrolled computation. for (SmallVector<int64_t> offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { + StaticTileOffsetRange(originalSize, adjustedTargetShape)) { SmallVector<Value> extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = dyn_cast<VectorType>(operand.get().getType()); @@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern { extractOperands.push_back(operand.get()); continue; } - extractOperands.push_back( - rewriter.createOrFold<vector::ExtractStridedSliceOp>( - loc, operand.get(), offsets, *targetShape, strides)); + Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, operand.get(), offsets, adjustedTargetShape, strides); + + // Reshape to remove leading unit dims if needed + if (adjustedTargetShapeRank > targetShapeRank) { + extracted = rewriter.createOrFold<vector::ShapeCastOp>( + loc, VectorType::get(*targetShape, vecType.getElementType()), + extracted); + } + extractOperands.push_back(extracted); } + Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, op, extractOperands, newVecType); + rewriter, loc, op, extractOperands, unrolledVecType); + + Value computeResult = newOp->getResult(0); + + // Use strides sized to targetShape for proper insertion + SmallVector<int64_t> insertStrides = + (adjustedTargetShapeRank > targetShapeRank) + ? SmallVector<int64_t>(targetShapeRank, 1) + : strides; + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( - loc, newOp->getResult(0), result, offsets, strides); + loc, computeResult, result, offsets, insertStrides); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 025ee9a..c809c502 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -91,7 +91,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) { // Check whether the two source vector dimensions that are greater than one // must be transposed with each other so that we can apply one of the 2-D - // transpose pattens. Otherwise, these patterns are not applicable. + // transpose patterns. Otherwise, these patterns are not applicable. if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], op.getPermutation())) return failure(); diff --git a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp index 89b62a2..a514ea9 100644 --- a/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp +++ b/mlir/lib/Dialect/WasmSSA/IR/WasmSSAOps.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Region.h" #include "mlir/IR/SymbolTable.h" @@ -39,28 +40,6 @@ void printElseRegion(OpAsmPrinter &opPrinter, Operation *op, opPrinter.printKeywordOrString("else "); opPrinter.printRegion(elseRegion); } - -ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) { - std::string keyword; - auto initLocation = opParser.getCurrentLocation(); - std::ignore = opParser.parseOptionalKeywordOrString(&keyword); - if (keyword == "nested" or keyword == "") { - visibility = StringAttr::get(opParser.getContext(), "nested"); - return ParseResult::success(); - } - - if (keyword == "public" || keyword == "private") { - visibility = StringAttr::get(opParser.getContext(), keyword); - return ParseResult::success(); - } - opParser.emitError(initLocation, "expecting symbol visibility"); - return ParseResult::failure(); -} - -void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op, - Attribute visibility) { - opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref()); -} } // namespace #define GET_OP_CLASSES @@ -167,10 +146,23 @@ Block *FuncOp::addEntryBlock() { void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState, StringRef symbol, FunctionType funcType) { - FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested"); + FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto *ctx = parser.getContext(); + std::string visibilityString; + auto loc = parser.getNameLoc(); + ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString); + bool exported{false}; + if (res.succeeded()) { + if (visibilityString != "exported") + return parser.emitError( + loc, "expecting either `exported` or symbol name. got ") + << visibilityString; + exported = true; + } + auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, function_interface_impl::VariadicFlag, @@ -191,11 +183,13 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { return builder.getFunctionType(argTypesWithoutLocal, results); }; - - return function_interface_impl::parseFunctionOp( + auto funcParseRes = function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); + if (exported) + result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx)); + return funcParseRes; } LogicalResult FuncOp::verifyBody() { @@ -224,9 +218,18 @@ LogicalResult FuncOp::verifyBody() { } void FuncOp::print(OpAsmPrinter &p) { + /// If exported, print it before and mask it before printing + /// using generic interface. + auto exported = getExported(); + if (exported) { + p << " exported"; + removeExportedAttr(); + } function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); + if (exported) + setExported(true); } //===----------------------------------------------------------------------===// @@ -237,38 +240,37 @@ void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, StringRef symbol, StringRef moduleName, StringRef importName, FunctionType type) { FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - type, {}, {}, odsBuilder.getStringAttr("nested")); + type, {}, {}); } //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// - -void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, Type type, bool isMutable) { - GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable, - odsBuilder.getStringAttr("nested")); -} - // Custom formats ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr symbolName; Type globalType; auto *ctx = parser.getContext(); - ParseResult res = parser.parseSymbolName( - symbolName, SymbolTable::getSymbolAttrName(), result.attributes); + std::string visibilityString; + auto loc = parser.getNameLoc(); + ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString); + if (res.succeeded()) { + if (visibilityString != "exported") + return parser.emitError( + loc, "expecting either `exported` or symbol name. got ") + << visibilityString; + result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx)); + } + res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(), + result.attributes); res = parser.parseType(globalType); result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType)); std::string mutableString; res = parser.parseOptionalKeywordOrString(&mutableString); if (res.succeeded() && mutableString == "mutable") result.addAttribute("isMutable", UnitAttr::get(ctx)); - std::string visibilityString; - res = parser.parseOptionalKeywordOrString(&visibilityString); - if (res.succeeded()) - result.addAttribute("sym_visibility", - StringAttr::get(ctx, visibilityString)); + res = parser.parseColon(); Region *globalInitRegion = result.addRegion(); res = parser.parseRegion(*globalInitRegion); @@ -276,11 +278,11 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { } void GlobalOp::print(OpAsmPrinter &printer) { + if (getExported()) + printer << " exported"; printer << " @" << getSymName().str() << " " << getType(); if (getIsMutable()) printer << " mutable"; - if (auto vis = getSymVisibility()) - printer << " " << *vis; printer << " :"; Region &body = getRegion(); if (!body.empty()) { @@ -319,13 +321,6 @@ GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // GlobalImportOp //===----------------------------------------------------------------------===// -void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, StringRef moduleName, - StringRef importName, Type type, bool isMutable) { - GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - type, isMutable, odsBuilder.getStringAttr("nested")); -} - ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) { auto *ctx = parser.getContext(); ParseResult res = parseImportOp(parser, result); @@ -335,12 +330,8 @@ ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) { res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString); if (res.succeeded() && mutableOrSymVisString == "mutable") { result.addAttribute("isMutable", UnitAttr::get(ctx)); - res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString); } - if (res.succeeded()) - result.addAttribute("sym_visibility", - StringAttr::get(ctx, mutableOrSymVisString)); res = parser.parseColon(); Type importedType; @@ -356,8 +347,6 @@ void GlobalImportOp::print(OpAsmPrinter &printer) { << "\" as @" << getSymName(); if (getIsMutable()) printer << " mutable"; - if (auto vis = getSymVisibility()) - printer << " " << *vis; printer << " : " << getType(); } @@ -431,27 +420,6 @@ LogicalResult LocalTeeOp::verify() { Block *LoopOp::getLabelTarget() { return &getBody().front(); } //===----------------------------------------------------------------------===// -// MemOp -//===----------------------------------------------------------------------===// - -void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, LimitType limit) { - MemOp::build(odsBuilder, odsState, symbol, limit, - odsBuilder.getStringAttr("nested")); -} - -//===----------------------------------------------------------------------===// -// MemImportOp -//===----------------------------------------------------------------------===// - -void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, StringRef moduleName, - StringRef importName, LimitType limits) { - MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - limits, odsBuilder.getStringAttr("nested")); -} - -//===----------------------------------------------------------------------===// // ReinterpretOp //===----------------------------------------------------------------------===// @@ -471,24 +439,3 @@ LogicalResult ReinterpretOp::verify() { //===----------------------------------------------------------------------===// void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {} - -//===----------------------------------------------------------------------===// -// TableOp -//===----------------------------------------------------------------------===// - -void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, TableType type) { - TableOp::build(odsBuilder, odsState, symbol, type, - odsBuilder.getStringAttr("nested")); -} - -//===----------------------------------------------------------------------===// -// TableImportOp -//===----------------------------------------------------------------------===// - -void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef symbol, StringRef moduleName, - StringRef importName, TableType type) { - TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName, - type, odsBuilder.getStringAttr("nested")); -} diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9beb22d..1599ae9 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -727,6 +727,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } +// a helper utility to perform binary operation on OpFoldResult. +// If both a and b are attributes, it will simply return the result. +// Otherwise, the corresponding arith op will be generated, and an +// contant op will be created if one of them is an attribute. +template <typename ArithOp> +OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, + OpBuilder &builder) { + auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); + auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); + return builder.create<ArithOp>(loc, aVal, bVal).getResult(); +} + +// a helper utility to perform division operation on OpFoldResult and int64_t. +#define div(a, b) \ + genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform reminder operation on OpFoldResult and int64_t. +#define rem(a, b) \ + genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform multiply operation on OpFoldResult and int64_t. +#define mul(a, b) \ + genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform addition operation on two OpFoldResult. +#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder) + +// block the given offsets according to the block shape +// say the original offset is [y, x], and the block shape is [By, Bx], +// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] +SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets, + ArrayRef<int64_t> blockShape) { + + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + SmallVector<OpFoldResult> blockedOffsets; + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + return blockedOffsets; +} + +// Get strides as vector of integer for MemDesc. +SmallVector<int64_t> MemDescType::getStrideShape() { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + + ArrayAttr strideAttr = getStrideAttr(); + SmallVector<int64_t> strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast<IntegerAttr>(attr).getInt()); + } + + SmallVector<int64_t> innerBlkShape = getBlockShape(); + + // get perm from FCD to LCD + // perm[i] = the dim with i-th smallest stride + SmallVector<int, 4> perm = + llvm::to_vector<4>(llvm::seq<int>(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + + SmallVector<int64_t> innerBlkStride(innerBlkShape.size()); + innerBlkStride[perm[0]] = 1; + for (size_t i = 1; i < perm.size(); ++i) + innerBlkStride[perm[i]] = + innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector<int64_t> matrixShapeOrig(matrixShape.size()); + SmallVector<int64_t> BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } + + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; + + SmallVector<int64_t> outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = + outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; + } + + // combine the inner and outer strides + SmallVector<int64_t> blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + return blockedStrides; +} + +// Calculate the linear offset using the blocked offsets and stride +Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> offsets) { + + SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end()); + SmallVector<int64_t> blockShape = getBlockShape(); + SmallVector<int64_t> strides = getStrideShape(); + SmallVector<OpFoldResult> blockedOffsets; + + // blockshape equal to matrixshape means no blocking + if (llvm::equal(blockShape, matrixShape)) { + // remove the outer dims from strides + strides.erase(strides.begin(), strides.begin() + matrixShape.size()); + } else { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + + SmallVector<OpFoldResult> divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + offsets = blockedOffsets; + } + + // Start with initial value as matrix descriptor's base offset. + Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); + for (size_t i = 0; i < offsets.size(); ++i) { + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + } + + return linearOffset; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788..abd12e2 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -20,8 +20,8 @@ #define DEBUG_TYPE "xegpu" -namespace mlir { -namespace xegpu { +using namespace mlir; +using namespace mlir::xegpu; static bool isSharedMemory(const MemRefType &memrefTy) { Attribute attr = memrefTy.getMemorySpace(); @@ -173,6 +173,49 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } +LogicalResult +IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, + UnitAttr subgroup_block_io, + function_ref<InFlightDiagnostic()> emitError) { + + if (!dataTy) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + else + return success(); + } + + if (mdescTy.getRank() != 2) + return emitError() << "mem_desc must be 2D."; + + ArrayRef<int64_t> dataShape = dataTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + + if (dataShape.size() == 2) { + if (subgroup_block_io) + return emitError() << "subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitError() << "data shape must not exceed mem_desc shape."; + } else { + SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); + // if the subgroup_block_io attribute is set, mdescTy must have block + // attribute + if (subgroup_block_io && !blockShape.size()) + return emitError() << "mem_desc must have block attribute when " + "subgroup_block_io is set."; + // if the subgroup_block_io attribute is set, the memdesc should be row + // major + if (subgroup_block_io && mdescTy.isColMajor()) + return emitError() << "mem_desc should be row major when " + "subgroup_block_io is set."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1049,23 +1092,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector<int64_t> staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + // Call the generated builder with all parameters (including optional ones as + // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { - VectorType resTy = getRes().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); + auto resTy = dyn_cast<VectorType>(getRes().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); - ArrayRef<int64_t> valueShape = resTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); - return success(); + return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1080,62 +1120,18 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*subgroup_block_io=*/nullptr, layout); } LogicalResult StoreMatrixOp::verify() { - VectorType dataTy = getData().getType(); - MemDescType mdescTy = getMemDesc().getType(); - - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - - ArrayRef<int64_t> dataShape = dataTy.getShape(); - ArrayRef<int64_t> mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - - return success(); -} - -//===----------------------------------------------------------------------===// -// XeGPU_MemDescSubviewOp -//===----------------------------------------------------------------------===// - -void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, - Type resTy, Value src, - llvm::ArrayRef<OpFoldResult> offsets) { - llvm::SmallVector<Value> dynamicOffsets; - llvm::SmallVector<int64_t> staticOffsets; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); - auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); -} - -LogicalResult MemDescSubviewOp::verify() { - MemDescType srcTy = getSrc().getType(); - MemDescType resTy = getRes().getType(); - ArrayRef<int64_t> srcShape = srcTy.getShape(); - ArrayRef<int64_t> resShape = resTy.getShape(); - - if (srcTy.getRank() < resTy.getRank()) - return emitOpError("result rank must not exceed source rank."); - if (llvm::any_of( - llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed source shape."); - - if (srcTy.getStrides() != resTy.getStrides()) - return emitOpError("result must inherit the source strides."); - - return success(); + auto dataTy = dyn_cast<VectorType>(getData().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MemDescType mdescTy = getMemDesc().getType(); + return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, + [&]() { return emitError(); }); } -} // namespace xegpu -} // namespace mlir - namespace mlir { #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 36c498e..f77784a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const { xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op)) return getTileShape(op->getOpResult(0)); if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, - xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op)) + xegpu::StoreMatrixOp>(op)) return getTileShape(op->getOpOperand(0)); - if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op)) + if (isa<xegpu::StoreNdOp>(op)) return getTileShape(op->getOpOperand(1)); + // Handle LoadGatherOp and StoreScatterOp (with and without offset) + if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) { + if (loadGatherOp.getOffsets()) + return getTileShape(loadGatherOp->getOpResult(0)); + else + return getTileShape(loadGatherOp->getOpOperand(0)); + } + + if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op)) + return getTileShape(storeScatterOp.getOffsets() + ? storeScatterOp->getOpOperand(0) + : storeScatterOp->getOpOperand(1)); + if (isa<xegpu::DpasOp>(op)) { std::optional<SmallVector<int64_t>> aTile = getTileShape(op->getOpOperand(0)); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0f..aafa1b7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,7 +941,9 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = op.getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType()); + assert(valueTy && "the value type must be vector type!"); + std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -984,7 +986,8 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> { return failure(); Location loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType()); + assert(valueTy && "the value type must be vector type!"); ArrayRef<int64_t> shape = valueTy.getShape(); auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c28d2fc..31a967d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -991,7 +991,8 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { return failure(); ArrayRef<int64_t> wgShape = op.getDataShape(); - VectorType valueTy = op.getRes().getType(); + VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType()); + assert(valueTy && "the value type must be vector type!"); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index b72d564..2c56a43 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -52,8 +52,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { // compute sgSize by multiply elements of laneLayout // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1] // e.g. for 1D layout, sgSize = laneLayout[0] - auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1, - std::multiplies<int64_t>()); + int64_t sgSize = llvm::product_of(laneLayout); // Case 1: regular loads/stores auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>(); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 3d19c5a..9b23dd6 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2200,10 +2200,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, os << '>'; } os << '['; - interleave( - loc.getLocations(), - [&](Location loc) { printLocationInternal(loc, pretty); }, - [&]() { os << ", "; }); + interleaveComma(loc.getLocations(), [&](Location loc) { + printLocationInternal(loc, pretty); + }); os << ']'; }) .Default([&](LocationAttr loc) { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 776b5c6..4d81918 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl { } // Otherwise, try to load the source file. - std::string ignored; - unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored); + auto bufferOrErr = llvm::MemoryBuffer::getFile(filename); + if (!bufferOrErr) + return 0; + unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc()); filenameToBufId[filename] = id; return id; } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 1fa04ed..5f63fe6 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -121,6 +121,11 @@ namespace mlir { class MLIRContextImpl { public: //===--------------------------------------------------------------------===// + // Remark + //===--------------------------------------------------------------------===// + std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; + + //===--------------------------------------------------------------------===// // Debugging //===--------------------------------------------------------------------===// @@ -135,11 +140,6 @@ public: DiagnosticEngine diagEngine; //===--------------------------------------------------------------------===// - // Remark - //===--------------------------------------------------------------------===// - std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; - - //===--------------------------------------------------------------------===// // Options //===--------------------------------------------------------------------===// @@ -357,7 +357,10 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); } -MLIRContext::~MLIRContext() = default; +MLIRContext::~MLIRContext() { + // finalize remark engine before destroying anything else. + impl->remarkEngine.reset(); +} /// Copy the specified array of elements into memory managed by the provided /// bump pointer allocator. This assumes the elements are all PODs. @@ -1201,7 +1204,7 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, /// present in result expressions is less than `dimCount` and the highest index /// of symbolic identifier present in result expressions is less than /// `symbolCount`. -LLVM_ATTRIBUTE_UNUSED static bool +[[maybe_unused]] static bool willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, ArrayRef<AffineExpr> results) { int64_t maxDimPosition = -1; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 8bcfa46..ce421f4 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FoldInterfaces.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" #include <numeric> @@ -1274,10 +1275,7 @@ LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, return op->emitOpError("'") << attrName << "' attribute cannot have negative elements"; - size_t totalCount = - std::accumulate(sizes.begin(), sizes.end(), 0, - [](unsigned all, int32_t one) { return all + one; }); - + size_t totalCount = llvm::sum_of(sizes, size_t(0)); if (totalCount != expectedCount) return op->emitOpError() << valueGroupName << " count (" << expectedCount diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 394ac77..2a37f38 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -406,15 +406,13 @@ OperandRangeRange::OperandRangeRange(OperandRange operands, OperandRange OperandRangeRange::join() const { const OwnerT &owner = getBase(); ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second); - return OperandRange(owner.first, - std::accumulate(sizeData.begin(), sizeData.end(), 0)); + return OperandRange(owner.first, llvm::sum_of(sizeData)); } OperandRange OperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second); - uint32_t startIndex = - std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + uint32_t startIndex = llvm::sum_of(sizeData.take_front(index)); return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); } @@ -565,8 +563,7 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second.getValue()); - uint32_t startIndex = - std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + uint32_t startIndex = llvm::sum_of(sizeData.take_front(index)); return object.first.slice( startIndex, *(sizeData.begin() + index), MutableOperandRange::OperandSegment(index, object.second)); diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp index a55f61a..031eae2 100644 --- a/mlir/lib/IR/Remarks.cpp +++ b/mlir/lib/IR/Remarks.cpp @@ -16,7 +16,7 @@ #include "llvm/ADT/StringRef.h" using namespace mlir::remark::detail; - +using namespace mlir::remark; //------------------------------------------------------------------------------ // Remark //------------------------------------------------------------------------------ @@ -70,7 +70,7 @@ static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) { void Remark::print(llvm::raw_ostream &os, bool printLocation) const { // Header: [Type] pass:remarkName StringRef type = getRemarkTypeString(); - StringRef categoryName = getFullCategoryName(); + StringRef categoryName = getCombinedCategoryName(); StringRef name = remarkName; os << '[' << type << "] "; @@ -81,9 +81,10 @@ void Remark::print(llvm::raw_ostream &os, bool printLocation) const { os << "Function=" << getFunction() << " | "; if (printLocation) { - if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) { os << " @" << flc.getFilename() << ":" << flc.getLine() << ":" << flc.getColumn(); + } } printArgs(os, getArgs()); @@ -140,7 +141,7 @@ llvm::remarks::Remark Remark::generateRemark() const { r.RemarkType = getRemarkType(); r.RemarkName = getRemarkName(); // MLIR does not use passes; instead, it has categories and sub-categories. - r.PassName = getFullCategoryName(); + r.PassName = getCombinedCategoryName(); r.FunctionName = getFunction(); r.Loc = locLambda(); for (const Remark::Arg &arg : getArgs()) { @@ -225,26 +226,42 @@ InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc, // RemarkEngine //===----------------------------------------------------------------------===// -void RemarkEngine::report(const Remark &&remark) { +void RemarkEngine::reportImpl(const Remark &remark) { // Stream the remark - if (remarkStreamer) + if (remarkStreamer) { remarkStreamer->streamOptimizationRemark(remark); + } // Print using MLIR's diagnostic if (printAsEmitRemarks) emitRemark(remark.getLocation(), remark.getMsg()); } +void RemarkEngine::report(const Remark &&remark) { + if (remarkEmittingPolicy) + remarkEmittingPolicy->reportRemark(remark); +} + RemarkEngine::~RemarkEngine() { + if (remarkEmittingPolicy) + remarkEmittingPolicy->finalize(); + if (remarkStreamer) remarkStreamer->finalize(); } -llvm::LogicalResult -RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, - std::string *errMsg) { - // If you need to validate categories/filters, do so here and set errMsg. +llvm::LogicalResult RemarkEngine::initialize( + std::unique_ptr<MLIRRemarkStreamerBase> streamer, + std::unique_ptr<RemarkEmittingPolicyBase> remarkEmittingPolicy, + std::string *errMsg) { + remarkStreamer = std::move(streamer); + + auto reportFunc = + std::bind(&RemarkEngine::reportImpl, this, std::placeholders::_1); + remarkEmittingPolicy->initialize(ReportFn(std::move(reportFunc))); + + this->remarkEmittingPolicy = std::move(remarkEmittingPolicy); return success(); } @@ -301,14 +318,15 @@ RemarkEngine::RemarkEngine(bool printAsEmitRemarks, } llvm::LogicalResult mlir::remark::enableOptimizationRemarks( - MLIRContext &ctx, - std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer, - const remark::RemarkCategories &cats, bool printAsEmitRemarks) { + MLIRContext &ctx, std::unique_ptr<detail::MLIRRemarkStreamerBase> streamer, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, + const RemarkCategories &cats, bool printAsEmitRemarks) { auto engine = - std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats); + std::make_unique<detail::RemarkEngine>(printAsEmitRemarks, cats); std::string errMsg; - if (failed(engine->initialize(std::move(streamer), &errMsg))) { + if (failed(engine->initialize(std::move(streamer), + std::move(remarkEmittingPolicy), &errMsg))) { llvm::report_fatal_error( llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg); } @@ -316,3 +334,12 @@ llvm::LogicalResult mlir::remark::enableOptimizationRemarks( return success(); } + +//===----------------------------------------------------------------------===// +// Remark emitting policies +//===----------------------------------------------------------------------===// + +namespace mlir::remark { +RemarkEmittingPolicyAll::RemarkEmittingPolicyAll() = default; +RemarkEmittingPolicyFinal::RemarkEmittingPolicyFinal() = default; +} // namespace mlir::remark diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index d2d115e..e438631 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -104,8 +104,8 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) { LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { if (dims.empty()) return success(); - auto staticDim = std::accumulate( - dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) { + auto staticDim = + llvm::accumulate(dims, dims.front(), [](auto fold, auto dim) { return ShapedType::isDynamic(dim) ? fold : dim; }); return success(llvm::all_of(dims, [&](auto dim) { diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 388de1c..f96af02 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES FunctionInterfaces.cpp IndexingMapOpInterface.cpp InferIntRangeInterface.cpp + InferStridedMetadataInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp MemOpInterfaces.cpp @@ -64,6 +65,21 @@ add_mlir_library(MLIRFunctionInterfaces add_mlir_interface_library(IndexingMapOpInterface) add_mlir_interface_library(InferIntRangeInterface) + +add_mlir_library(MLIRInferStridedMetadataInterface + InferStridedMetadataInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces + + DEPENDS + MLIRInferStridedMetadataInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + MLIRIR +) + add_mlir_interface_library(InferTypeOpInterface) add_mlir_library(MLIRLoopLikeInterface diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index 9f3e97d..84fc9b8 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -146,6 +146,25 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const IntegerValueRange &range) { return os; } +SmallVector<IntegerValueRange> +mlir::getIntValueRanges(ArrayRef<OpFoldResult> values, + GetIntRangeFn getIntRange, int32_t indexBitwidth) { + SmallVector<IntegerValueRange> ranges; + ranges.reserve(values.size()); + for (OpFoldResult ofr : values) { + if (auto value = dyn_cast<Value>(ofr)) { + ranges.push_back(getIntRange(value)); + continue; + } + + // Create a constant range. + auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); + ranges.emplace_back(ConstantIntRanges::constant( + attr.getValue().sextOrTrunc(indexBitwidth))); + } + return ranges; +} + void mlir::intrange::detail::defaultInferResultRanges( InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) { diff --git a/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp new file mode 100644 index 0000000..483e9f1 --- /dev/null +++ b/mlir/lib/Interfaces/InferStridedMetadataInterface.cpp @@ -0,0 +1,36 @@ +//===- InferStridedMetadataInterface.cpp - Strided md inference interface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferStridedMetadataInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include <optional> + +using namespace mlir; + +#include "mlir/Interfaces/InferStridedMetadataInterface.cpp.inc" + +void StridedMetadataRange::print(raw_ostream &os) const { + if (isUninitialized()) { + os << "strided_metadata<None>"; + return; + } + os << "strided_metadata<offset = ["; + llvm::interleaveComma(*offsets, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], sizes = ["; + llvm::interleaveComma(sizes, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "], strides = ["; + llvm::interleaveComma(strides, os, [&](const ConstantIntRanges &range) { + os << "{" << range << "}"; + }); + os << "]>"; +} diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp index c67b242..dd413d2de 100644 --- a/mlir/lib/RegisterAllPasses.cpp +++ b/mlir/lib/RegisterAllPasses.cpp @@ -98,4 +98,5 @@ void mlir::registerAllPasses() { sparse_tensor::registerSparseTensorPipelines(); tosa::registerTosaToLinalgPipelines(); gpu::registerGPUToNVVMPipeline(); + gpu::registerGPUToXeVMPipeline(); } diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp index d213a1a..bf36286 100644 --- a/mlir/lib/Remark/RemarkStreamer.cpp +++ b/mlir/lib/Remark/RemarkStreamer.cpp @@ -60,6 +60,7 @@ void LLVMRemarkStreamer::finalize() { namespace mlir::remark { LogicalResult enableOptimizationRemarksWithLLVMStreamer( MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt, + std::unique_ptr<detail::RemarkEmittingPolicyBase> remarkEmittingPolicy, const RemarkCategories &cat, bool printAsEmitRemarks) { FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr = @@ -67,7 +68,8 @@ LogicalResult enableOptimizationRemarksWithLLVMStreamer( if (failed(sOr)) return failure(); - return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat, + return remark::enableOptimizationRemarks(ctx, std::move(*sOr), + std::move(remarkEmittingPolicy), cat, printAsEmitRemarks); } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 33fbd2a..42843ea 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1835,8 +1835,7 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index, return nullptr; ArrayRef<int32_t> segments = segmentAttr; - unsigned startIndex = - std::accumulate(segments.begin(), segments.begin() + index, 0); + unsigned startIndex = llvm::sum_of(segments.take_front(index)); values = values.slice(startIndex, *std::next(segments.begin(), index)); LDBG() << " * Extracting range[" << startIndex << ", " diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index cb90ef8..d52d5e7 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( raw_ostream &os, const RecordKeeper &records, StringRef tag) : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} -void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef<const Record *> opDefs) { - NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); +void StaticVerifierFunctionEmitter::emitOpConstraints() { emitTypeConstraints(); emitAttrConstraints(); emitPropConstraints(); diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 5fe5f41..1243511 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -357,11 +357,6 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { if (expressionOp.getDoNotInline()) return false; - // Do not inline expressions with side effects to prevent side-effect - // reordering. - if (expressionOp.hasSideEffects()) - return false; - // Do not inline expressions with multiple uses. Value result = expressionOp.getResult(); if (!result.hasOneUse()) @@ -377,7 +372,34 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { // Do not inline expressions used by other expressions or by ops with the // CExpressionInterface. If this was intended, the user could have been merged // into the expression op. - return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user); + if (isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user)) + return false; + + // Expressions with no side-effects can safely be inlined. + if (!expressionOp.hasSideEffects()) + return true; + + // Expressions with side-effects can be only inlined if side-effect ordering + // in the program is provably retained. + + // Require the user to immediately follow the expression. + if (++Block::iterator(expressionOp) != Block::iterator(user)) + return false; + + // These single-operand ops are safe. + if (isa<emitc::IfOp, emitc::SwitchOp, emitc::ReturnOp>(user)) + return true; + + // For assignment look for specific cases to inline as evaluation order of + // its lvalue and rvalue is undefined in C. + if (auto assignOp = dyn_cast<emitc::AssignOp>(user)) { + // Inline if this assignment is of the form `<var> = <expression>`. + if (expressionOp.getResult() == assignOp.getValue() && + isa_and_present<VariableOp>(assignOp.getVar().getDefiningOp())) + return true; + } + + return false; } static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp index e3f075f..8ecb084 100644 --- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp +++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp @@ -464,12 +464,6 @@ static std::string generateOpDefinition(irdl::detail::dictionary &dict, auto opStrings = getStrings(op); fillDict(dict, opStrings); - const auto operandCount = opStrings.opOperandNames.size(); - const auto operandNames = - operandCount ? joinNameList(opStrings.opOperandNames) : "{\"\"}"; - - const auto resultNames = joinNameList(opStrings.opResultNames); - auto resultTypes = llvm::join( llvm::map_range(opStrings.opResultNames, [](StringRef attr) -> std::string { diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp index 4bbcd8e..db39c70 100644 --- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp +++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp @@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) { return UnknownLoc::get(context); // Add a fused location to link the subprogram information. - StringAttr funcName = StringAttr::get(context, subprogram->getName()); StringAttr fileName = StringAttr::get(context, subprogram->getFilename()); return FusedLocWith<DISubprogramAttr>::get( - {NameLoc::get(funcName), - FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, + {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, translate(subprogram), context); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 1e2099d..8de49dd 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -246,7 +246,7 @@ public: // Rewrite all uses of the original variable in `BBName` // with the linear variable in-place - void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName, + void rewriteInPlace(llvm::IRBuilderBase &builder, const std::string &BBName, size_t varIndex) { llvm::SmallVector<llvm::User *> users; for (llvm::User *user : linearOrigVal[varIndex]->users()) diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 9603813..857e31b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -2604,6 +2604,7 @@ static constexpr std::array kExplicitLLVMFuncOpAttributes{ StringLiteral("denormal-fp-math-f32"), StringLiteral("fp-contract"), StringLiteral("frame-pointer"), + StringLiteral("inlinehint"), StringLiteral("instrument-function-entry"), StringLiteral("instrument-function-exit"), StringLiteral("memory"), @@ -2643,6 +2644,8 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, funcOp.setNoInline(true); if (func->hasFnAttribute(llvm::Attribute::AlwaysInline)) funcOp.setAlwaysInline(true); + if (func->hasFnAttribute(llvm::Attribute::InlineHint)) + funcOp.setInlineHint(true); if (func->hasFnAttribute(llvm::Attribute::OptimizeNone)) funcOp.setOptimizeNone(true); if (func->hasFnAttribute(llvm::Attribute::Convergent)) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 5a3eb20..147613f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -922,8 +922,7 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall( assert(opBundleSizes.size() == opBundleTagsAttr.size() && "operand bundles and tags do not match"); - numOpBundleOperands = - std::accumulate(opBundleSizes.begin(), opBundleSizes.end(), size_t(0)); + numOpBundleOperands = llvm::sum_of(opBundleSizes); assert(numOpBundleOperands <= intrOp->getNumOperands() && "operand bundle operands is more than the number of operands"); @@ -1653,6 +1652,8 @@ static void convertFunctionAttributes(LLVMFuncOp func, llvmFunc->addFnAttr(llvm::Attribute::NoInline); if (func.getAlwaysInlineAttr()) llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline); + if (func.getInlineHintAttr()) + llvmFunc->addFnAttr(llvm::Attribute::InlineHint); if (func.getOptimizeNoneAttr()) llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone); if (func.getConvergentAttr()) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 0c3e87a..d9ad8fb 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2619,6 +2619,11 @@ LogicalResult ControlFlowStructurizer::structurize() { // region. We cannot handle such cases given that once a value is sinked into // the SelectionOp/LoopOp's region, there is no escape for it. for (auto *block : constructBlocks) { + if (!block->use_empty()) + return emitError(block->getParent()->getLoc(), + "failed control flow structurization: " + "block has uses outside of the " + "enclosing selection/loop construct"); for (Operation &op : *block) if (!op.use_empty()) return op.emitOpError("failed control flow structurization: value has " diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 132be4e..366ba8f 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" @@ -138,6 +139,10 @@ using ImportDesc = using parsed_inst_t = FailureOr<SmallVector<Value>>; +struct EmptyBlockMarker {}; +using BlockTypeParseResult = + std::variant<EmptyBlockMarker, TypeIdxRecord, Type>; + struct WasmModuleSymbolTables { SmallVector<FunctionSymbolRefContainer> funcSymbols; SmallVector<GlobalSymbolRefContainer> globalSymbols; @@ -175,6 +180,9 @@ class ParserHead; /// Wrapper around SmallVector to only allow access as push and pop on the /// stack. Makes sure that there are no "free accesses" on the stack to preserve /// its state. +/// This class also keep tracks of the Wasm labels defined by different ops, +/// which can be targeted by control flow ops. This can be modeled as part of +/// the Value Stack as Wasm control flow ops can only target enclosing labels. class ValueStack { private: struct LabelLevel { @@ -206,6 +214,16 @@ public: /// if an error occurs. LogicalResult pushResults(ValueRange results, Location *opLoc); + void addLabelLevel(LabelLevelOpInterface levelOp) { + labelLevel.push_back({values.size(), levelOp}); + LDBG() << "Adding a new frame context to ValueStack"; + } + + void dropLabelLevel() { + assert(!labelLevel.empty() && "Trying to drop a frame from empty context"); + auto newSize = labelLevel.pop_back_val().stackIdx; + values.truncate(newSize); + } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// A simple dump function for debugging. /// Writes output to llvm::dbgs(). @@ -214,6 +232,7 @@ public: private: SmallVector<Value> values; + SmallVector<LabelLevel> labelLevel; }; using local_val_t = TypedValue<wasmssa::LocalRefType>; @@ -248,6 +267,19 @@ private: buildNumericOp(OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr); + /// Construct a conversion operation of type \p opType that takes a value from + /// type \p inputType on the stack and will produce a value of type + /// \p outputType. + /// + /// \p opType - The WASM dialect operation to build. + /// \p inputType - The operand type for the built instruction. + /// \p outputType - The result type for the built instruction. + /// + /// \returns The parsed instruction result, or failure. + template <typename opType, typename inputType, typename outputType, + typename... extraArgsT> + inline parsed_inst_t buildConvertOp(OpBuilder &builder, extraArgsT...); + /// This function generates a dispatch tree to associate an opcode with a /// parser. Parsers are registered by specialising the /// `parseSpecificInstruction` function for the op code to handle. @@ -280,11 +312,105 @@ private: } } + /// + /// RAII guard class for creating a nesting level + /// + struct NestingContextGuard { + NestingContextGuard(ExpressionParser &parser, LabelLevelOpInterface levelOp) + : parser{parser} { + parser.addNestingContextLevel(levelOp); + } + NestingContextGuard(NestingContextGuard &&other) : parser{other.parser} { + other.shouldDropOnDestruct = false; + } + NestingContextGuard(NestingContextGuard const &) = delete; + ~NestingContextGuard() { + if (shouldDropOnDestruct) + parser.dropNestingContextLevel(); + } + ExpressionParser &parser; + bool shouldDropOnDestruct = true; + }; + + void addNestingContextLevel(LabelLevelOpInterface levelOp) { + valueStack.addLabelLevel(levelOp); + } + + void dropNestingContextLevel() { + // Should always succeed as we are droping the frame that was previously + // created. + valueStack.dropLabelLevel(); + } + + llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder, + EmptyBlockMarker) { + return builder.getFunctionType({}, {}); + } + + llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder, + TypeIdxRecord type) { + if (type.id >= symbols.moduleFuncTypes.size()) + return emitError(*currentOpLoc, + "type index references nonexistent type (") + << type.id << "). Only " << symbols.moduleFuncTypes.size() + << " types are registered"; + return symbols.moduleFuncTypes[type.id]; + } + + llvm::FailureOr<FunctionType> getFuncTypeFor(OpBuilder &builder, + Type valType) { + return builder.getFunctionType({}, {valType}); + } + + llvm::FailureOr<FunctionType> + getFuncTypeFor(OpBuilder &builder, BlockTypeParseResult parseResult) { + return std::visit( + [this, &builder](auto value) { return getFuncTypeFor(builder, value); }, + parseResult); + } + + llvm::FailureOr<FunctionType> + getFuncTypeFor(OpBuilder &builder, + llvm::FailureOr<BlockTypeParseResult> parseResult) { + if (llvm::failed(parseResult)) + return failure(); + return getFuncTypeFor(builder, *parseResult); + } + + llvm::FailureOr<FunctionType> parseBlockFuncType(OpBuilder &builder); + struct ParseResultWithInfo { SmallVector<Value> opResults; std::byte endingByte; }; + template <typename FilterT = ByteSequence<WasmBinaryEncoding::endByte>> + /// @param blockToFill: the block which content will be populated + /// @param resType: the type that this block is supposed to return + llvm::FailureOr<std::byte> + parseBlockContent(OpBuilder &builder, Block *blockToFill, TypeRange resTypes, + Location opLoc, LabelLevelOpInterface levelOp, + FilterT parseEndBytes = {}) { + OpBuilder::InsertionGuard guard{builder}; + builder.setInsertionPointToStart(blockToFill); + LDBG() << "parsing a block of type " + << builder.getFunctionType(blockToFill->getArgumentTypes(), + resTypes); + auto nC = addNesting(levelOp); + + if (failed(pushResults(blockToFill->getArguments()))) + return failure(); + auto bodyParsingRes = parse(builder, parseEndBytes); + if (failed(bodyParsingRes)) + return failure(); + auto returnOperands = popOperands(resTypes); + if (failed(returnOperands)) + return failure(); + builder.create<BlockReturnOp>(opLoc, *returnOperands); + LDBG() << "end of parsing of a block"; + return bodyParsingRes->endingByte; + } + public: template <std::byte ParseEndByte = WasmBinaryEncoding::endByte> parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {}); @@ -294,7 +420,11 @@ public: parse(OpBuilder &builder, ByteSequence<ExpressionParseEnd...> parsingEndFilters); - FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) { + NestingContextGuard addNesting(LabelLevelOpInterface levelOp) { + return NestingContextGuard{*this, levelOp}; + } + + FailureOr<llvm::SmallVector<Value>> popOperands(TypeRange operandTypes) { return valueStack.popOperands(operandTypes, ¤tOpLoc.value()); } @@ -308,6 +438,12 @@ public: template <typename OpToCreate> parsed_inst_t parseSetOrTee(OpBuilder &); + /// Blocks and Loops have a similar format and differ only in how their exit + /// is handled which doesn´t matter at parsing time. Factorizes in one + /// function. + template <typename OpToCreate> + parsed_inst_t parseBlockLikeOp(OpBuilder &); + private: std::optional<Location> currentOpLoc; ParserHead &parser; @@ -586,6 +722,29 @@ public: return success(); } + llvm::FailureOr<BlockTypeParseResult> parseBlockType(MLIRContext *ctx) { + auto loc = getLocation(); + auto blockIndicator = peek(); + if (failed(blockIndicator)) + return failure(); + if (*blockIndicator == WasmBinaryEncoding::Type::emptyBlockType) { + offset += 1; + return {EmptyBlockMarker{}}; + } + if (isValueOneOf(*blockIndicator, valueTypesEncodings)) + return parseValueType(ctx); + /// Block type idx is a 32 bit positive integer encoded as a 33 bit signed + /// value + auto typeIdx = parseI64(); + if (failed(typeIdx)) + return failure(); + if (*typeIdx < 0 || *typeIdx > std::numeric_limits<uint32_t>::max()) + return emitError(loc, "type ID should be representable with an unsigned " + "32 bits integer. Got ") + << *typeIdx; + return {TypeIdxRecord{static_cast<uint32_t>(*typeIdx)}}; + } + bool end() const { return curHead().empty(); } ParserHead copy() const { return *this; } @@ -701,17 +860,41 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) { void ValueStack::dump() const { llvm::dbgs() << "================= Wasm ValueStack =======================\n"; llvm::dbgs() << "size: " << size() << "\n"; + llvm::dbgs() << "nbFrames: " << labelLevel.size() << '\n'; llvm::dbgs() << "<Top>" << "\n"; // Stack is pushed to via push_back. Therefore the top of the stack is the // end of the vector. Iterate in reverse so that the first thing we print // is the top of the stack. + auto indexGetter = [this]() { + size_t idx = labelLevel.size(); + return [this, idx]() mutable -> std::optional<std::pair<size_t, size_t>> { + llvm::dbgs() << "IDX: " << idx << '\n'; + if (idx == 0) + return std::nullopt; + auto frameId = idx - 1; + auto frameLimit = labelLevel[frameId].stackIdx; + idx -= 1; + return {{frameId, frameLimit}}; + }; + }; + auto getNextFrameIndex = indexGetter(); + auto nextFrameIdx = getNextFrameIndex(); size_t stackSize = size(); - for (size_t idx = 0; idx < stackSize; idx++) { + for (size_t idx = 0; idx < stackSize; ++idx) { size_t actualIdx = stackSize - 1 - idx; + while (nextFrameIdx && (nextFrameIdx->second > actualIdx)) { + llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first + << ")\n"; + nextFrameIdx = getNextFrameIndex(); + } llvm::dbgs() << " "; values[actualIdx].dump(); } + while (nextFrameIdx) { + llvm::dbgs() << " --------------- Frame (" << nextFrameIdx->first << ")\n"; + nextFrameIdx = getNextFrameIndex(); + } llvm::dbgs() << "<Bottom>" << "\n"; llvm::dbgs() << "=========================================================\n"; @@ -726,7 +909,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { return emitError(*opLoc, "stack doesn't contain enough values. trying to get ") << operandTypes.size() << " operands on a stack containing only " - << values.size() << " values."; + << values.size() << " values"; size_t stackIdxOffset = values.size() - operandTypes.size(); SmallVector<Value> res{}; res.reserve(operandTypes.size()); @@ -735,8 +918,7 @@ parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { Type stackType = operand.getType(); if (stackType != operandTypes[i]) return emitError(*opLoc, "invalid operand type on stack. expecting ") - << operandTypes[i] << ", value on stack is of type " << stackType - << "."; + << operandTypes[i] << ", value on stack is of type " << stackType; LDBG() << " POP: " << operand; res.push_back(operand); } @@ -792,6 +974,151 @@ ExpressionParser::parse(OpBuilder &builder, } } +llvm::FailureOr<FunctionType> +ExpressionParser::parseBlockFuncType(OpBuilder &builder) { + return getFuncTypeFor(builder, parser.parseBlockType(builder.getContext())); +} + +template <typename OpToCreate> +parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) { + auto opLoc = currentOpLoc; + auto funcType = parseBlockFuncType(builder); + if (failed(funcType)) + return failure(); + + auto inputTypes = funcType->getInputs(); + auto inputOps = popOperands(inputTypes); + if (failed(inputOps)) + return failure(); + + Block *curBlock = builder.getBlock(); + Region *curRegion = curBlock->getParent(); + auto resTypes = funcType->getResults(); + llvm::SmallVector<Location> locations{}; + locations.resize(resTypes.size(), *currentOpLoc); + auto *successor = + builder.createBlock(curRegion, curRegion->end(), resTypes, locations); + builder.setInsertionPointToEnd(curBlock); + auto blockOp = + builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor); + auto *blockBody = blockOp.createBlock(); + if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp))) + return failure(); + builder.setInsertionPointToStart(successor); + return {ValueRange{successor->getArguments()}}; +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::block>( + OpBuilder &builder) { + return parseBlockLikeOp<BlockOp>(builder); +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::loop>( + OpBuilder &builder) { + return parseBlockLikeOp<LoopOp>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::ifOpCode>(OpBuilder &builder) { + auto opLoc = currentOpLoc; + auto funcType = parseBlockFuncType(builder); + if (failed(funcType)) + return failure(); + + LDBG() << "Parsing an if instruction of type " << *funcType; + auto inputTypes = funcType->getInputs(); + auto conditionValue = popOperands(builder.getI32Type()); + if (failed(conditionValue)) + return failure(); + auto inputOps = popOperands(inputTypes); + if (failed(inputOps)) + return failure(); + + Block *curBlock = builder.getBlock(); + Region *curRegion = curBlock->getParent(); + auto resTypes = funcType->getResults(); + llvm::SmallVector<Location> locations{}; + locations.resize(resTypes.size(), *currentOpLoc); + auto *successor = + builder.createBlock(curRegion, curRegion->end(), resTypes, locations); + builder.setInsertionPointToEnd(curBlock); + auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(), + *inputOps, successor); + auto *ifEntryBlock = ifOp.createIfBlock(); + constexpr auto ifElseFilter = + ByteSequence<WasmBinaryEncoding::endByte, + WasmBinaryEncoding::OpCode::elseOpCode>{}; + auto parseIfRes = parseBlockContent(builder, ifEntryBlock, resTypes, *opLoc, + ifOp, ifElseFilter); + if (failed(parseIfRes)) + return failure(); + if (*parseIfRes == WasmBinaryEncoding::OpCode::elseOpCode) { + LDBG() << " else block is present."; + Block *elseEntryBlock = ifOp.createElseBlock(); + auto parseElseRes = + parseBlockContent(builder, elseEntryBlock, resTypes, *opLoc, ifOp); + if (failed(parseElseRes)) + return failure(); + } + builder.setInsertionPointToStart(successor); + return {ValueRange{successor->getArguments()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::branchIf>(OpBuilder &builder) { + auto level = parser.parseLiteral<uint32_t>(); + if (failed(level)) + return failure(); + Block *curBlock = builder.getBlock(); + Region *curRegion = curBlock->getParent(); + auto sip = builder.saveInsertionPoint(); + Block *elseBlock = builder.createBlock(curRegion, curRegion->end()); + auto condition = popOperands(builder.getI32Type()); + if (failed(condition)) + return failure(); + builder.restoreInsertionPoint(sip); + auto targetOp = + LabelBranchingOpInterface::getTargetOpFromBlock(curBlock, *level); + if (failed(targetOp)) + return failure(); + auto inputTypes = targetOp->getLabelTarget()->getArgumentTypes(); + auto branchArgs = popOperands(inputTypes); + if (failed(branchArgs)) + return failure(); + builder.create<BranchIfOp>(*currentOpLoc, condition->front(), + builder.getUI32IntegerAttr(*level), *branchArgs, + elseBlock); + builder.setInsertionPointToStart(elseBlock); + return {*branchArgs}; +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>( + OpBuilder &builder) { + auto loc = *currentOpLoc; + auto funcIdx = parser.parseLiteral<uint32_t>(); + if (failed(funcIdx)) + return failure(); + if (*funcIdx >= symbols.funcSymbols.size()) + return emitError(loc, "Invalid function index: ") << *funcIdx; + auto callee = symbols.funcSymbols[*funcIdx]; + llvm::ArrayRef<Type> inTypes = callee.functionType.getInputs(); + llvm::ArrayRef<Type> resTypes = callee.functionType.getResults(); + parsed_inst_t inOperands = popOperands(inTypes); + if (failed(inOperands)) + return failure(); + auto callOp = + builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands); + return {callOp.getResults()}; +} + template <> inline parsed_inst_t ExpressionParser::parseSpecificInstruction< WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) { @@ -834,7 +1161,7 @@ parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) { if (valueStack.empty()) return emitError( *currentOpLoc, - "invalid stack access, trying to access a value on an empty stack."); + "invalid stack access, trying to access a value on an empty stack"); parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType()); if (failed(poppedOp)) @@ -956,7 +1283,7 @@ inline parsed_inst_t ExpressionParser::buildNumericOp( << ", type = " << ty << " ***"; auto tysToPop = SmallVector<Type, numOperands>(); tysToPop.resize(numOperands); - std::fill(tysToPop.begin(), tysToPop.end(), ty); + llvm::fill(tysToPop, ty); auto operands = popOperands(tysToPop); if (failed(operands)) return failure(); @@ -1000,11 +1327,23 @@ inline parsed_inst_t ExpressionParser::buildNumericOp( BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign) BUILD_NUMERIC_BINOP_FP(DivOp, div) +BUILD_NUMERIC_BINOP_FP(GeOp, ge) +BUILD_NUMERIC_BINOP_FP(GtOp, gt) +BUILD_NUMERIC_BINOP_FP(LeOp, le) +BUILD_NUMERIC_BINOP_FP(LtOp, lt) BUILD_NUMERIC_BINOP_FP(MaxOp, max) BUILD_NUMERIC_BINOP_FP(MinOp, min) BUILD_NUMERIC_BINOP_INT(AndOp, and) BUILD_NUMERIC_BINOP_INT(DivSIOp, divS) BUILD_NUMERIC_BINOP_INT(DivUIOp, divU) +BUILD_NUMERIC_BINOP_INT(GeSIOp, geS) +BUILD_NUMERIC_BINOP_INT(GeUIOp, geU) +BUILD_NUMERIC_BINOP_INT(GtSIOp, gtS) +BUILD_NUMERIC_BINOP_INT(GtUIOp, gtU) +BUILD_NUMERIC_BINOP_INT(LeSIOp, leS) +BUILD_NUMERIC_BINOP_INT(LeUIOp, leU) +BUILD_NUMERIC_BINOP_INT(LtSIOp, ltS) +BUILD_NUMERIC_BINOP_INT(LtUIOp, ltU) BUILD_NUMERIC_BINOP_INT(OrOp, or) BUILD_NUMERIC_BINOP_INT(RemSIOp, remS) BUILD_NUMERIC_BINOP_INT(RemUIOp, remU) @@ -1015,7 +1354,9 @@ BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS) BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU) BUILD_NUMERIC_BINOP_INT(XOrOp, xor) BUILD_NUMERIC_BINOP_INTFP(AddOp, add) +BUILD_NUMERIC_BINOP_INTFP(EqOp, eq) BUILD_NUMERIC_BINOP_INTFP(MulOp, mul) +BUILD_NUMERIC_BINOP_INTFP(NeOp, ne) BUILD_NUMERIC_BINOP_INTFP(SubOp, sub) BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs) BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil) @@ -1025,6 +1366,7 @@ BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt) BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc) BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz) BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz) +BUILD_NUMERIC_UNARY_OP_INT(EqzOp, eqz) BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt) // Don't need these anymore so let's undef them. @@ -1036,6 +1378,105 @@ BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt) #undef BUILD_NUMERIC_OP #undef BUILD_NUMERIC_CAST_OP +template <typename opType, typename inputType, typename outputType, + typename... extraArgsT> +inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder, + extraArgsT... extraArgs) { + static_assert(std::is_arithmetic_v<inputType>, + "InputType should be an arithmetic type"); + static_assert(std::is_arithmetic_v<outputType>, + "OutputType should be an arithmetic type"); + auto intype = buildLiteralType<inputType>(builder); + auto outType = buildLiteralType<outputType>(builder); + auto operand = popOperands(intype); + if (failed(operand)) + return failure(); + auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(), + extraArgs...); + LDBG() << "Built operation: " << op; + return {{op.getResult()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::demoteF64ToF32>(OpBuilder &builder) { + return buildConvertOp<DemoteOp, double, float>(builder); +} + +template <> +inline parsed_inst_t +ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::wrap>( + OpBuilder &builder) { + return buildConvertOp<WrapOp, int64_t, int32_t>(builder); +} + +#define BUILD_CONVERSION_OP(IN_T, OUT_T, SOURCE_OP, TARGET_OP) \ + template <> \ + inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::SOURCE_OP>(OpBuilder & builder) { \ + return buildConvertOp<TARGET_OP, IN_T, OUT_T>(builder); \ + } + +#define BUILD_CONVERT_OP_FOR(DEST_T, WIDTH) \ + BUILD_CONVERSION_OP(uint32_t, DEST_T, convertUI32F##WIDTH, ConvertUOp) \ + BUILD_CONVERSION_OP(int32_t, DEST_T, convertSI32F##WIDTH, ConvertSOp) \ + BUILD_CONVERSION_OP(uint64_t, DEST_T, convertUI64F##WIDTH, ConvertUOp) \ + BUILD_CONVERSION_OP(int64_t, DEST_T, convertSI64F##WIDTH, ConvertSOp) + +BUILD_CONVERT_OP_FOR(float, 32) +BUILD_CONVERT_OP_FOR(double, 64) + +#undef BUILD_CONVERT_OP_FOR + +BUILD_CONVERSION_OP(int32_t, int64_t, extendS, ExtendSI32Op) +BUILD_CONVERSION_OP(int32_t, int64_t, extendU, ExtendUI32Op) + +#undef BUILD_CONVERSION_OP + +#define BUILD_SLICE_EXTEND_PARSER(IT_WIDTH, EXTRACT_WIDTH) \ + template <> \ + parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::extendI##IT_WIDTH##EXTRACT_WIDTH##S>( \ + OpBuilder & builder) { \ + using inout_t = int##IT_WIDTH##_t; \ + auto attr = builder.getUI32IntegerAttr(EXTRACT_WIDTH); \ + return buildConvertOp<ExtendLowBitsSOp, inout_t, inout_t>(builder, attr); \ + } + +BUILD_SLICE_EXTEND_PARSER(32, 8) +BUILD_SLICE_EXTEND_PARSER(32, 16) +BUILD_SLICE_EXTEND_PARSER(64, 8) +BUILD_SLICE_EXTEND_PARSER(64, 16) +BUILD_SLICE_EXTEND_PARSER(64, 32) + +#undef BUILD_SLICE_EXTEND_PARSER + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::promoteF32ToF64>(OpBuilder &builder) { + return buildConvertOp<PromoteOp, float, double>(builder); +} + +#define BUILD_REINTERPRET_PARSER(WIDTH, FP_TYPE) \ + template <> \ + inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::reinterpretF##WIDTH##AsI##WIDTH>(OpBuilder & \ + builder) { \ + return buildConvertOp<ReinterpretOp, FP_TYPE, int##WIDTH##_t>(builder); \ + } \ + \ + template <> \ + inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::reinterpretI##WIDTH##AsF##WIDTH>(OpBuilder & \ + builder) { \ + return buildConvertOp<ReinterpretOp, int##WIDTH##_t, FP_TYPE>(builder); \ + } + +BUILD_REINTERPRET_PARSER(32, float) +BUILD_REINTERPRET_PARSER(64, double) + +#undef BUILD_REINTERPRET_PARSER + class WasmBinaryParser { private: struct SectionRegistry { @@ -1153,7 +1594,7 @@ private: if (tid.id >= symbols.moduleFuncTypes.size()) return emitError(loc, "invalid type id: ") << tid.id << ". Only " << symbols.moduleFuncTypes.size() - << " type registration."; + << " type registrations"; FunctionType type = symbols.moduleFuncTypes[tid.id]; std::string symbol = symbols.getNewFuncSymbolName(); auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName, @@ -1221,7 +1662,7 @@ public: FileLineColLoc magicLoc = parser.getLocation(); FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size()); if (failed(magic) || magic->compare(wasmHeader)) { - emitError(magicLoc, "source file does not contain valid Wasm header."); + emitError(magicLoc, "source file does not contain valid Wasm header"); return; } auto const expectedVersionString = StringRef{"\1\0\0\0", 4}; @@ -1391,7 +1832,7 @@ WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph, return failure(); Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol); - SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public); + op->setAttr("exported", UnitAttr::get(op->getContext())); StringAttr symName = SymbolTable::getSymbolName(op); return SymbolTable{mOp}.rename(symName, *exportName); } diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp index 9670285..3fda5a7 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -93,7 +93,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { // Emit function to add the generated matchers to the pattern list. os << "template <typename... ConfigsT>\n" - "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" + "[[maybe_unused]] static void populateGeneratedPDLLPatterns(" "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n"; for (const auto &name : patternNames) os << " patterns.add<" << name diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index c883baa..3236b4f 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Parser.h" #include <optional> @@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, llvm::SourceMgr tdSrcMgr; tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); + tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); // This class provides a context argument for the llvm::SourceMgr diagnostic // handler. diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 30fd384..9ef405d 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -37,6 +37,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/Remarks/RemarkFormat.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/ManagedStatic.h" @@ -226,6 +227,18 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { "bitstream", "Print bitstream file")), llvm::cl::cat(remarkCategory)}; + static llvm::cl::opt<RemarkPolicy, /*ExternalStorage=*/true> remarkPolicy{ + "remark-policy", + llvm::cl::desc("Specify the policy for remark output."), + cl::location(remarkPolicyFlag), + llvm::cl::value_desc("format"), + llvm::cl::init(RemarkPolicy::REMARK_POLICY_ALL), + llvm::cl::values(clEnumValN(RemarkPolicy::REMARK_POLICY_ALL, "all", + "Print all remarks"), + clEnumValN(RemarkPolicy::REMARK_POLICY_FINAL, "final", + "Print final remarks")), + llvm::cl::cat(remarkCategory)}; + static cl::opt<std::string, /*ExternalStorage=*/true> remarksAll( "remarks-filter", cl::desc("Show all remarks: passed, missed, failed, analysis"), @@ -517,18 +530,28 @@ performActions(raw_ostream &os, return failure(); context->enableMultithreading(wasThreadingEnabled); - + // Set the remark categories and policy. remark::RemarkCategories cats{ config.getRemarksAllFilter(), config.getRemarksPassedFilter(), config.getRemarksMissedFilter(), config.getRemarksAnalyseFilter(), config.getRemarksFailedFilter()}; mlir::MLIRContext &ctx = *context; + // Helper to create the appropriate policy based on configuration + auto createPolicy = [&config]() + -> std::unique_ptr<mlir::remark::detail::RemarkEmittingPolicyBase> { + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_ALL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyAll>(); + if (config.getRemarkPolicy() == RemarkPolicy::REMARK_POLICY_FINAL) + return std::make_unique<mlir::remark::RemarkEmittingPolicyFinal>(); + + llvm_unreachable("Invalid remark policy"); + }; switch (config.getRemarkFormat()) { case RemarkFormat::REMARK_FORMAT_STDOUT: if (failed(mlir::remark::enableOptimizationRemarks( - ctx, nullptr, cats, true /*printAsEmitRemarks*/))) + ctx, nullptr, createPolicy(), cats, true /*printAsEmitRemarks*/))) return failure(); break; @@ -537,7 +560,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.yaml" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::YAML, cats))) + ctx, file, llvm::remarks::Format::YAML, createPolicy(), cats))) return failure(); break; } @@ -547,7 +570,7 @@ performActions(raw_ostream &os, ? "mlir-remarks.bitstream" : config.getRemarksOutputFile(); if (failed(mlir::remark::enableOptimizationRemarksWithLLVMStreamer( - ctx, file, llvm::remarks::Format::Bitstream, cats))) + ctx, file, llvm::remarks::Format::Bitstream, createPolicy(), cats))) return failure(); break; } @@ -593,6 +616,12 @@ performActions(raw_ostream &os, AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); os << OpWithState(op.get(), asmState) << '\n'; + + // This is required if the remark policy is final. Otherwise, the remarks are + // not emitted. + if (remark::detail::RemarkEngine *engine = ctx.getRemarkEngine()) + engine->getRemarkEmittingPolicy()->finalize(); + return success(); } diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 60b9567..1dbe7eca 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -31,6 +31,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include <optional> using namespace mlir; @@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, llvm::append_range(includeDirs, extraDirs); sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) { diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 3080b78..2d817be 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" #include <optional> @@ -448,6 +449,7 @@ void TableGenTextFile::initialize( return; } sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); // This class provides a context argument for the SourceMgr diagnostic diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index 111f58e..5f3b04a 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -66,7 +66,9 @@ size_t mlir::moveLoopInvariantCode( size_t numMoved = 0; for (Region *region : regions) { - LDBG() << "Original loop:\n" << *region->getParentOp(); + LDBG() << "Original loop:\n" + << OpWithFlags(region->getParentOp(), + OpPrintingFlags().skipRegions()); std::queue<Operation *> worklist; // Add top-level operations in the loop body to the worklist. @@ -90,7 +92,8 @@ size_t mlir::moveLoopInvariantCode( !canBeHoisted(op, definedOutside)) continue; - LDBG() << "Moving loop-invariant op: " << *op; + LDBG() << "Moving loop-invariant op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); moveOutOfRegion(op, region); ++numMoved; @@ -111,9 +114,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { [&](Value value, Region *) { return loopLike.isDefinedOutsideOfLoop(value); }, - [&](Operation *op, Region *) { - return isMemoryEffectFree(op) && isSpeculatable(op); - }, + [&](Operation *op, Region *) { return isPure(op); }, [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); } |