aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tosa/IR/TosaOps.cpp')
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp67
1 files changed, 15 insertions, 52 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 332f1a0..c51b5e9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
}
-// Returns the first declaration point prior to this operation or failure if
-// not found.
-static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
- StringRef symName) {
- ModuleOp module = op->getParentOfType<ModuleOp>();
- tosa::VariableOp varOp = nullptr;
-
- // TODO: Adopt SymbolTable trait to Varible ops.
- // Currently, the variable's definition point is searched via walk(),
- // starting from the top-level ModuleOp and stopping at the point of use. Once
- // TOSA control flow and variable extensions reach the complete state, may
- // leverage MLIR's Symbol Table functionality to look up symbol and enhance
- // the search to a TOSA specific graph traversal over the IR structure.
- module.walk([&](Operation *tempOp) {
- // Reach this op itself.
- if (tempOp == op) {
- return WalkResult::interrupt();
- }
-
- if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
- if (symName == tosaOp.getName()) {
- varOp = tosaOp;
- return WalkResult::interrupt();
- }
- }
-
- return WalkResult::advance();
- });
-
- if (varOp)
- return varOp;
-
- return failure();
-}
-
template <typename T>
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
- StringRef symName = op.getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
- if (failed(varOp))
+ Operation *symTableOp =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+ if (!symTableOp)
+ // If the operation is not the scope of a symbol table, we cannot
+ // verify it against it's declaration.
+ return success();
+
+ SymbolTable symTable(symTableOp);
+ const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
+
+ // Verify prior declaration
+ if (!varOp)
return op->emitOpError("'")
- << symName << "' has not been declared by 'tosa.variable'";
+ << op.getName() << "' has not been declared by 'tosa.variable'";
// Verify type and shape
- auto variableType = getVariableType(varOp.value());
+ auto variableType = getVariableType(varOp);
if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
"the input tensor")
.failed())
return failure();
-
return success();
}
@@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> shape = shapedType.getShape();
auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
- result.addAttribute("name", nameAttr);
+ result.addAttribute("sym_name", nameAttr);
result.addAttribute("var_shape", varShapeAttr);
result.addAttribute("type", elementTypeAttr);
result.addAttribute("initial_value", initialValue);
@@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() {
return success();
}
-LogicalResult tosa::VariableOp::verify() {
- StringRef symName = getName();
- FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
- if (succeeded(varOp))
- return emitOpError("illegal to have multiple declaration of '")
- << symName << "'";
-
- return success();
-}
-
LogicalResult tosa::VariableReadOp::verify() {
if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
.failed())