diff options
Diffstat (limited to 'mlir')
23 files changed, 307 insertions, 131 deletions
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index f482385..ab9b9f2 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -39,8 +39,7 @@ void addTosaToLinalgPasses( TosaToLinalgNamedOptions(), // Note: Default to 'none' level unless otherwise specified. std::optional<tosa::TosaValidationOptions> validationOptions = - tosa::TosaValidationOptions{ - {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None}); + tosa::TosaValidationOptions{false, false}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 9ee5079..10491f6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -20,24 +20,67 @@ namespace mlir { namespace tosa { +struct TosaLevel { + int32_t MAX_RANK = 0; + int32_t MAX_KERNEL = 0; + int32_t MAX_STRIDE = 0; + int32_t MAX_SCALE = 0; + int32_t MAX_LOG2_SIZE = 0; + int32_t MAX_NESTING = 0; + int32_t MAX_TENSOR_LIST_SIZE = 0; + + bool operator==(const TosaLevel &rhs) { + return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && + MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && + MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && + MAX_NESTING == rhs.MAX_NESTING && + MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; + } +}; + +static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; +static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, + 63, 256, 256}; + +TargetEnvAttr lookupTargetEnv(Operation *op); +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); + +/// Queries the target environment recursively from enclosing symbol table ops +/// containing the given `op` or returns the default target environment as +/// returned by getDefaultTargetEnv() if not provided. +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); + /// This class represents the capability enabled in the target implementation -/// such as profile, extension, and level. +/// such as profile, extension, and level. It's a wrapper class around +/// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(const SmallVectorImpl<Profile> &profiles, - const SmallVectorImpl<Extension> &extensions) { + explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles, + const ArrayRef<Extension> &extensions) + : level(level) { enabledProfiles.insert_range(profiles); - enabledExtensions.insert_range(extensions); } + explicit TargetEnv(TargetEnvAttr targetAttr) + : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), + targetAttr.getExtensions()) {} + void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } // TODO implement the following utilities. // Version getSpecVersion() const; - // TosaLevel getLevel() const; + + TosaLevel getLevel() const { + if (level == Level::eightK) + return TOSA_LEVEL_EIGHTK; + else if (level == Level::none) + return TOSA_LEVEL_NONE; + else + llvm_unreachable("Unknown TOSA level"); + }; // Returns true if the given profile is allowed. bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; } @@ -62,8 +105,9 @@ public: } private: + Level level; llvm::SmallSet<Profile, 3> enabledProfiles; - llvm::SmallSet<Extension, 8> enabledExtensions; + llvm::SmallSet<Extension, 13> enabledExtensions; }; } // namespace tosa diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 80337fc..38cb293 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>; def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>; def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>; +def Tosa_ProfileAttr + : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", + [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> { + let extraClassDeclaration = [{ + static llvm::SmallVector<Profile, 2> getAllValues() { + return {Profile::pro_int, Profile::pro_fp}; + } + }]; +} + +def Tosa_ProfileArrayAttr + : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">; + def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>; def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>; def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>; @@ -264,17 +277,27 @@ def Tosa_ExtensionAttr Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC - ]>; + ]> { + let extraClassDeclaration = [{ + static llvm::SmallVector<Extension, 11> getAllValues() { + return { + Extension::int16, Extension::int4, Extension::bf16, + Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, + Extension::variable, Extension::controlflow, Extension::doubleround, + Extension::inexactround, Extension::dynamic + }; + } + }]; +} def Tosa_ExtensionArrayAttr : TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">; -def Tosa_ProfileAttr - : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", - [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>; +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; -def Tosa_ProfileArrayAttr - : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">; +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; // The base class for defining op availability dimensions. class Availability { @@ -382,6 +405,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability { } //===----------------------------------------------------------------------===// +// TOSA target environment. +//===----------------------------------------------------------------------===// +def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { + let summary = "Target environment information."; + let parameters = ( ins + "Level": $level, + ArrayRefParameter<"Profile">: $profiles, + ArrayRefParameter<"Extension">: $extensions + ); + + let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + "`extensions` `=` `[` $extensions `]` `>`"; +} + +//===----------------------------------------------------------------------===// // Iterable attributes. //===----------------------------------------------------------------------===// // Defined in `section 3. Enumerations` of the TOSA specification. diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt index 7484473..f52b82a 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,7 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) -mlir_tablegen(PassesEnums.h.inc -gen-enum-decls) -mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs) add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen) add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 306e4b1..ba99d2f 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -15,7 +15,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index b966828..6ae19d8 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass }]; } -def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", - [ - I32EnumAttrCase<"None", 0, "none">, - I32EnumAttrCase<"EightK", 1, "8k">, - ]>{ - let cppNamespace = "mlir::tosa"; -} - def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { let summary = "Validates TOSA dialect"; let description = [{ @@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { }]; let options = [ - ListOption<"profile", "profile", "std::string", - "Validate if operations match for the given profile set">, - ListOption<"extension", "extension", "std::string", - "Validate if operations match for the given extension set">, Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, @@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { /*default=*/"false", "Disable checks for operations that are determined to be invalid due to their " "operand/result datatypes not aligning with the 'Supported Data Types' " - "sections of the specifciation">, - Option<"level", "level", "mlir::tosa::TosaLevelEnum", - /*default=*/"mlir::tosa::TosaLevelEnum::EightK", - "Validate if operator parameters are within specfication for the given level", - [{::llvm::cl::values( - clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k", - "Ranges are expected to be sufficient for applications with frame sizes up to 8K."), - clEnumValN(mlir::tosa::TosaLevelEnum::None, "none", - "Allows the full range of arguments specified by the operations according " - "to the operation data types.") - )}]> + "sections of the specifciation"> ]; } @@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle }]; } +def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { + let summary = "Attach tosa.target_env information to the given module."; + + let description = [{ + This pass allows the user to specify a TOSA target environment consisting of + the following components: level, profiles and extensions. + + The target environment is attached to the module as an attribute, allowing other + transformations to query the selected target and adapt their behaviour based on + this information. + }]; + + let dependentDialects = [ + "func::FuncDialect", + "tosa::TosaDialect", + ]; + + let options = [ + Option<"level", "level", "mlir::tosa::Level", + /*default=*/"mlir::tosa::Level::eightK", + "The TOSA level that operators should conform to. A TOSA level defines " + "operator argument ranges that an implementation shall support.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::Level::eightK, "8k", + "Ranges are expected to be sufficient for applications with frame " + "sizes up to 8K."), + clEnumValN(mlir::tosa::Level::none, "none", + "Allows the full range of arguments specified by the operations according " + "to the operation data types.") + )}]>, + ListOption<"profiles", "profiles", "std::string", + "The TOSA profile(s) that operators should conform to. TOSA profiles " + "enable efficient implementation on different classes of device. Each " + "profile is an independent set of operations and data type combinations.">, + ListOption<"extensions", "extensions", "std::string", + "The TOSA extension(s) that operators should conform to. TOSA profile " + "extensions define optional operation and data type combinations."> + ]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index c6a3ba9..e7602b4 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() { TosaToLinalgOptions tosaToLinalgOptions; TosaToLinalgNamedOptions tosaToLinalgNamedOptions; TosaValidationOptions validationOptions; - validationOptions.profile = {"none"}; - validationOptions.extension = {"none"}; validationOptions.strictOpSpecAlignment = false; validationOptions.allowInvalidOpDatatypeCombinations = false; - validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, validationOptions); diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index c6a438d..a95906a 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTosaDialect IR/TosaOps.cpp IR/TosaCanonicalizations.cpp + IR/TargetEnv.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp new file mode 100644 index 0000000..5aad671 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -0,0 +1,42 @@ +//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" + +namespace mlir { +namespace tosa { + +TargetEnvAttr lookupTargetEnv(Operation *op) { + while (op) { + op = SymbolTable::getNearestSymbolTable(op); + if (!op) + break; + + if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name)) + return attr; + + op = op->getParentOp(); + } + + return {}; +} + +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { + return TargetEnvAttr::get(context, Level::eightK, + {Profile::pro_int, Profile::pro_fp}, {}); +} + +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { + if (auto attr = lookupTargetEnv(op)) + return attr; + + return getDefaultTargetEnv(op->getContext()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 803993b..41b338d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaAttachTarget.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp new file mode 100644 index 0000000..bcb880a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -0,0 +1,87 @@ +//===- TosaAttachTarget.cpp +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attach target information to a TOSA module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +#define GEN_PASS_DEF_TOSAATTACHTARGET +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +namespace { + +class TosaAttachTarget + : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> { + using Base::Base; + +public: + void runOnOperation() override { + llvm::SmallVector<Profile, 2> selectedProfiles; + if (!profiles.empty()) { + for (const std::string &prof : profiles) { + std::optional<Profile> profSymbol = symbolizeProfile(prof); + if (!profSymbol) { + llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allProfiles, + "profile", prof); + return signalPassFailure(); + } + selectedProfiles.push_back(profSymbol.value()); + } + } + + llvm::SmallVector<Extension, 10> selectedExtensions; + if (!extensions.empty()) { + for (const std::string &ext : extensions) { + std::optional<Extension> extSymbol = symbolizeExtension(ext); + if (!extSymbol) { + llvm::SmallVector<Extension> allExtensions = + ExtensionAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allExtensions, + "extension", ext); + return signalPassFailure(); + } + selectedExtensions.push_back(extSymbol.value()); + } + } + + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + const auto targetEnvAttr = + TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + mod->setAttr(TargetEnvAttr::name, targetEnvAttr); + } + +private: + template <typename T> + std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues, + std::string enumName, + std::string unknownArgument) { + std::string message; + llvm::raw_string_ostream os(message); + os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument + << "', supported " << enumName << "s are: "; + llvm::interleaveComma(enumValues, os); + os << "\n"; + return message; + } +}; + +} // namespace + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4fc7ce8..82f2f7e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Tosa/IR/TargetEnv.h" #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" #include <string> @@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op, return success(); } -struct TosaLevel { - int32_t MAX_RANK = 0; - int32_t MAX_KERNEL = 0; - int32_t MAX_STRIDE = 0; - int32_t MAX_SCALE = 0; - int32_t MAX_LOG2_SIZE = 0; - int32_t MAX_NESTING = 0; - int32_t MAX_TENSOR_LIST_SIZE = 0; - - bool operator==(const TosaLevel &rhs) { - return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && - MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && - MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && - MAX_NESTING == rhs.MAX_NESTING && - MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; - } -}; - -static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; -static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, - 63, 256, 256}; - //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// @@ -162,12 +139,9 @@ public: explicit TosaValidation(const TosaValidationOptions &options) : TosaValidation() { - this->profile = options.profile; - this->extension = options.extension; this->strictOpSpecAlignment = options.strictOpSpecAlignment; this->allowInvalidOpDatatypeCombinations = options.allowInvalidOpDatatypeCombinations; - this->level = options.level; } void runOnOperation() final; @@ -207,28 +181,28 @@ private: LogicalResult levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_KERNEL) + if (v > targetEnv.getLevel().MAX_KERNEL) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_STRIDE) + if (v > targetEnv.getLevel().MAX_STRIDE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_SCALE) + if (v > targetEnv.getLevel().MAX_SCALE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) + if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE) return op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc; return success(); @@ -285,6 +259,7 @@ private: template <typename T> LogicalResult levelCheckRanks(T tosaOp) { auto op = tosaOp.getOperation(); + const TosaLevel tosaLevel = targetEnv.getLevel(); for (auto v : op->getOperands()) { if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))) return failure(); @@ -466,7 +441,7 @@ private: int32_t maxNestedDepth = 0; getMaxNestedDepth(op, maxNestedDepth); - if (maxNestedDepth >= tosaLevel.MAX_NESTING) { + if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) { op->emitOpError() << "failed level check: " << maxNestedDepth << " >= MAX_NESTING"; return failure(); @@ -523,43 +498,6 @@ private: return success(); } - // configure profile and level values from pass options profileName and - // levelName - void configLevelAndProfile() { - tosaLevel = TOSA_LEVEL_NONE; - if (level == TosaLevelEnum::EightK) { - tosaLevel = TOSA_LEVEL_EIGHTK; - } - - if (!profile.empty()) { - for (std::string &prof : profile) { - auto profSymbol = symbolizeProfile(prof); - if (profSymbol) { - targetEnv.addProfile(profSymbol.value()); - } else { - llvm::errs() << "unknown TOSA profile name passed in: " << prof - << ", supported profiles are `pro_int` and `pro_fp`\n"; - return signalPassFailure(); - } - } - } - - if (!extension.empty()) { - for (std::string &ext : extension) { - auto extSymbol = symbolizeExtension(ext); - if (extSymbol) { - targetEnv.addExtension(extSymbol.value()); - } else { - llvm::errs() << "unknown TOSA extension name passed in: " << ext - << ", supported extension are int16, int4, bf16, " - << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " - << "doubleround, inexactround and dynamic\n"; - return signalPassFailure(); - } - } - } - } - LogicalResult CheckVariable(Operation *op); LogicalResult CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type, const bool allowUnsigned = false); @@ -567,7 +505,6 @@ private: SmallVector< std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>> constCheckers; - TosaLevel tosaLevel; DenseMap<StringAttr, mlir::Type> variablesMap; TosaProfileCompliance profileComp; tosa::TargetEnv targetEnv; @@ -576,13 +513,13 @@ private: template <> LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { auto *op = tosaOp.getOperation(); - if (failed( - levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))) + if (failed(levelCheckRank(op, tosaOp.getInput(), "operand", + targetEnv.getLevel().MAX_RANK))) return failure(); // rank(output) = rank(input) - 1 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result", - tosaLevel.MAX_RANK - 1))) + targetEnv.getLevel().MAX_RANK - 1))) return failure(); return success(); @@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { // Only the condition input has rank limitation. if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { auto *op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); if (failed(levelCheckRank(op, variableType, "variable type", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, // defined in 1.7. Levels. // For each tensor, the number of tensor elements multiplied by the // element size in bytes must be representable as a tensor_size_t. - const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; + const int64_t max_size = + (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1; if (size > max_size) return op->emitOpError() << "failed level check: " << operandOrResult @@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, } LogicalResult TosaValidation::applyLevelCheck(Operation *op) { - if (tosaLevel == TOSA_LEVEL_NONE) { + if (targetEnv.getLevel() == TOSA_LEVEL_NONE) { // no need to do level checks return success(); } @@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { - configLevelAndProfile(); - TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; + targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); + getOperation().walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index aaf8371..60b70b8 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,7 +2,7 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations" // ----- diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index 2f9421c..334f52a 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment" // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e60f1c9b..2a3985c 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> { diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 1daabe9..e5c9402 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -2,7 +2,7 @@ // Enable all supported profiles to focus the verification of expected extension requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 5bf2dbb8..8cc357e 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 225b962..09e96ec 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 58a73d6..7ff8065 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index a5784b3..48e79e4 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_i1() -> tensor<3x11x11x3xi1> { diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir new file mode 100644 index 0000000..d6c886c --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K +// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT + +// ----- + +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-LABEL: test_simple +func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { + %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + return %1 : tensor<1x1x1x1xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir index f05ae7f..8e0ad0a 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" --tosa-validate="strict-op-spec-alignment" | FileCheck %s // ----- diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir index 88ec027..663159e 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s // ----- |