aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp18
-rw-r--r--mlir/lib/Dialect/Tosa/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp42
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp87
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp94
6 files changed, 158 insertions, 85 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 8295492..04e8836 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() {
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, BlockLoadOp, BlockStoreOp>::value>>
LogicalResult verify1DBlockArg(OpType op) {
- VectorType vTy;
+ Type srcOrDstTy;
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
- vTy = op.getResult().getType();
+ srcOrDstTy = op.getResult().getType();
else
- vTy = op.getVal().getType();
+ srcOrDstTy = op.getVal().getType();
+ VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
+ // scalar case is always valid
+ if (!vTy)
+ return success();
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
if (elemTySize == 1) {
- llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
+ llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
+ "vector size must be 2, 4, 8 or 16 for 8-bit element type");
} else {
- llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
+ llvm::SmallSet<int, 3> validSizes{2, 4, 8};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4 or 8 for element type > 8 bits");
+ "vector size must be 2, 4 or 8 for element type > 8 bits");
}
}
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index c6a438d..a95906a 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
+ IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
new file mode 100644
index 0000000..5aad671
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -0,0 +1,42 @@
+//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+
+namespace mlir {
+namespace tosa {
+
+TargetEnvAttr lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
+ break;
+
+ if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
+ return attr;
+
+ op = op->getParentOp();
+ }
+
+ return {};
+}
+
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
+ return TargetEnvAttr::get(context, Level::eightK,
+ {Profile::pro_int, Profile::pro_fp}, {});
+}
+
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = lookupTargetEnv(op))
+ return attr;
+
+ return getDefaultTargetEnv(op->getContext());
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 803993b..41b338d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
new file mode 100644
index 0000000..bcb880a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -0,0 +1,87 @@
+//===- TosaAttachTarget.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Attach target information to a TOSA module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSAATTACHTARGET
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+
+class TosaAttachTarget
+ : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> {
+ using Base::Base;
+
+public:
+ void runOnOperation() override {
+ llvm::SmallVector<Profile, 2> selectedProfiles;
+ if (!profiles.empty()) {
+ for (const std::string &prof : profiles) {
+ std::optional<Profile> profSymbol = symbolizeProfile(prof);
+ if (!profSymbol) {
+ llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allProfiles,
+ "profile", prof);
+ return signalPassFailure();
+ }
+ selectedProfiles.push_back(profSymbol.value());
+ }
+ }
+
+ llvm::SmallVector<Extension, 10> selectedExtensions;
+ if (!extensions.empty()) {
+ for (const std::string &ext : extensions) {
+ std::optional<Extension> extSymbol = symbolizeExtension(ext);
+ if (!extSymbol) {
+ llvm::SmallVector<Extension> allExtensions =
+ ExtensionAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allExtensions,
+ "extension", ext);
+ return signalPassFailure();
+ }
+ selectedExtensions.push_back(extSymbol.value());
+ }
+ }
+
+ ModuleOp mod = getOperation();
+ MLIRContext *ctx = &getContext();
+ const auto targetEnvAttr =
+ TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
+ }
+
+private:
+ template <typename T>
+ std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues,
+ std::string enumName,
+ std::string unknownArgument) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument
+ << "', supported " << enumName << "s are: ";
+ llvm::interleaveComma(enumValues, os);
+ os << "\n";
+ return message;
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4fc7ce8..82f2f7e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
@@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
-struct TosaLevel {
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
- int32_t MAX_LOG2_SIZE = 0;
- int32_t MAX_NESTING = 0;
- int32_t MAX_TENSOR_LIST_SIZE = 0;
-
- bool operator==(const TosaLevel &rhs) {
- return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
- MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
- MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
- }
-};
-
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
-static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
- 63, 256, 256};
-
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
@@ -162,12 +139,9 @@ public:
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
- this->profile = options.profile;
- this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
this->allowInvalidOpDatatypeCombinations =
options.allowInvalidOpDatatypeCombinations;
- this->level = options.level;
}
void runOnOperation() final;
@@ -207,28 +181,28 @@ private:
LogicalResult levelCheckKernel(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL)
+ if (v > targetEnv.getLevel().MAX_KERNEL)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckStride(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE)
+ if (v > targetEnv.getLevel().MAX_STRIDE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckScale(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE)
+ if (v > targetEnv.getLevel().MAX_SCALE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckListSize(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+ if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
return op->emitOpError()
<< "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
return success();
@@ -285,6 +259,7 @@ private:
template <typename T>
LogicalResult levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
+ const TosaLevel tosaLevel = targetEnv.getLevel();
for (auto v : op->getOperands()) {
if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
return failure();
@@ -466,7 +441,7 @@ private:
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
- if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+ if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
return failure();
@@ -523,43 +498,6 @@ private:
return success();
}
- // configure profile and level values from pass options profileName and
- // levelName
- void configLevelAndProfile() {
- tosaLevel = TOSA_LEVEL_NONE;
- if (level == TosaLevelEnum::EightK) {
- tosaLevel = TOSA_LEVEL_EIGHTK;
- }
-
- if (!profile.empty()) {
- for (std::string &prof : profile) {
- auto profSymbol = symbolizeProfile(prof);
- if (profSymbol) {
- targetEnv.addProfile(profSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA profile name passed in: " << prof
- << ", supported profiles are `pro_int` and `pro_fp`\n";
- return signalPassFailure();
- }
- }
- }
-
- if (!extension.empty()) {
- for (std::string &ext : extension) {
- auto extSymbol = symbolizeExtension(ext);
- if (extSymbol) {
- targetEnv.addExtension(extSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA extension name passed in: " << ext
- << ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
- << "doubleround, inexactround and dynamic\n";
- return signalPassFailure();
- }
- }
- }
- }
-
LogicalResult CheckVariable(Operation *op);
LogicalResult CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
@@ -567,7 +505,6 @@ private:
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
constCheckers;
- TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
tosa::TargetEnv targetEnv;
@@ -576,13 +513,13 @@ private:
template <>
LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto *op = tosaOp.getOperation();
- if (failed(
- levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+ if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
+ targetEnv.getLevel().MAX_RANK)))
return failure();
// rank(output) = rank(input) - 1
if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
- tosaLevel.MAX_RANK - 1)))
+ targetEnv.getLevel().MAX_RANK - 1)))
return failure();
return success();
@@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
// Only the condition input has rank limitation.
if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto *op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
if (failed(levelCheckRank(op, variableType, "variable type",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
// defined in 1.7. Levels.
// For each tensor, the number of tensor elements multiplied by the
// element size in bytes must be representable as a tensor_size_t.
- const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
+ const int64_t max_size =
+ (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
if (size > max_size)
return op->emitOpError()
<< "failed level check: " << operandOrResult
@@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
}
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
- if (tosaLevel == TOSA_LEVEL_NONE) {
+ if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
// no need to do level checks
return success();
}
@@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
- configLevelAndProfile();
-
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
+ targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
+
getOperation().walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;