diff options
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp | 18 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 42 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp | 87 | ||||
-rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 94 |
6 files changed, 158 insertions, 85 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp index 8295492..04e8836 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() { template <typename OpType, typename = std::enable_if_t<llvm::is_one_of< OpType, BlockLoadOp, BlockStoreOp>::value>> LogicalResult verify1DBlockArg(OpType op) { - VectorType vTy; + Type srcOrDstTy; if constexpr (std::is_same_v<OpType, BlockLoadOp>) - vTy = op.getResult().getType(); + srcOrDstTy = op.getResult().getType(); else - vTy = op.getVal().getType(); + srcOrDstTy = op.getVal().getType(); + VectorType vTy = dyn_cast<VectorType>(srcOrDstTy); + // scalar case is always valid + if (!vTy) + return success(); int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8; if (elemTySize == 1) { - llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16}; + llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type"); + "vector size must be 2, 4, 8 or 16 for 8-bit element type"); } else { - llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8}; + llvm::SmallSet<int, 3> validSizes{2, 4, 8}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4 or 8 for element type > 8 bits"); + "vector size must be 2, 4 or 8 for element type > 8 bits"); } } diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index c6a438d..a95906a 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTosaDialect IR/TosaOps.cpp IR/TosaCanonicalizations.cpp + IR/TargetEnv.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp new file mode 100644 index 0000000..5aad671 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -0,0 +1,42 @@ +//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" + +namespace mlir { +namespace tosa { + +TargetEnvAttr lookupTargetEnv(Operation *op) { + while (op) { + op = SymbolTable::getNearestSymbolTable(op); + if (!op) + break; + + if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name)) + return attr; + + op = op->getParentOp(); + } + + return {}; +} + +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { + return TargetEnvAttr::get(context, Level::eightK, + {Profile::pro_int, Profile::pro_fp}, {}); +} + +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { + if (auto attr = lookupTargetEnv(op)) + return attr; + + return getDefaultTargetEnv(op->getContext()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 803993b..41b338d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaAttachTarget.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp new file mode 100644 index 0000000..bcb880a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -0,0 +1,87 @@ +//===- TosaAttachTarget.cpp +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attach target information to a TOSA module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +#define GEN_PASS_DEF_TOSAATTACHTARGET +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +namespace { + +class TosaAttachTarget + : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> { + using Base::Base; + +public: + void runOnOperation() override { + llvm::SmallVector<Profile, 2> selectedProfiles; + if (!profiles.empty()) { + for (const std::string &prof : profiles) { + std::optional<Profile> profSymbol = symbolizeProfile(prof); + if (!profSymbol) { + llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allProfiles, + "profile", prof); + return signalPassFailure(); + } + selectedProfiles.push_back(profSymbol.value()); + } + } + + llvm::SmallVector<Extension, 10> selectedExtensions; + if (!extensions.empty()) { + for (const std::string &ext : extensions) { + std::optional<Extension> extSymbol = symbolizeExtension(ext); + if (!extSymbol) { + llvm::SmallVector<Extension> allExtensions = + ExtensionAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allExtensions, + "extension", ext); + return signalPassFailure(); + } + selectedExtensions.push_back(extSymbol.value()); + } + } + + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + const auto targetEnvAttr = + TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + mod->setAttr(TargetEnvAttr::name, targetEnvAttr); + } + +private: + template <typename T> + std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues, + std::string enumName, + std::string unknownArgument) { + std::string message; + llvm::raw_string_ostream os(message); + os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument + << "', supported " << enumName << "s are: "; + llvm::interleaveComma(enumValues, os); + os << "\n"; + return message; + } +}; + +} // namespace + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4fc7ce8..82f2f7e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Tosa/IR/TargetEnv.h" #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" #include <string> @@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op, return success(); } -struct TosaLevel { - int32_t MAX_RANK = 0; - int32_t MAX_KERNEL = 0; - int32_t MAX_STRIDE = 0; - int32_t MAX_SCALE = 0; - int32_t MAX_LOG2_SIZE = 0; - int32_t MAX_NESTING = 0; - int32_t MAX_TENSOR_LIST_SIZE = 0; - - bool operator==(const TosaLevel &rhs) { - return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && - MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && - MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && - MAX_NESTING == rhs.MAX_NESTING && - MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; - } -}; - -static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; -static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, - 63, 256, 256}; - //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// @@ -162,12 +139,9 @@ public: explicit TosaValidation(const TosaValidationOptions &options) : TosaValidation() { - this->profile = options.profile; - this->extension = options.extension; this->strictOpSpecAlignment = options.strictOpSpecAlignment; this->allowInvalidOpDatatypeCombinations = options.allowInvalidOpDatatypeCombinations; - this->level = options.level; } void runOnOperation() final; @@ -207,28 +181,28 @@ private: LogicalResult levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_KERNEL) + if (v > targetEnv.getLevel().MAX_KERNEL) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_STRIDE) + if (v > targetEnv.getLevel().MAX_STRIDE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_SCALE) + if (v > targetEnv.getLevel().MAX_SCALE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) + if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE) return op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc; return success(); @@ -285,6 +259,7 @@ private: template <typename T> LogicalResult levelCheckRanks(T tosaOp) { auto op = tosaOp.getOperation(); + const TosaLevel tosaLevel = targetEnv.getLevel(); for (auto v : op->getOperands()) { if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))) return failure(); @@ -466,7 +441,7 @@ private: int32_t maxNestedDepth = 0; getMaxNestedDepth(op, maxNestedDepth); - if (maxNestedDepth >= tosaLevel.MAX_NESTING) { + if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) { op->emitOpError() << "failed level check: " << maxNestedDepth << " >= MAX_NESTING"; return failure(); @@ -523,43 +498,6 @@ private: return success(); } - // configure profile and level values from pass options profileName and - // levelName - void configLevelAndProfile() { - tosaLevel = TOSA_LEVEL_NONE; - if (level == TosaLevelEnum::EightK) { - tosaLevel = TOSA_LEVEL_EIGHTK; - } - - if (!profile.empty()) { - for (std::string &prof : profile) { - auto profSymbol = symbolizeProfile(prof); - if (profSymbol) { - targetEnv.addProfile(profSymbol.value()); - } else { - llvm::errs() << "unknown TOSA profile name passed in: " << prof - << ", supported profiles are `pro_int` and `pro_fp`\n"; - return signalPassFailure(); - } - } - } - - if (!extension.empty()) { - for (std::string &ext : extension) { - auto extSymbol = symbolizeExtension(ext); - if (extSymbol) { - targetEnv.addExtension(extSymbol.value()); - } else { - llvm::errs() << "unknown TOSA extension name passed in: " << ext - << ", supported extension are int16, int4, bf16, " - << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " - << "doubleround, inexactround and dynamic\n"; - return signalPassFailure(); - } - } - } - } - LogicalResult CheckVariable(Operation *op); LogicalResult CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type, const bool allowUnsigned = false); @@ -567,7 +505,6 @@ private: SmallVector< std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>> constCheckers; - TosaLevel tosaLevel; DenseMap<StringAttr, mlir::Type> variablesMap; TosaProfileCompliance profileComp; tosa::TargetEnv targetEnv; @@ -576,13 +513,13 @@ private: template <> LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { auto *op = tosaOp.getOperation(); - if (failed( - levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))) + if (failed(levelCheckRank(op, tosaOp.getInput(), "operand", + targetEnv.getLevel().MAX_RANK))) return failure(); // rank(output) = rank(input) - 1 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result", - tosaLevel.MAX_RANK - 1))) + targetEnv.getLevel().MAX_RANK - 1))) return failure(); return success(); @@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { // Only the condition input has rank limitation. if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { auto *op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); if (failed(levelCheckRank(op, variableType, "variable type", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, // defined in 1.7. Levels. // For each tensor, the number of tensor elements multiplied by the // element size in bytes must be representable as a tensor_size_t. - const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; + const int64_t max_size = + (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1; if (size > max_size) return op->emitOpError() << "failed level check: " << operandOrResult @@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, } LogicalResult TosaValidation::applyLevelCheck(Operation *op) { - if (tosaLevel == TOSA_LEVEL_NONE) { + if (targetEnv.getLevel() == TOSA_LEVEL_NONE) { // no need to do level checks return success(); } @@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { - configLevelAndProfile(); - TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; + targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); + getOperation().walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; |