aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h3
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h56
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td50
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h1
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td64
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp3
-rw-r--r--mlir/lib/Dialect/Tosa/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp42
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp87
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp94
-rw-r--r--mlir/test/Dialect/Tosa/dynamic_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/error_if_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_all_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-attach-target.mlir14
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-valid.mlir2
23 files changed, 307 insertions, 131 deletions
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index f482385..ab9b9f2 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -39,8 +39,7 @@ void addTosaToLinalgPasses(
TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
- tosa::TosaValidationOptions{
- {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
+ tosa::TosaValidationOptions{false, false});
/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 9ee5079..10491f6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -20,24 +20,67 @@
namespace mlir {
namespace tosa {
+struct TosaLevel {
+ int32_t MAX_RANK = 0;
+ int32_t MAX_KERNEL = 0;
+ int32_t MAX_STRIDE = 0;
+ int32_t MAX_SCALE = 0;
+ int32_t MAX_LOG2_SIZE = 0;
+ int32_t MAX_NESTING = 0;
+ int32_t MAX_TENSOR_LIST_SIZE = 0;
+
+ bool operator==(const TosaLevel &rhs) {
+ return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
+ MAX_NESTING == rhs.MAX_NESTING &&
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+ }
+};
+
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
+ 63, 256, 256};
+
+TargetEnvAttr lookupTargetEnv(Operation *op);
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+
+/// Queries the target environment recursively from enclosing symbol table ops
+/// containing the given `op` or returns the default target environment as
+/// returned by getDefaultTargetEnv() if not provided.
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+
/// This class represents the capability enabled in the target implementation
-/// such as profile, extension, and level.
+/// such as profile, extension, and level. It's a wrapper class around
+/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
- const SmallVectorImpl<Extension> &extensions) {
+ explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ const ArrayRef<Extension> &extensions)
+ : level(level) {
enabledProfiles.insert_range(profiles);
-
enabledExtensions.insert_range(extensions);
}
+ explicit TargetEnv(TargetEnvAttr targetAttr)
+ : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
+ targetAttr.getExtensions()) {}
+
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
// TODO implement the following utilities.
// Version getSpecVersion() const;
- // TosaLevel getLevel() const;
+
+ TosaLevel getLevel() const {
+ if (level == Level::eightK)
+ return TOSA_LEVEL_EIGHTK;
+ else if (level == Level::none)
+ return TOSA_LEVEL_NONE;
+ else
+ llvm_unreachable("Unknown TOSA level");
+ };
// Returns true if the given profile is allowed.
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -62,8 +105,9 @@ public:
}
private:
+ Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
- llvm::SmallSet<Extension, 8> enabledExtensions;
+ llvm::SmallSet<Extension, 13> enabledExtensions;
};
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 80337fc..38cb293 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
+def Tosa_ProfileAttr
+ : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
+ [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Profile, 2> getAllValues() {
+ return {Profile::pro_int, Profile::pro_fp};
+ }
+ }];
+}
+
+def Tosa_ProfileArrayAttr
+ : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
@@ -264,17 +277,27 @@ def Tosa_ExtensionAttr
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
Tosa_EXT_DYNAMIC
- ]>;
+ ]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Extension, 11> getAllValues() {
+ return {
+ Extension::int16, Extension::int4, Extension::bf16,
+ Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
+ Extension::variable, Extension::controlflow, Extension::doubleround,
+ Extension::inexactround, Extension::dynamic
+ };
+ }
+ }];
+}
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_ProfileAttr
- : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
- [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-def Tosa_ProfileArrayAttr
- : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
// The base class for defining op availability dimensions.
class Availability {
@@ -382,6 +405,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
}
//===----------------------------------------------------------------------===//
+// TOSA target environment.
+//===----------------------------------------------------------------------===//
+def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
+ let summary = "Target environment information.";
+ let parameters = ( ins
+ "Level": $level,
+ ArrayRefParameter<"Profile">: $profiles,
+ ArrayRefParameter<"Extension">: $extensions
+ );
+
+ let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ "`extensions` `=` `[` $extensions `]` `>`";
+}
+
+//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
// Defined in `section 3. Enumerations` of the TOSA specification.
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index 7484473..f52b82a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,7 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
-mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
-mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen)
add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 306e4b1..ba99d2f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index b966828..6ae19d8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass
}];
}
-def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
- [
- I32EnumAttrCase<"None", 0, "none">,
- I32EnumAttrCase<"EightK", 1, "8k">,
- ]>{
- let cppNamespace = "mlir::tosa";
-}
-
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
@@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
}];
let options = [
- ListOption<"profile", "profile", "std::string",
- "Validate if operations match for the given profile set">,
- ListOption<"extension", "extension", "std::string",
- "Validate if operations match for the given extension set">,
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
@@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
/*default=*/"false",
"Disable checks for operations that are determined to be invalid due to their "
"operand/result datatypes not aligning with the 'Supported Data Types' "
- "sections of the specifciation">,
- Option<"level", "level", "mlir::tosa::TosaLevelEnum",
- /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
- "Validate if operator parameters are within specfication for the given level",
- [{::llvm::cl::values(
- clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
- "Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
- clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
- "Allows the full range of arguments specified by the operations according "
- "to the operation data types.")
- )}]>
+ "sections of the specifciation">
];
}
@@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle
}];
}
+def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
+ let summary = "Attach tosa.target_env information to the given module.";
+
+ let description = [{
+ This pass allows the user to specify a TOSA target environment consisting of
+ the following components: level, profiles and extensions.
+
+ The target environment is attached to the module as an attribute, allowing other
+ transformations to query the selected target and adapt their behaviour based on
+ this information.
+ }];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+
+ let options = [
+ Option<"level", "level", "mlir::tosa::Level",
+ /*default=*/"mlir::tosa::Level::eightK",
+ "The TOSA level that operators should conform to. A TOSA level defines "
+ "operator argument ranges that an implementation shall support.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::Level::eightK, "8k",
+ "Ranges are expected to be sufficient for applications with frame "
+ "sizes up to 8K."),
+ clEnumValN(mlir::tosa::Level::none, "none",
+ "Allows the full range of arguments specified by the operations according "
+ "to the operation data types.")
+ )}]>,
+ ListOption<"profiles", "profiles", "std::string",
+ "The TOSA profile(s) that operators should conform to. TOSA profiles "
+ "enable efficient implementation on different classes of device. Each "
+ "profile is an independent set of operations and data type combinations.">,
+ ListOption<"extensions", "extensions", "std::string",
+ "The TOSA extension(s) that operators should conform to. TOSA profile "
+ "extensions define optional operation and data type combinations.">
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index c6a3ba9..e7602b4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
TosaToLinalgOptions tosaToLinalgOptions;
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
TosaValidationOptions validationOptions;
- validationOptions.profile = {"none"};
- validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
- validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
validationOptions);
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index c6a438d..a95906a 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
+ IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
new file mode 100644
index 0000000..5aad671
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -0,0 +1,42 @@
+//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+
+namespace mlir {
+namespace tosa {
+
+TargetEnvAttr lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
+ break;
+
+ if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
+ return attr;
+
+ op = op->getParentOp();
+ }
+
+ return {};
+}
+
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
+ return TargetEnvAttr::get(context, Level::eightK,
+ {Profile::pro_int, Profile::pro_fp}, {});
+}
+
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = lookupTargetEnv(op))
+ return attr;
+
+ return getDefaultTargetEnv(op->getContext());
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 803993b..41b338d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
new file mode 100644
index 0000000..bcb880a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -0,0 +1,87 @@
+//===- TosaAttachTarget.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Attach target information to a TOSA module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSAATTACHTARGET
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+
+class TosaAttachTarget
+ : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> {
+ using Base::Base;
+
+public:
+ void runOnOperation() override {
+ llvm::SmallVector<Profile, 2> selectedProfiles;
+ if (!profiles.empty()) {
+ for (const std::string &prof : profiles) {
+ std::optional<Profile> profSymbol = symbolizeProfile(prof);
+ if (!profSymbol) {
+ llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allProfiles,
+ "profile", prof);
+ return signalPassFailure();
+ }
+ selectedProfiles.push_back(profSymbol.value());
+ }
+ }
+
+ llvm::SmallVector<Extension, 10> selectedExtensions;
+ if (!extensions.empty()) {
+ for (const std::string &ext : extensions) {
+ std::optional<Extension> extSymbol = symbolizeExtension(ext);
+ if (!extSymbol) {
+ llvm::SmallVector<Extension> allExtensions =
+ ExtensionAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allExtensions,
+ "extension", ext);
+ return signalPassFailure();
+ }
+ selectedExtensions.push_back(extSymbol.value());
+ }
+ }
+
+ ModuleOp mod = getOperation();
+ MLIRContext *ctx = &getContext();
+ const auto targetEnvAttr =
+ TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
+ }
+
+private:
+ template <typename T>
+ std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues,
+ std::string enumName,
+ std::string unknownArgument) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument
+ << "', supported " << enumName << "s are: ";
+ llvm::interleaveComma(enumValues, os);
+ os << "\n";
+ return message;
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4fc7ce8..82f2f7e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
@@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
-struct TosaLevel {
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
- int32_t MAX_LOG2_SIZE = 0;
- int32_t MAX_NESTING = 0;
- int32_t MAX_TENSOR_LIST_SIZE = 0;
-
- bool operator==(const TosaLevel &rhs) {
- return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
- MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
- MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
- }
-};
-
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
-static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
- 63, 256, 256};
-
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
@@ -162,12 +139,9 @@ public:
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
- this->profile = options.profile;
- this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
this->allowInvalidOpDatatypeCombinations =
options.allowInvalidOpDatatypeCombinations;
- this->level = options.level;
}
void runOnOperation() final;
@@ -207,28 +181,28 @@ private:
LogicalResult levelCheckKernel(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL)
+ if (v > targetEnv.getLevel().MAX_KERNEL)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckStride(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE)
+ if (v > targetEnv.getLevel().MAX_STRIDE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckScale(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE)
+ if (v > targetEnv.getLevel().MAX_SCALE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckListSize(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+ if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
return op->emitOpError()
<< "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
return success();
@@ -285,6 +259,7 @@ private:
template <typename T>
LogicalResult levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
+ const TosaLevel tosaLevel = targetEnv.getLevel();
for (auto v : op->getOperands()) {
if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
return failure();
@@ -466,7 +441,7 @@ private:
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
- if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+ if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
return failure();
@@ -523,43 +498,6 @@ private:
return success();
}
- // configure profile and level values from pass options profileName and
- // levelName
- void configLevelAndProfile() {
- tosaLevel = TOSA_LEVEL_NONE;
- if (level == TosaLevelEnum::EightK) {
- tosaLevel = TOSA_LEVEL_EIGHTK;
- }
-
- if (!profile.empty()) {
- for (std::string &prof : profile) {
- auto profSymbol = symbolizeProfile(prof);
- if (profSymbol) {
- targetEnv.addProfile(profSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA profile name passed in: " << prof
- << ", supported profiles are `pro_int` and `pro_fp`\n";
- return signalPassFailure();
- }
- }
- }
-
- if (!extension.empty()) {
- for (std::string &ext : extension) {
- auto extSymbol = symbolizeExtension(ext);
- if (extSymbol) {
- targetEnv.addExtension(extSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA extension name passed in: " << ext
- << ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
- << "doubleround, inexactround and dynamic\n";
- return signalPassFailure();
- }
- }
- }
- }
-
LogicalResult CheckVariable(Operation *op);
LogicalResult CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
@@ -567,7 +505,6 @@ private:
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
constCheckers;
- TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
tosa::TargetEnv targetEnv;
@@ -576,13 +513,13 @@ private:
template <>
LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto *op = tosaOp.getOperation();
- if (failed(
- levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+ if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
+ targetEnv.getLevel().MAX_RANK)))
return failure();
// rank(output) = rank(input) - 1
if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
- tosaLevel.MAX_RANK - 1)))
+ targetEnv.getLevel().MAX_RANK - 1)))
return failure();
return success();
@@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
// Only the condition input has rank limitation.
if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto *op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
if (failed(levelCheckRank(op, variableType, "variable type",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
// defined in 1.7. Levels.
// For each tensor, the number of tensor elements multiplied by the
// element size in bytes must be representable as a tensor_size_t.
- const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
+ const int64_t max_size =
+ (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
if (size > max_size)
return op->emitOpError()
<< "failed level check: " << operandOrResult
@@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
}
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
- if (tosaLevel == TOSA_LEVEL_NONE) {
+ if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
// no need to do level checks
return success();
}
@@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
- configLevelAndProfile();
-
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
+ targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
+
getOperation().walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index aaf8371..60b70b8 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,7 +2,7 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
// -----
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 2f9421c..334f52a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment"
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e60f1c9b..2a3985c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 1daabe9..e5c9402 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles to focus the verification of expected extension requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 5bf2dbb8..8cc357e 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 225b962..09e96ec 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 58a73d6..7ff8065 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index a5784b3..48e79e4 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_i1() -> tensor<3x11x11x3xi1> {
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
new file mode 100644
index 0000000..d6c886c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+
+// -----
+
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-LABEL: test_simple
+func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
+ return %1 : tensor<1x1x1x1xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
index f05ae7f..8e0ad0a 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" --tosa-validate="strict-op-spec-alignment" | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index 88ec027..663159e 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s
// -----