From 46e41c8631bd6c1a6c91d6cc4a5e4f1671078ccd Mon Sep 17 00:00:00 2001 From: Will Dietz Date: Mon, 10 Jun 2024 19:12:34 -0500 Subject: [mlir] Sanitize identifiers with leading symbol. (#94795) Presently, if name starts with a symbol it's converted to hex which may cause the result to be invalid by starting with a digit. Address this and add a small test. Co-authored-by: Will Dietz --- mlir/lib/IR/AsmPrinter.cpp | 10 +++++++--- mlir/test/IR/print-attr-type-aliases.mlir | 3 +++ mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) (limited to 'mlir') diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 6a362af..2c43a6f 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -999,9 +999,13 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, bool allowTrailingDigit = true) { assert(!name.empty() && "Shouldn't have an empty name here"); + auto validChar = [&](char ch) { + return llvm::isAlnum(ch) || allowedPunctChars.contains(ch); + }; + auto copyNameToBuffer = [&] { for (char ch : name) { - if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch)) + if (validChar(ch)) buffer.push_back(ch); else if (ch == ' ') buffer.push_back('_'); @@ -1013,7 +1017,7 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, // Check to see if this name is valid. If it starts with a digit, then it // could conflict with the autogenerated numeric ID's, so add an underscore // prefix to avoid problems. - if (isdigit(name[0])) { + if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) { buffer.push_back('_'); copyNameToBuffer(); return buffer; @@ -1029,7 +1033,7 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, // Check to see that the name consists of only valid identifier characters. for (char ch : name) { - if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) { + if (!validChar(ch)) { copyNameToBuffer(); return buffer; } diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir index 162eacd..27c5a75 100644 --- a/mlir/test/IR/print-attr-type-aliases.mlir +++ b/mlir/test/IR/print-attr-type-aliases.mlir @@ -11,6 +11,9 @@ // CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit" "test.op"() {alias_test = "alias_test:prefixed_digit"} : () -> () +// CHECK-DAG: #_25test = "alias_test:prefixed_symbol" +"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> () + // CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a" // CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b" "test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index a3a8913..64add8c 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -188,6 +188,7 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { .Case("alias_test:dot_in_name", StringRef("test.alias")) .Case("alias_test:trailing_digit", StringRef("test_alias0")) .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) + .Case("alias_test:prefixed_symbol", StringRef("%test")) .Case("alias_test:sanitize_conflict_a", StringRef("test_alias_conflict0")) .Case("alias_test:sanitize_conflict_b", -- cgit v1.1