aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tosa
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tosa')
-rw-r--r--mlir/lib/Dialect/Tosa/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp42
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp87
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp94
5 files changed, 147 insertions, 78 deletions
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index c6a438d..a95906a 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
+ IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
new file mode 100644
index 0000000..5aad671
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -0,0 +1,42 @@
+//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+
+namespace mlir {
+namespace tosa {
+
+TargetEnvAttr lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
+ break;
+
+ if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
+ return attr;
+
+ op = op->getParentOp();
+ }
+
+ return {};
+}
+
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
+ return TargetEnvAttr::get(context, Level::eightK,
+ {Profile::pro_int, Profile::pro_fp}, {});
+}
+
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = lookupTargetEnv(op))
+ return attr;
+
+ return getDefaultTargetEnv(op->getContext());
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 803993b..41b338d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
new file mode 100644
index 0000000..bcb880a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -0,0 +1,87 @@
+//===- TosaAttachTarget.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Attach target information to a TOSA module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSAATTACHTARGET
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+
+class TosaAttachTarget
+ : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> {
+ using Base::Base;
+
+public:
+ void runOnOperation() override {
+ llvm::SmallVector<Profile, 2> selectedProfiles;
+ if (!profiles.empty()) {
+ for (const std::string &prof : profiles) {
+ std::optional<Profile> profSymbol = symbolizeProfile(prof);
+ if (!profSymbol) {
+ llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allProfiles,
+ "profile", prof);
+ return signalPassFailure();
+ }
+ selectedProfiles.push_back(profSymbol.value());
+ }
+ }
+
+ llvm::SmallVector<Extension, 10> selectedExtensions;
+ if (!extensions.empty()) {
+ for (const std::string &ext : extensions) {
+ std::optional<Extension> extSymbol = symbolizeExtension(ext);
+ if (!extSymbol) {
+ llvm::SmallVector<Extension> allExtensions =
+ ExtensionAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allExtensions,
+ "extension", ext);
+ return signalPassFailure();
+ }
+ selectedExtensions.push_back(extSymbol.value());
+ }
+ }
+
+ ModuleOp mod = getOperation();
+ MLIRContext *ctx = &getContext();
+ const auto targetEnvAttr =
+ TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
+ }
+
+private:
+ template <typename T>
+ std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues,
+ std::string enumName,
+ std::string unknownArgument) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument
+ << "', supported " << enumName << "s are: ";
+ llvm::interleaveComma(enumValues, os);
+ os << "\n";
+ return message;
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4fc7ce8..82f2f7e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
@@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
-struct TosaLevel {
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
- int32_t MAX_LOG2_SIZE = 0;
- int32_t MAX_NESTING = 0;
- int32_t MAX_TENSOR_LIST_SIZE = 0;
-
- bool operator==(const TosaLevel &rhs) {
- return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
- MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
- MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
- }
-};
-
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
-static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
- 63, 256, 256};
-
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
@@ -162,12 +139,9 @@ public:
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
- this->profile = options.profile;
- this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
this->allowInvalidOpDatatypeCombinations =
options.allowInvalidOpDatatypeCombinations;
- this->level = options.level;
}
void runOnOperation() final;
@@ -207,28 +181,28 @@ private:
LogicalResult levelCheckKernel(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL)
+ if (v > targetEnv.getLevel().MAX_KERNEL)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckStride(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE)
+ if (v > targetEnv.getLevel().MAX_STRIDE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckScale(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE)
+ if (v > targetEnv.getLevel().MAX_SCALE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckListSize(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+ if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
return op->emitOpError()
<< "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
return success();
@@ -285,6 +259,7 @@ private:
template <typename T>
LogicalResult levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
+ const TosaLevel tosaLevel = targetEnv.getLevel();
for (auto v : op->getOperands()) {
if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
return failure();
@@ -466,7 +441,7 @@ private:
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
- if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+ if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
return failure();
@@ -523,43 +498,6 @@ private:
return success();
}
- // configure profile and level values from pass options profileName and
- // levelName
- void configLevelAndProfile() {
- tosaLevel = TOSA_LEVEL_NONE;
- if (level == TosaLevelEnum::EightK) {
- tosaLevel = TOSA_LEVEL_EIGHTK;
- }
-
- if (!profile.empty()) {
- for (std::string &prof : profile) {
- auto profSymbol = symbolizeProfile(prof);
- if (profSymbol) {
- targetEnv.addProfile(profSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA profile name passed in: " << prof
- << ", supported profiles are `pro_int` and `pro_fp`\n";
- return signalPassFailure();
- }
- }
- }
-
- if (!extension.empty()) {
- for (std::string &ext : extension) {
- auto extSymbol = symbolizeExtension(ext);
- if (extSymbol) {
- targetEnv.addExtension(extSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA extension name passed in: " << ext
- << ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
- << "doubleround, inexactround and dynamic\n";
- return signalPassFailure();
- }
- }
- }
- }
-
LogicalResult CheckVariable(Operation *op);
LogicalResult CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
@@ -567,7 +505,6 @@ private:
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
constCheckers;
- TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
tosa::TargetEnv targetEnv;
@@ -576,13 +513,13 @@ private:
template <>
LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto *op = tosaOp.getOperation();
- if (failed(
- levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+ if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
+ targetEnv.getLevel().MAX_RANK)))
return failure();
// rank(output) = rank(input) - 1
if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
- tosaLevel.MAX_RANK - 1)))
+ targetEnv.getLevel().MAX_RANK - 1)))
return failure();
return success();
@@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
// Only the condition input has rank limitation.
if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto *op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
if (failed(levelCheckRank(op, variableType, "variable type",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
// defined in 1.7. Levels.
// For each tensor, the number of tensor elements multiplied by the
// element size in bytes must be representable as a tensor_size_t.
- const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
+ const int64_t max_size =
+ (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
if (size > max_size)
return op->emitOpError()
<< "failed level check: " << operandOrResult
@@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
}
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
- if (tosaLevel == TOSA_LEVEL_NONE) {
+ if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
// no need to do level checks
return success();
}
@@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
- configLevelAndProfile();
-
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
+ targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
+
getOperation().walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;