diff options
Diffstat (limited to 'mlir/lib/Dialect/Tosa/IR/TosaOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 67 |
1 files changed, 15 insertions, 52 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 332f1a0..c51b5e9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) { return shapeAdaptor.getNumElements() == 1 ? success() : failure(); } -// Returns the first declaration point prior to this operation or failure if -// not found. -static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op, - StringRef symName) { - ModuleOp module = op->getParentOfType<ModuleOp>(); - tosa::VariableOp varOp = nullptr; - - // TODO: Adopt SymbolTable trait to Varible ops. - // Currently, the variable's definition point is searched via walk(), - // starting from the top-level ModuleOp and stopping at the point of use. Once - // TOSA control flow and variable extensions reach the complete state, may - // leverage MLIR's Symbol Table functionality to look up symbol and enhance - // the search to a TOSA specific graph traversal over the IR structure. - module.walk([&](Operation *tempOp) { - // Reach this op itself. - if (tempOp == op) { - return WalkResult::interrupt(); - } - - if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) { - if (symName == tosaOp.getName()) { - varOp = tosaOp; - return WalkResult::interrupt(); - } - } - - return WalkResult::advance(); - }); - - if (varOp) - return varOp; - - return failure(); -} - template <typename T> static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { - StringRef symName = op.getName(); - FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName); - if (failed(varOp)) + Operation *symTableOp = + op->template getParentWithTrait<OpTrait::SymbolTable>(); + if (!symTableOp) + // If the operation is not the scope of a symbol table, we cannot + // verify it against it's declaration. + return success(); + + SymbolTable symTable(symTableOp); + const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName()); + + // Verify prior declaration + if (!varOp) return op->emitOpError("'") - << symName << "' has not been declared by 'tosa.variable'"; + << op.getName() << "' has not been declared by 'tosa.variable'"; // Verify type and shape - auto variableType = getVariableType(varOp.value()); + auto variableType = getVariableType(varOp); if (errorIfTypeOrShapeMismatch(op, type, name, variableType, "the input tensor") .failed()) return failure(); - return success(); } @@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result, ArrayRef<int64_t> shape = shapedType.getShape(); auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); - result.addAttribute("name", nameAttr); + result.addAttribute("sym_name", nameAttr); result.addAttribute("var_shape", varShapeAttr); result.addAttribute("type", elementTypeAttr); result.addAttribute("initial_value", initialValue); @@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() { return success(); } -LogicalResult tosa::VariableOp::verify() { - StringRef symName = getName(); - FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName); - if (succeeded(varOp)) - return emitOpError("illegal to have multiple declaration of '") - << symName << "'"; - - return success(); -} - LogicalResult tosa::VariableReadOp::verify() { if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'") .failed()) |