//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/SymbolTable.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" #include "gtest/gtest.h" using namespace mlir; namespace test { void registerTestDialect(DialectRegistry &); } // namespace test class ReplaceAllSymbolUsesTest : public ::testing::Test { protected: using ReplaceFnType = llvm::function_ref; void SetUp() override { ::test::registerTestDialect(registry); context = std::make_unique(registry); } void testReplaceAllSymbolUses(ReplaceFnType replaceFn) { // Set up IR and find func ops. OwningOpRef module = parseSourceString(kInput, context.get()); SymbolTable symbolTable(module.get()); auto opIterator = module->getBody(0)->getOperations().begin(); auto fooOp = cast(opIterator++); auto barOp = cast(opIterator++); ASSERT_EQ(fooOp.getNameAttr(), "foo"); ASSERT_EQ(barOp.getNameAttr(), "bar"); // Call test function that does symbol replacement. LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp); ASSERT_TRUE(succeeded(res)); ASSERT_TRUE(succeeded(verify(module.get()))); // Check that it got renamed. bool calleeFound = false; fooOp->walk([&](CallOpInterface callOp) { StringAttr callee = dyn_cast(callOp.getCallableForCallee()) .getLeafReference(); EXPECT_EQ(callee, "baz"); calleeFound = true; }); EXPECT_TRUE(calleeFound); } std::unique_ptr context; private: constexpr static llvm::StringLiteral kInput = R"MLIR( module { test.conversion_func_op private @foo() { "test.conversion_call_op"() { callee=@bar } : () -> () "test.return"() : () -> () } test.conversion_func_op private @bar() } )MLIR"; DialectRegistry registry; }; namespace { TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { // Symbol as `Operation *`, rename within module. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), module); }); } TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { // Symbol as `StringAttr`, rename within module. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), StringAttr::get(context.get(), "baz"), module); }); } TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { // Symbol as `Operation *`, rename within module body. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0)); }); } TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { // Symbol as `StringAttr`, rename within module body. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), StringAttr::get(context.get(), "baz"), &module->getRegion(0)); }); } TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { // Symbol as `Operation *`, rename within function. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), fooOp); }); } TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { // Symbol as `StringAttr`, rename within function. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), StringAttr::get(context.get(), "baz"), fooOp); }); } TEST(SymbolOpInterface, Visibility) { DialectRegistry registry; ::test::registerTestDialect(registry); MLIRContext context(registry); constexpr static StringLiteral kInput = R"MLIR( "test.overridden_symbol_visibility"() {sym_name = "symbol_name"} : () -> () )MLIR"; OwningOpRef module = parseSourceString(kInput, &context); auto symOp = cast(module->getBody()->front()); ASSERT_TRUE(symOp.isPrivate()); ASSERT_FALSE(symOp.isPublic()); ASSERT_FALSE(symOp.isNested()); ASSERT_TRUE(symOp.canDiscardOnUseEmpty()); std::string diagStr; context.getDiagEngine().registerHandler( [&](Diagnostic &diag) { diagStr += diag.str(); }); std::string expectedDiag; symOp.setPublic(); expectedDiag += "'test.overridden_symbol_visibility' op cannot change " "visibility of symbol to public"; symOp.setNested(); expectedDiag += "'test.overridden_symbol_visibility' op cannot change " "visibility of symbol to nested"; symOp.setPrivate(); expectedDiag += "'test.overridden_symbol_visibility' op cannot change " "visibility of symbol to private"; ASSERT_EQ(diagStr, expectedDiag); } } // namespace