aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/Support/SourceMgr.h22
-rw-r--r--llvm/lib/Support/SourceMgr.cpp16
-rw-r--r--mlir/include/mlir/IR/OpBase.td6
-rw-r--r--mlir/include/mlir/TableGen/Constraint.h5
-rw-r--r--mlir/include/mlir/Tools/PDLL/AST/Context.h13
-rw-r--r--mlir/include/mlir/Tools/PDLL/ODS/Constraint.h98
-rw-r--r--mlir/include/mlir/Tools/PDLL/ODS/Context.h78
-rw-r--r--mlir/include/mlir/Tools/PDLL/ODS/Dialect.h64
-rw-r--r--mlir/include/mlir/Tools/PDLL/ODS/Operation.h189
-rw-r--r--mlir/lib/TableGen/Constraint.cpp23
-rw-r--r--mlir/lib/Tools/PDLL/AST/CMakeLists.txt1
-rw-r--r--mlir/lib/Tools/PDLL/AST/Context.cpp2
-rw-r--r--mlir/lib/Tools/PDLL/CMakeLists.txt1
-rw-r--r--mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp31
-rw-r--r--mlir/lib/Tools/PDLL/ODS/CMakeLists.txt8
-rw-r--r--mlir/lib/Tools/PDLL/ODS/Context.cpp174
-rw-r--r--mlir/lib/Tools/PDLL/ODS/Dialect.cpp39
-rw-r--r--mlir/lib/Tools/PDLL/ODS/Operation.cpp26
-rw-r--r--mlir/lib/Tools/PDLL/Parser/CMakeLists.txt6
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp344
-rw-r--r--mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll20
-rw-r--r--mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td9
-rw-r--r--mlir/test/mlir-pdll/Parser/directive-failure.pdll2
-rw-r--r--mlir/test/mlir-pdll/Parser/expr-failure.pdll22
-rw-r--r--mlir/test/mlir-pdll/Parser/expr.pdll21
-rw-r--r--mlir/test/mlir-pdll/Parser/include/interfaces.td5
-rw-r--r--mlir/test/mlir-pdll/Parser/include/ops.td26
-rw-r--r--mlir/test/mlir-pdll/Parser/include_td.pdll52
-rw-r--r--mlir/test/mlir-pdll/Parser/stmt-failure.pdll24
-rw-r--r--mlir/tools/mlir-pdll/mlir-pdll.cpp24
30 files changed, 1312 insertions, 39 deletions
diff --git a/llvm/include/llvm/Support/SourceMgr.h b/llvm/include/llvm/Support/SourceMgr.h
index 28716b4..fc6d651 100644
--- a/llvm/include/llvm/Support/SourceMgr.h
+++ b/llvm/include/llvm/Support/SourceMgr.h
@@ -100,6 +100,9 @@ public:
SourceMgr &operator=(SourceMgr &&) = default;
~SourceMgr() = default;
+ /// Return the include directories of this source manager.
+ ArrayRef<std::string> getIncludeDirs() const { return IncludeDirectories; }
+
void setIncludeDirs(const std::vector<std::string> &Dirs) {
IncludeDirectories = Dirs;
}
@@ -147,6 +150,14 @@ public:
return Buffers.size();
}
+ /// Takes the source buffers from the given source manager and append them to
+ /// the current manager.
+ void takeSourceBuffersFrom(SourceMgr &SrcMgr) {
+ std::move(SrcMgr.Buffers.begin(), SrcMgr.Buffers.end(),
+ std::back_inserter(Buffers));
+ SrcMgr.Buffers.clear();
+ }
+
/// Search for a file with the specified name in the current directory or in
/// one of the IncludeDirs.
///
@@ -156,6 +167,17 @@ public:
unsigned AddIncludeFile(const std::string &Filename, SMLoc IncludeLoc,
std::string &IncludedFile);
+ /// Search for a file with the specified name in the current directory or in
+ /// one of the IncludeDirs, and try to open it **without** adding to the
+ /// SourceMgr. If the opened file is intended to be added to the source
+ /// manager, prefer `AddIncludeFile` instead.
+ ///
+ /// If no file is found, this returns an Error, otherwise it returns the
+ /// buffer of the stacked file. The full path to the included file can be
+ /// found in \p IncludedFile.
+ ErrorOr<std::unique_ptr<MemoryBuffer>>
+ OpenIncludeFile(const std::string &Filename, std::string &IncludedFile);
+
/// Return the ID of the buffer containing the specified location.
///
/// 0 is returned if the buffer is not found.
diff --git a/llvm/lib/Support/SourceMgr.cpp b/llvm/lib/Support/SourceMgr.cpp
index 2eb2989b..42982b4 100644
--- a/llvm/lib/Support/SourceMgr.cpp
+++ b/llvm/lib/Support/SourceMgr.cpp
@@ -40,6 +40,17 @@ static const size_t TabStop = 8;
unsigned SourceMgr::AddIncludeFile(const std::string &Filename,
SMLoc IncludeLoc,
std::string &IncludedFile) {
+ ErrorOr<std::unique_ptr<MemoryBuffer>> NewBufOrErr =
+ OpenIncludeFile(Filename, IncludedFile);
+ if (!NewBufOrErr)
+ return 0;
+
+ return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc);
+}
+
+ErrorOr<std::unique_ptr<MemoryBuffer>>
+SourceMgr::OpenIncludeFile(const std::string &Filename,
+ std::string &IncludedFile) {
IncludedFile = Filename;
ErrorOr<std::unique_ptr<MemoryBuffer>> NewBufOrErr =
MemoryBuffer::getFile(IncludedFile);
@@ -52,10 +63,7 @@ unsigned SourceMgr::AddIncludeFile(const std::string &Filename,
NewBufOrErr = MemoryBuffer::getFile(IncludedFile);
}
- if (!NewBufOrErr)
- return 0;
-
- return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc);
+ return NewBufOrErr;
}
unsigned SourceMgr::FindBufferContainingLoc(SMLoc Loc) const {
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 11bc1a6..fa40ca7 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -363,7 +363,8 @@ class DialectType<Dialect d, Pred condition, string descr = "",
// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results.
-class Variadic<Type type> : TypeConstraint<type.predicate, type.summary> {
+class Variadic<Type type> : TypeConstraint<type.predicate, type.summary,
+ type.cppClassName> {
Type baseType = type;
}
@@ -379,7 +380,8 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
-class Optional<Type type> : TypeConstraint<type.predicate, type.summary> {
+class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
+ type.cppClassName> {
Type baseType = type;
}
diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 4e099aa..b24b9b7 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -54,6 +54,11 @@ public:
// description is not provided, returns the TableGen def name.
StringRef getSummary() const;
+ /// Returns the name of the TablGen def of this constraint. In some cases
+ /// where the current def is anonymous, the name of the base def is used (e.g.
+ /// `Optional<>`/`Variadic<>` type constraints).
+ StringRef getDefName() const;
+
Kind getKind() const { return kind; }
protected:
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Context.h b/mlir/include/mlir/Tools/PDLL/AST/Context.h
index 9781589..f9a9424 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Context.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Context.h
@@ -14,13 +14,17 @@
namespace mlir {
namespace pdll {
+namespace ods {
+class Context;
+} // namespace ods
+
namespace ast {
/// This class represents the main context of the PDLL AST. It handles
/// allocating all of the AST constructs, and manages all state necessary for
/// the AST.
class Context {
public:
- Context();
+ explicit Context(ods::Context &odsContext);
Context(const Context &) = delete;
Context &operator=(const Context &) = delete;
@@ -30,6 +34,10 @@ public:
/// Return the storage uniquer used for AST types.
StorageUniquer &getTypeUniquer() { return typeUniquer; }
+ /// Return the ODS context used by the AST.
+ ods::Context &getODSContext() { return odsContext; }
+ const ods::Context &getODSContext() const { return odsContext; }
+
/// Return the diagnostic engine of this context.
DiagnosticEngine &getDiagEngine() { return diagEngine; }
@@ -37,6 +45,9 @@ private:
/// The diagnostic engine of this AST context.
DiagnosticEngine diagEngine;
+ /// The ODS context used by the AST.
+ ods::Context &odsContext;
+
/// The allocator used for AST nodes, and other entities allocated within the
/// context.
llvm::BumpPtrAllocator allocator;
diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h
new file mode 100644
index 0000000..2703309
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h
@@ -0,0 +1,98 @@
+//===- Constraint.h - MLIR PDLL ODS Constraints -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a PDLL description of ODS constraints. These are used to
+// support the import of constraints defined outside of PDLL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_
+#define MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_
+
+#include <string>
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+
+namespace mlir {
+namespace pdll {
+namespace ods {
+
+//===----------------------------------------------------------------------===//
+// Constraint
+//===----------------------------------------------------------------------===//
+
+/// This class represents a generic ODS constraint.
+class Constraint {
+public:
+ /// Return the name of this constraint.
+ StringRef getName() const { return name; }
+
+ /// Return the summary of this constraint.
+ StringRef getSummary() const { return summary; }
+
+protected:
+ Constraint(StringRef name, StringRef summary)
+ : name(name.str()), summary(summary.str()) {}
+ Constraint(const Constraint &) = delete;
+
+private:
+ /// The name of the constraint.
+ std::string name;
+ /// A summary of the constraint.
+ std::string summary;
+};
+
+//===----------------------------------------------------------------------===//
+// AttributeConstraint
+//===----------------------------------------------------------------------===//
+
+/// This class represents a generic ODS Attribute constraint.
+class AttributeConstraint : public Constraint {
+public:
+ /// Return the name of the underlying c++ class of this constraint.
+ StringRef getCppClass() const { return cppClassName; }
+
+private:
+ AttributeConstraint(StringRef name, StringRef summary, StringRef cppClassName)
+ : Constraint(name, summary), cppClassName(cppClassName.str()) {}
+
+ /// The c++ class of the constraint.
+ std::string cppClassName;
+
+ /// Allow access to the constructor.
+ friend class Context;
+};
+
+//===----------------------------------------------------------------------===//
+// TypeConstraint
+//===----------------------------------------------------------------------===//
+
+/// This class represents a generic ODS Type constraint.
+class TypeConstraint : public Constraint {
+public:
+ /// Return the name of the underlying c++ class of this constraint.
+ StringRef getCppClass() const { return cppClassName; }
+
+private:
+ TypeConstraint(StringRef name, StringRef summary, StringRef cppClassName)
+ : Constraint(name, summary), cppClassName(cppClassName.str()) {}
+
+ /// The c++ class of the constraint.
+ std::string cppClassName;
+
+ /// Allow access to the constructor.
+ friend class Context;
+};
+
+} // namespace ods
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_
diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
new file mode 100644
index 0000000..d0955ab
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h
@@ -0,0 +1,78 @@
+//===- Context.h - MLIR PDLL ODS Context ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_ODS_CONTEXT_H_
+#define MLIR_TOOLS_PDLL_ODS_CONTEXT_H_
+
+#include <string>
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+
+namespace llvm {
+class SMLoc;
+} // namespace llvm
+
+namespace mlir {
+namespace pdll {
+namespace ods {
+class AttributeConstraint;
+class Dialect;
+class Operation;
+class TypeConstraint;
+
+/// This class contains all of the registered ODS operation classes.
+class Context {
+public:
+ Context();
+ ~Context();
+
+ /// Insert a new attribute constraint with the context. Returns the inserted
+ /// constraint, or a previously inserted constraint with the same name.
+ const AttributeConstraint &insertAttributeConstraint(StringRef name,
+ StringRef summary,
+ StringRef cppClass);
+
+ /// Insert a new type constraint with the context. Returns the inserted
+ /// constraint, or a previously inserted constraint with the same name.
+ const TypeConstraint &insertTypeConstraint(StringRef name, StringRef summary,
+ StringRef cppClass);
+
+ /// Insert a new dialect with the context. Returns the inserted dialect, or a
+ /// previously inserted dialect with the same name.
+ Dialect &insertDialect(StringRef name);
+
+ /// Lookup a dialect registered with the given name, or null if no dialect
+ /// with that name was inserted.
+ const Dialect *lookupDialect(StringRef name) const;
+
+ /// Insert a new operation with the context. Returns the inserted operation,
+ /// and a boolean indicating if the operation newly inserted (false if the
+ /// operation already existed).
+ std::pair<Operation *, bool>
+ insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
+
+ /// Lookup an operation registered with the given name, or null if no
+ /// operation with that name is registered.
+ const Operation *lookupOperation(StringRef name) const;
+
+ /// Print the contents of this context to the provided stream.
+ void print(raw_ostream &os) const;
+
+private:
+ llvm::StringMap<std::unique_ptr<AttributeConstraint>> attributeConstraints;
+ llvm::StringMap<std::unique_ptr<Dialect>> dialects;
+ llvm::StringMap<std::unique_ptr<TypeConstraint>> typeConstraints;
+};
+} // namespace ods
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_PDL_pdll_ODS_CONTEXT_H_
diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
new file mode 100644
index 0000000..f75d497
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h
@@ -0,0 +1,64 @@
+//===- Dialect.h - PDLL ODS Dialect -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_ODS_DIALECT_H_
+#define MLIR_TOOLS_PDLL_ODS_DIALECT_H_
+
+#include <string>
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
+
+namespace mlir {
+namespace pdll {
+namespace ods {
+class Operation;
+
+/// This class represents an ODS dialect, and contains information on the
+/// constructs held within the dialect.
+class Dialect {
+public:
+ ~Dialect();
+
+ /// Return the name of this dialect.
+ StringRef getName() const { return name; }
+
+ /// Insert a new operation with the dialect. Returns the inserted operation,
+ /// and a boolean indicating if the operation newly inserted (false if the
+ /// operation already existed).
+ std::pair<Operation *, bool>
+ insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
+
+ /// Lookup an operation registered with the given name, or null if no
+ /// operation with that name is registered.
+ Operation *lookupOperation(StringRef name) const;
+
+ /// Return a map of all of the operations registered to this dialect.
+ const llvm::StringMap<std::unique_ptr<Operation>> &getOperations() const {
+ return operations;
+ }
+
+private:
+ explicit Dialect(StringRef name);
+
+ /// The name of the dialect.
+ std::string name;
+
+ /// The operations defined by the dialect.
+ llvm::StringMap<std::unique_ptr<Operation>> operations;
+
+ /// Allow access to the constructor.
+ friend class Context;
+};
+} // namespace ods
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PDLL_ODS_DIALECT_H_
diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
new file mode 100644
index 0000000..c5b86e1
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h
@@ -0,0 +1,189 @@
+//===- Operation.h - MLIR PDLL ODS Operation --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_ODS_OPERATION_H_
+#define MLIR_TOOLS_PDLL_ODS_OPERATION_H_
+
+#include <string>
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/SMLoc.h"
+
+namespace mlir {
+namespace pdll {
+namespace ods {
+class AttributeConstraint;
+class TypeConstraint;
+
+//===----------------------------------------------------------------------===//
+// VariableLengthKind
+//===----------------------------------------------------------------------===//
+
+enum VariableLengthKind { Single, Optional, Variadic };
+
+//===----------------------------------------------------------------------===//
+// Attribute
+//===----------------------------------------------------------------------===//
+
+/// This class provides an ODS representation of a specific operation attribute.
+/// This includes the name, optionality, and more.
+class Attribute {
+public:
+ /// Return the name of this operand.
+ StringRef getName() const { return name; }
+
+ /// Return true if this attribute is optional.
+ bool isOptional() const { return optional; }
+
+ /// Return the constraint of this attribute.
+ const AttributeConstraint &getConstraint() const { return constraint; }
+
+private:
+ Attribute(StringRef name, bool optional,
+ const AttributeConstraint &constraint)
+ : name(name.str()), optional(optional), constraint(constraint) {}
+
+ /// The ODS name of the attribute.
+ std::string name;
+
+ /// A flag indicating if the attribute is optional.
+ bool optional;
+
+ /// The ODS constraint of this attribute.
+ const AttributeConstraint &constraint;
+
+ /// Allow access to the private constructor.
+ friend class Operation;
+};
+
+//===----------------------------------------------------------------------===//
+// OperandOrResult
+//===----------------------------------------------------------------------===//
+
+/// This class provides an ODS representation of a specific operation operand or
+/// result. This includes the name, variable length flags, and more.
+class OperandOrResult {
+public:
+ /// Return the name of this value.
+ StringRef getName() const { return name; }
+
+ /// Returns true if this value is variadic (Note this is false if the value is
+ /// Optional).
+ bool isVariadic() const {
+ return variableLengthKind == VariableLengthKind::Variadic;
+ }
+
+ /// Returns the variable length kind of this value.
+ VariableLengthKind getVariableLengthKind() const {
+ return variableLengthKind;
+ }
+
+ /// Return the constraint of this value.
+ const TypeConstraint &getConstraint() const { return constraint; }
+
+private:
+ OperandOrResult(StringRef name, VariableLengthKind variableLengthKind,
+ const TypeConstraint &constraint)
+ : name(name.str()), variableLengthKind(variableLengthKind),
+ constraint(constraint) {}
+
+ /// The ODS name of this value.
+ std::string name;
+
+ /// The variable length kind of this value.
+ VariableLengthKind variableLengthKind;
+
+ /// The ODS constraint of this value.
+ const TypeConstraint &constraint;
+
+ /// Allow access to the private constructor.
+ friend class Operation;
+};
+
+//===----------------------------------------------------------------------===//
+// Operation
+//===----------------------------------------------------------------------===//
+
+/// This class provides an ODS representation of a specific operation. This
+/// includes all of the information necessary for use by the PDL frontend for
+/// generating code for a pattern rewrite.
+class Operation {
+public:
+ /// Return the source location of this operation.
+ SMRange getLoc() const { return location; }
+
+ /// Append an attribute to this operation.
+ void appendAttribute(StringRef name, bool optional,
+ const AttributeConstraint &constraint) {
+ attributes.emplace_back(Attribute(name, optional, constraint));
+ }
+
+ /// Append an operand to this operation.
+ void appendOperand(StringRef name, VariableLengthKind variableLengthKind,
+ const TypeConstraint &constraint) {
+ operands.emplace_back(
+ OperandOrResult(name, variableLengthKind, constraint));
+ }
+
+ /// Append a result to this operation.
+ void appendResult(StringRef name, VariableLengthKind variableLengthKind,
+ const TypeConstraint &constraint) {
+ results.emplace_back(OperandOrResult(name, variableLengthKind, constraint));
+ }
+
+ /// Returns the name of the operation.
+ StringRef getName() const { return name; }
+
+ /// Returns the summary of the operation.
+ StringRef getSummary() const { return summary; }
+
+ /// Returns the description of the operation.
+ StringRef getDescription() const { return description; }
+
+ /// Returns the attributes of this operation.
+ ArrayRef<Attribute> getAttributes() const { return attributes; }
+
+ /// Returns the operands of this operation.
+ ArrayRef<OperandOrResult> getOperands() const { return operands; }
+
+ /// Returns the results of this operation.
+ ArrayRef<OperandOrResult> getResults() const { return results; }
+
+private:
+ Operation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);
+
+ /// The name of the operation.
+ std::string name;
+
+ /// The documentation of the operation.
+ std::string summary;
+ std::string description;
+
+ /// The source location of this operation.
+ SMRange location;
+
+ /// The operands of the operation.
+ SmallVector<OperandOrResult> operands;
+
+ /// The results of the operation.
+ SmallVector<OperandOrResult> results;
+
+ /// The attributes of the operation.
+ SmallVector<Attribute> attributes;
+
+ /// Allow access to the private constructor.
+ friend class Dialect;
+};
+} // namespace ods
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PDLL_ODS_OPERATION_H_
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 759e28f..249c22e 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -57,6 +57,29 @@ StringRef Constraint::getSummary() const {
return def->getName();
}
+StringRef Constraint::getDefName() const {
+ // Functor used to check a base def in the case where the current def is
+ // anonymous.
+ auto checkBaseDefFn = [&](StringRef baseName) {
+ if (const auto *init = dyn_cast<llvm::DefInit>(def->getValueInit(baseName)))
+ return Constraint(init->getDef(), kind).getDefName();
+ return def->getName();
+ };
+
+ switch (kind) {
+ case CK_Attr:
+ if (def->isAnonymous())
+ return checkBaseDefFn("baseAttr");
+ return def->getName();
+ case CK_Type:
+ if (def->isAnonymous())
+ return checkBaseDefFn("baseType");
+ return def->getName();
+ default:
+ return def->getName();
+ }
+}
+
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self,
std::vector<std::string> &&entities)
diff --git a/mlir/lib/Tools/PDLL/AST/CMakeLists.txt b/mlir/lib/Tools/PDLL/AST/CMakeLists.txt
index 3eb9c62..5e67ee0 100644
--- a/mlir/lib/Tools/PDLL/AST/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/AST/CMakeLists.txt
@@ -6,5 +6,6 @@ add_mlir_library(MLIRPDLLAST
Types.cpp
LINK_LIBS PUBLIC
+ MLIRPDLLODS
MLIRSupport
)
diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp
index 09ae0e6..6f2e4cd 100644
--- a/mlir/lib/Tools/PDLL/AST/Context.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Context.cpp
@@ -12,7 +12,7 @@
using namespace mlir;
using namespace mlir::pdll::ast;
-Context::Context() {
+Context::Context(ods::Context &odsContext) : odsContext(odsContext) {
typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::RewriteTypeStorage>();
diff --git a/mlir/lib/Tools/PDLL/CMakeLists.txt b/mlir/lib/Tools/PDLL/CMakeLists.txt
index ac83f5e..522429b 100644
--- a/mlir/lib/Tools/PDLL/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(AST)
add_subdirectory(CodeGen)
+add_subdirectory(ODS)
add_subdirectory(Parser)
diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index 81b719c..1f8466f 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -17,6 +17,8 @@
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h"
+#include "mlir/Tools/PDLL/ODS/Context.h"
+#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -33,7 +35,8 @@ class CodeGen {
public:
CodeGen(MLIRContext *mlirContext, const ast::Context &context,
const llvm::SourceMgr &sourceMgr)
- : builder(mlirContext), sourceMgr(sourceMgr) {
+ : builder(mlirContext), odsContext(context.getODSContext()),
+ sourceMgr(sourceMgr) {
// Make sure that the PDL dialect is loaded.
mlirContext->loadDialect<pdl::PDLDialect>();
}
@@ -117,6 +120,9 @@ private:
llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
VariableMapTy variables;
+ /// A reference to the ODS context.
+ const ods::Context &odsContext;
+
/// The source manager of the PDLL ast.
const llvm::SourceMgr &sourceMgr;
};
@@ -435,7 +441,28 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
builder.getI32IntegerAttr(0));
return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
}
- llvm_unreachable("unhandled operation member access expression");
+
+ assert(opType.getName() && "expected valid operation name");
+ const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName());
+ assert(odsOp && "expected valid ODS operation information");
+
+ // Find the result with the member name or by index.
+ ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
+ unsigned resultIndex = results.size();
+ if (llvm::isDigit(name[0])) {
+ name.getAsInteger(/*Radix=*/10, resultIndex);
+ } else {
+ auto findFn = [&](const ods::OperandOrResult &result) {
+ return result.getName() == name;
+ };
+ resultIndex = llvm::find_if(results, findFn) - results.begin();
+ }
+ assert(resultIndex < results.size() && "invalid result index");
+
+ // Generate the result access.
+ IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
+ return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
+ parentExprs[0], index);
}
// Handle tuple based member access.
diff --git a/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt
new file mode 100644
index 0000000..3abbaab
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_library(MLIRPDLLODS
+ Context.cpp
+ Dialect.cpp
+ Operation.cpp
+
+ LINK_LIBS PUBLIC
+ MLIRSupport
+ )
diff --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp
new file mode 100644
index 0000000..7684da5
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp
@@ -0,0 +1,174 @@
+//===- Context.cpp --------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/ODS/Context.h"
+#include "mlir/Tools/PDLL/ODS/Constraint.h"
+#include "mlir/Tools/PDLL/ODS/Dialect.h"
+#include "mlir/Tools/PDLL/ODS/Operation.h"
+#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::pdll::ods;
+
+//===----------------------------------------------------------------------===//
+// Context
+//===----------------------------------------------------------------------===//
+
+Context::Context() = default;
+Context::~Context() = default;
+
+const AttributeConstraint &
+Context::insertAttributeConstraint(StringRef name, StringRef summary,
+ StringRef cppClass) {
+ std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
+ if (!constraint) {
+ constraint.reset(new AttributeConstraint(name, summary, cppClass));
+ } else {
+ assert(constraint->getCppClass() == cppClass &&
+ constraint->getSummary() == summary &&
+ "constraint with the same name was already registered with a "
+ "different class");
+ }
+ return *constraint;
+}
+
+const TypeConstraint &Context::insertTypeConstraint(StringRef name,
+ StringRef summary,
+ StringRef cppClass) {
+ std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
+ if (!constraint)
+ constraint.reset(new TypeConstraint(name, summary, cppClass));
+ return *constraint;
+}
+
+Dialect &Context::insertDialect(StringRef name) {
+ std::unique_ptr<Dialect> &dialect = dialects[name];
+ if (!dialect)
+ dialect.reset(new Dialect(name));
+ return *dialect;
+}
+
+const Dialect *Context::lookupDialect(StringRef name) const {
+ auto it = dialects.find(name);
+ return it == dialects.end() ? nullptr : &*it->second;
+}
+
+std::pair<Operation *, bool> Context::insertOperation(StringRef name,
+ StringRef summary,
+ StringRef desc,
+ SMLoc loc) {
+ std::pair<StringRef, StringRef> dialectAndName = name.split('.');
+ return insertDialect(dialectAndName.first)
+ .insertOperation(name, summary, desc, loc);
+}
+
+const Operation *Context::lookupOperation(StringRef name) const {
+ std::pair<StringRef, StringRef> dialectAndName = name.split('.');
+ if (const Dialect *dialect = lookupDialect(dialectAndName.first))
+ return dialect->lookupOperation(name);
+ return nullptr;
+}
+
+template <typename T>
+SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
+ SmallVector<T *> storage;
+ for (auto &entry : map)
+ storage.push_back(entry.second.get());
+ llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
+ return lhs->getName() < rhs->getName();
+ });
+ return storage;
+}
+
+void Context::print(raw_ostream &os) const {
+ auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
+ switch (kind) {
+ case VariableLengthKind::Optional:
+ os << "Optional<" << cst << ">";
+ break;
+ case VariableLengthKind::Single:
+ os << cst;
+ break;
+ case VariableLengthKind::Variadic:
+ os << "Variadic<" << cst << ">";
+ break;
+ }
+ };
+
+ llvm::ScopedPrinter printer(os);
+ llvm::DictScope odsScope(printer, "ODSContext");
+ for (const Dialect *dialect : sortMapByName(dialects)) {
+ printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
+ printer.indent();
+
+ for (const Operation *op : sortMapByName(dialect->getOperations())) {
+ printer.startLine() << "Operation `" << op->getName() << "` {\n";
+ printer.indent();
+
+ // Attributes.
+ ArrayRef<Attribute> attributes = op->getAttributes();
+ if (!attributes.empty()) {
+ printer.startLine() << "Attributes { ";
+ llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
+ os << attr.getName() << " : ";
+
+ auto kind = attr.isOptional() ? VariableLengthKind::Optional
+ : VariableLengthKind::Single;
+ printVariableLengthCst(attr.getConstraint().getName(), kind);
+ });
+ os << " }\n";
+ }
+
+ // Operands.
+ ArrayRef<OperandOrResult> operands = op->getOperands();
+ if (!operands.empty()) {
+ printer.startLine() << "Operands { ";
+ llvm::interleaveComma(
+ operands, os, [&](const OperandOrResult &operand) {
+ os << operand.getName() << " : ";
+ printVariableLengthCst(operand.getConstraint().getName(),
+ operand.getVariableLengthKind());
+ });
+ os << " }\n";
+ }
+
+ // Results.
+ ArrayRef<OperandOrResult> results = op->getResults();
+ if (!results.empty()) {
+ printer.startLine() << "Results { ";
+ llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
+ os << result.getName() << " : ";
+ printVariableLengthCst(result.getConstraint().getName(),
+ result.getVariableLengthKind());
+ });
+ os << " }\n";
+ }
+
+ printer.objectEnd();
+ }
+ printer.objectEnd();
+ }
+ for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
+ printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n";
+ printer.indent();
+
+ printer.startLine() << "Summary: " << cst->getSummary() << "\n";
+ printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
+ printer.objectEnd();
+ }
+ for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
+ printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n";
+ printer.indent();
+
+ printer.startLine() << "Summary: " << cst->getSummary() << "\n";
+ printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
+ printer.objectEnd();
+ }
+ printer.objectEnd();
+}
diff --git a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
new file mode 100644
index 0000000..ce9c234
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp
@@ -0,0 +1,39 @@
+//===- Dialect.cpp --------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/ODS/Dialect.h"
+#include "mlir/Tools/PDLL/ODS/Constraint.h"
+#include "mlir/Tools/PDLL/ODS/Operation.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::pdll::ods;
+
+//===----------------------------------------------------------------------===//
+// Dialect
+//===----------------------------------------------------------------------===//
+
+Dialect::Dialect(StringRef name) : name(name.str()) {}
+Dialect::~Dialect() = default;
+
+std::pair<Operation *, bool> Dialect::insertOperation(StringRef name,
+ StringRef summary,
+ StringRef desc,
+ llvm::SMLoc loc) {
+ std::unique_ptr<Operation> &operation = operations[name];
+ if (operation)
+ return std::make_pair(&*operation, /*wasInserted*/ false);
+
+ operation.reset(new Operation(name, summary, desc, loc));
+ return std::make_pair(&*operation, /*wasInserted*/ true);
+}
+
+Operation *Dialect::lookupOperation(StringRef name) const {
+ auto it = operations.find(name);
+ return it != operations.end() ? it->second.get() : nullptr;
+}
diff --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
new file mode 100644
index 0000000..121c6c8
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp
@@ -0,0 +1,26 @@
+//===- Operation.cpp ------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/ODS/Operation.h"
+#include "mlir/Support/IndentedOstream.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::pdll::ods;
+
+//===----------------------------------------------------------------------===//
+// Operation
+//===----------------------------------------------------------------------===//
+
+Operation::Operation(StringRef name, StringRef summary, StringRef desc,
+ llvm::SMLoc loc)
+ : name(name.str()), summary(summary.str()),
+ location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) {
+ llvm::raw_string_ostream descOS(description);
+ raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t"));
+}
diff --git a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
index fb933cd..5d466cf 100644
--- a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_LINK_COMPONENTS
+ Support
+ TableGen
+)
+
add_mlir_library(MLIRPDLLParser
Lexer.cpp
Parser.cpp
@@ -5,4 +10,5 @@ add_mlir_library(MLIRPDLLParser
LINK_LIBS PUBLIC
MLIRPDLLAST
MLIRSupport
+ MLIRTableGen
)
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index da72837..3da7783 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -9,15 +9,26 @@
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Constraint.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/Operator.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h"
+#include "mlir/Tools/PDLL/ODS/Constraint.h"
+#include "mlir/Tools/PDLL/ODS/Context.h"
+#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Parser.h"
#include <string>
using namespace mlir;
@@ -36,7 +47,8 @@ public:
valueTy(ast::ValueType::get(ctx)),
valueRangeTy(ast::ValueRangeType::get(ctx)),
typeTy(ast::TypeType::get(ctx)),
- typeRangeTy(ast::TypeRangeType::get(ctx)) {}
+ typeRangeTy(ast::TypeRangeType::get(ctx)),
+ attrTy(ast::AttributeType::get(ctx)) {}
/// Try to parse a new module. Returns nullptr in the case of failure.
FailureOr<ast::Module *> parseModule();
@@ -78,7 +90,7 @@ private:
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
/// Parse the body of an AST module.
- LogicalResult parseModuleBody(SmallVector<ast::Decl *> &decls);
+ LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
/// Try to convert the given expression to `type`. Returns failure and emits
/// an error if a conversion is not viable. On failure, `noteAttachFn` is
@@ -92,11 +104,34 @@ private:
/// typed expression.
ast::Expr *convertOpToValue(const ast::Expr *opExpr);
+ /// Lookup ODS information for the given operation, returns nullptr if no
+ /// information is found.
+ const ods::Operation *lookupODSOperation(Optional<StringRef> opName) {
+ return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr;
+ }
+
//===--------------------------------------------------------------------===//
// Directives
- LogicalResult parseDirective(SmallVector<ast::Decl *> &decls);
- LogicalResult parseInclude(SmallVector<ast::Decl *> &decls);
+ LogicalResult parseDirective(SmallVectorImpl<ast::Decl *> &decls);
+ LogicalResult parseInclude(SmallVectorImpl<ast::Decl *> &decls);
+ LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc,
+ SmallVectorImpl<ast::Decl *> &decls);
+
+ /// Process the records of a parsed tablegen include file.
+ void processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
+ SmallVectorImpl<ast::Decl *> &decls);
+
+ /// Create a user defined native constraint for a constraint imported from
+ /// ODS.
+ template <typename ConstraintT>
+ ast::Decl *createODSNativePDLLConstraintDecl(StringRef name,
+ StringRef codeBlock, SMRange loc,
+ ast::Type type);
+ template <typename ConstraintT>
+ ast::Decl *
+ createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
+ SMRange loc, ast::Type type);
//===--------------------------------------------------------------------===//
// Decls
@@ -340,13 +375,16 @@ private:
MutableArrayRef<ast::Expr *> results);
LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name,
+ const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> operands);
LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
+ const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> results);
- LogicalResult
- validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
- MutableArrayRef<ast::Expr *> values,
- ast::Type singleTy, ast::Type rangeTy);
+ LogicalResult validateOperationOperandsOrResults(
+ StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
+ Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+ ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
+ ast::Type rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames);
@@ -440,6 +478,7 @@ private:
/// Cached types to simplify verification and expression creation.
ast::Type valueTy, valueRangeTy;
ast::Type typeTy, typeRangeTy;
+ ast::Type attrTy;
/// A counter used when naming anonymous constraints and rewrites.
unsigned anonymousDeclNameCounter = 0;
@@ -459,7 +498,7 @@ FailureOr<ast::Module *> Parser::parseModule() {
return ast::Module::create(ctx, moduleLoc, decls);
}
-LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
+LogicalResult Parser::parseModuleBody(SmallVectorImpl<ast::Decl *> &decls) {
while (curToken.isNot(Token::eof)) {
if (curToken.is(Token::directive)) {
if (failed(parseDirective(decls)))
@@ -516,6 +555,32 @@ LogicalResult Parser::convertExpressionTo(
// Allow conversion to a single value by constraining the result range.
if (type == valueTy) {
+ // If the operation is registered, we can verify if it can ever have a
+ // single result.
+ Optional<StringRef> opName = exprOpType.getName();
+ if (const ods::Operation *odsOp = lookupODSOperation(opName)) {
+ if (odsOp->getResults().empty()) {
+ return emitConvertError()->attachNote(
+ llvm::formatv("see the definition of `{0}`, which was defined "
+ "with zero results",
+ odsOp->getName()),
+ odsOp->getLoc());
+ }
+
+ unsigned numSingleResults = llvm::count_if(
+ odsOp->getResults(), [](const ods::OperandOrResult &result) {
+ return result.getVariableLengthKind() ==
+ ods::VariableLengthKind::Single;
+ });
+ if (numSingleResults > 1) {
+ return emitConvertError()->attachNote(
+ llvm::formatv("see the definition of `{0}`, which was defined "
+ "with at least {1} results",
+ odsOp->getName(), numSingleResults),
+ odsOp->getLoc());
+ }
+ }
+
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
valueTy);
return success();
@@ -569,7 +634,7 @@ LogicalResult Parser::convertExpressionTo(
//===----------------------------------------------------------------------===//
// Directives
-LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
+LogicalResult Parser::parseDirective(SmallVectorImpl<ast::Decl *> &decls) {
StringRef directive = curToken.getSpelling();
if (directive == "#include")
return parseInclude(decls);
@@ -577,7 +642,7 @@ LogicalResult Parser::parseDirective(SmallVector<ast::Decl *> &decls) {
return emitError("unknown directive `" + directive + "`");
}
-LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
+LogicalResult Parser::parseInclude(SmallVectorImpl<ast::Decl *> &decls) {
SMRange loc = curToken.getLoc();
consumeToken(Token::directive);
@@ -607,7 +672,193 @@ LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
return result;
}
- return emitError(fileLoc, "expected include filename to end with `.pdll`");
+ // Otherwise, this must be a `.td` include.
+ if (filename.endswith(".td"))
+ return parseTdInclude(filename, fileLoc, decls);
+
+ return emitError(fileLoc,
+ "expected include filename to end with `.pdll` or `.td`");
+}
+
+LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
+ SmallVectorImpl<ast::Decl *> &decls) {
+ llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr();
+
+ // This class provides a context argument for the llvm::SourceMgr diagnostic
+ // handler.
+ struct DiagHandlerContext {
+ Parser &parser;
+ StringRef filename;
+ llvm::SMRange loc;
+ } handlerContext{*this, filename, fileLoc};
+
+ // Set the diagnostic handler for the tablegen source manager.
+ llvm::SrcMgr.setDiagHandler(
+ [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) {
+ auto *ctx = reinterpret_cast<DiagHandlerContext *>(rawHandlerContext);
+ (void)ctx->parser.emitError(
+ ctx->loc,
+ llvm::formatv("error while processing include file `{0}`: {1}",
+ ctx->filename, diag.getMessage()));
+ },
+ &handlerContext);
+
+ // Use the source manager to open the file, but don't yet add it.
+ std::string includedFile;
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> includeBuffer =
+ parserSrcMgr.OpenIncludeFile(filename.str(), includedFile);
+ if (!includeBuffer)
+ return emitError(fileLoc, "unable to open include file `" + filename + "`");
+
+ auto processFn = [&](llvm::RecordKeeper &records) {
+ processTdIncludeRecords(records, decls);
+
+ // After we are done processing, move all of the tablegen source buffers to
+ // the main parser source mgr. This allows for directly using source
+ // locations from the .td files without needing to remap them.
+ parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr);
+ return false;
+ };
+ if (llvm::TableGenParseFile(std::move(*includeBuffer),
+ parserSrcMgr.getIncludeDirs(), processFn))
+ return failure();
+
+ return success();
+}
+
+void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
+ SmallVectorImpl<ast::Decl *> &decls) {
+ // Return the length kind of the given value.
+ auto getLengthKind = [](const auto &value) {
+ if (value.isOptional())
+ return ods::VariableLengthKind::Optional;
+ return value.isVariadic() ? ods::VariableLengthKind::Variadic
+ : ods::VariableLengthKind::Single;
+ };
+
+ // Insert a type constraint into the ODS context.
+ ods::Context &odsContext = ctx.getODSContext();
+ auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
+ -> const ods::TypeConstraint & {
+ return odsContext.insertTypeConstraint(cst.constraint.getDefName(),
+ cst.constraint.getSummary(),
+ cst.constraint.getCPPClassName());
+ };
+ auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
+ return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
+ };
+
+ // Process the parsed tablegen records to build ODS information.
+ /// Operations.
+ for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) {
+ tblgen::Operator op(def);
+
+ bool inserted = false;
+ ods::Operation *odsOp = nullptr;
+ std::tie(odsOp, inserted) =
+ odsContext.insertOperation(op.getOperationName(), op.getSummary(),
+ op.getDescription(), op.getLoc().front());
+
+ // Ignore operations that have already been added.
+ if (!inserted)
+ continue;
+
+ for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
+ odsOp->appendAttribute(
+ attr.name, attr.attr.isOptional(),
+ odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(),
+ attr.attr.getSummary(),
+ attr.attr.getStorageType()));
+ }
+ for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) {
+ odsOp->appendOperand(operand.name, getLengthKind(operand),
+ addTypeConstraint(operand));
+ }
+ for (const tblgen::NamedTypeConstraint &result : op.getResults()) {
+ odsOp->appendResult(result.name, getLengthKind(result),
+ addTypeConstraint(result));
+ }
+ }
+ /// Attr constraints.
+ for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) {
+ if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
+ decls.push_back(
+ createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
+ tblgen::AttrConstraint(def),
+ convertLocToRange(def->getLoc().front()), attrTy));
+ }
+ }
+ /// Type constraints.
+ for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) {
+ if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) {
+ decls.push_back(
+ createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
+ tblgen::TypeConstraint(def),
+ convertLocToRange(def->getLoc().front()), typeTy));
+ }
+ }
+ /// Interfaces.
+ ast::Type opTy = ast::OperationType::get(ctx);
+ for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) {
+ StringRef name = def->getName();
+ if (def->isAnonymous() || curDeclScope->lookup(name) ||
+ def->isSubClassOf("DeclareInterfaceMethods"))
+ continue;
+ SMRange loc = convertLocToRange(def->getLoc().front());
+
+ StringRef className = def->getValueAsString("cppClassName");
+ StringRef cppNamespace = def->getValueAsString("cppNamespace");
+ std::string codeBlock =
+ llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className)
+ .str();
+
+ if (def->isSubClassOf("OpInterface")) {
+ decls.push_back(createODSNativePDLLConstraintDecl<ast::OpConstraintDecl>(
+ name, codeBlock, loc, opTy));
+ } else if (def->isSubClassOf("AttrInterface")) {
+ decls.push_back(
+ createODSNativePDLLConstraintDecl<ast::AttrConstraintDecl>(
+ name, codeBlock, loc, attrTy));
+ } else if (def->isSubClassOf("TypeInterface")) {
+ decls.push_back(
+ createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
+ name, codeBlock, loc, typeTy));
+ }
+ }
+}
+
+template <typename ConstraintT>
+ast::Decl *
+Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock,
+ SMRange loc, ast::Type type) {
+ // Build the single input parameter.
+ ast::DeclScope *argScope = pushDeclScope();
+ auto *paramVar = ast::VariableDecl::create(
+ ctx, ast::Name::create(ctx, "self", loc), type,
+ /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc)));
+ argScope->add(paramVar);
+ popDeclScope();
+
+ // Build the native constraint.
+ auto *constraintDecl = ast::UserConstraintDecl::createNative(
+ ctx, ast::Name::create(ctx, name, loc), paramVar,
+ /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx));
+ curDeclScope->add(constraintDecl);
+ return constraintDecl;
+}
+
+template <typename ConstraintT>
+ast::Decl *
+Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
+ SMRange loc, ast::Type type) {
+ // Format the condition template.
+ tblgen::FmtContext fmtContext;
+ fmtContext.withSelf("self");
+ std::string codeBlock =
+ tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext);
+
+ return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(),
+ codeBlock, loc, type);
}
//===----------------------------------------------------------------------===//
@@ -2302,9 +2553,29 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType();
- if (parentType.isa<ast::OperationType>()) {
+ if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
return valueRangeTy;
+
+ // Verify member access based on the operation type.
+ if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) {
+ auto results = odsOp->getResults();
+
+ // Handle indexed results.
+ unsigned index = 0;
+ if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) &&
+ index < results.size()) {
+ return results[index].isVariadic() ? valueRangeTy : valueTy;
+ }
+
+ // Handle named results.
+ const auto *it = llvm::find_if(results, [&](const auto &result) {
+ return result.getName() == name;
+ });
+ if (it != results.end())
+ return it->isVariadic() ? valueRangeTy : valueTy;
+ }
+
} else if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
// Handle indexed results.
unsigned index = 0;
@@ -2331,9 +2602,10 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
MutableArrayRef<ast::Expr *> results) {
Optional<StringRef> opNameRef = name->getName();
+ const ods::Operation *odsOp = lookupODSOperation(opNameRef);
// Verify the inputs operands.
- if (failed(validateOperationOperands(loc, opNameRef, operands)))
+ if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands)))
return failure();
// Verify the attribute list.
@@ -2348,7 +2620,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
}
// Verify the result types.
- if (failed(validateOperationResults(loc, opNameRef, results)))
+ if (failed(validateOperationResults(loc, opNameRef, odsOp, results)))
return failure();
return ast::OperationExpr::create(ctx, loc, name, operands, results,
@@ -2357,21 +2629,28 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
LogicalResult
Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
+ const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> operands) {
- return validateOperationOperandsOrResults(loc, name, operands, valueTy,
- valueRangeTy);
+ return validateOperationOperandsOrResults(
+ "operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
+ operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
+ valueRangeTy);
}
LogicalResult
Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
+ const ods::Operation *odsOp,
MutableArrayRef<ast::Expr *> results) {
- return validateOperationOperandsOrResults(loc, name, results, typeTy,
- typeRangeTy);
+ return validateOperationOperandsOrResults(
+ "result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
+ results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
}
LogicalResult Parser::validateOperationOperandsOrResults(
- SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
- ast::Type singleTy, ast::Type rangeTy) {
+ StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
+ Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+ ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
+ ast::Type rangeTy) {
// All operation types accept a single range parameter.
if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy)))
@@ -2379,6 +2658,29 @@ LogicalResult Parser::validateOperationOperandsOrResults(
return success();
}
+ /// If the operation has ODS information, we can more accurately verify the
+ /// values.
+ if (odsOpLoc) {
+ if (odsValues.size() != values.size()) {
+ return emitErrorAndNote(
+ loc,
+ llvm::formatv("invalid number of {0} groups for `{1}`; expected "
+ "{2}, but got {3}",
+ groupName, *name, odsValues.size(), values.size()),
+ *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
+ }
+ auto diagFn = [&](ast::Diagnostic &diag) {
+ diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
+ *odsOpLoc);
+ };
+ for (unsigned i = 0, e = values.size(); i < e; ++i) {
+ ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy;
+ if (failed(convertExpressionTo(values[i], expectedType, diagFn)))
+ return failure();
+ }
+ return success();
+ }
+
// Otherwise, accept the value groups as they have been defined and just
// ensure they are one of the expected types.
for (ast::Expr *&valueExpr : values) {
diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
index 3e652ad..e8db46c 100644
--- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
@@ -1,4 +1,4 @@
-// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s
+// RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x mlir | FileCheck %s
//===----------------------------------------------------------------------===//
// AttributeExpr
@@ -55,6 +55,24 @@ Pattern OpAllResultMemberAccess {
// -----
+// Handle implicit "named" operation results access.
+
+#include "include/ops.td"
+
+// CHECK: pdl.pattern @OpResultMemberAccess
+// CHECK: %[[OP0:.*]] = operation
+// CHECK: %[[RES:.*]] = results 0 of %[[OP0]] -> !pdl.value
+// CHECK: %[[RES1:.*]] = results 0 of %[[OP0]] -> !pdl.value
+// CHECK: %[[RES2:.*]] = results 1 of %[[OP0]] -> !pdl.range<value>
+// CHECK: %[[RES3:.*]] = results 1 of %[[OP0]] -> !pdl.range<value>
+// CHECK: operation(%[[RES]], %[[RES1]], %[[RES2]], %[[RES3]] : !pdl.value, !pdl.value, !pdl.range<value>, !pdl.range<value>)
+Pattern OpResultMemberAccess {
+ let op: Op<test.with_results>;
+ erase op<>(op.0, op.result, op.1, op.var_result);
+}
+
+// -----
+
// CHECK: pdl.pattern @TupleMemberAccessNumber
// CHECK: %[[FIRST:.*]] = operation "test.first"
// CHECK: %[[SECOND:.*]] = operation "test.second"
diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td b/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td
new file mode 100644
index 0000000..588b290
--- /dev/null
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td
@@ -0,0 +1,9 @@
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+}
+
+def OpWithResults : Op<Test_Dialect, "with_results"> {
+ let results = (outs I64:$result, Variadic<I64>:$var_result);
+}
diff --git a/mlir/test/mlir-pdll/Parser/directive-failure.pdll b/mlir/test/mlir-pdll/Parser/directive-failure.pdll
index 14924fe..14f8db8 100644
--- a/mlir/test/mlir-pdll/Parser/directive-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/directive-failure.pdll
@@ -19,5 +19,5 @@
// -----
-// CHECK: expected include filename to end with `.pdll`
+// CHECK: expected include filename to end with `.pdll` or `.td`
#include "unknown_file.foo"
diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
index 7ed3ba8..08174de 100644
--- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
@@ -1,4 +1,4 @@
-// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
+// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s
//===----------------------------------------------------------------------===//
// Reference Expr
@@ -276,6 +276,26 @@ Pattern {
// -----
+#include "include/ops.td"
+
+Pattern {
+ // CHECK: invalid number of operand groups for `test.all_empty`; expected 0, but got 2
+ // CHECK: see the definition of `test.all_empty` here
+ let foo = op<test.all_empty>(operand1: Value, operand2: Value);
+}
+
+// -----
+
+#include "include/ops.td"
+
+Pattern {
+ // CHECK: invalid number of result groups for `test.all_empty`; expected 0, but got 2
+ // CHECK: see the definition of `test.all_empty` here
+ let foo = op<test.all_empty> -> (result1: Type, result2: Type);
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// `type` Expr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll
index 9919fe5..c7d9603 100644
--- a/mlir/test/mlir-pdll/Parser/expr.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr.pdll
@@ -1,4 +1,4 @@
-// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
+// RUN: mlir-pdll %s -I %S -I %S/../../../include -split-input-file | FileCheck %s
//===----------------------------------------------------------------------===//
// AttrExpr
@@ -71,6 +71,25 @@ Pattern {
// -----
+#include "include/ops.td"
+
+// CHECK: Module
+// CHECK: `-VariableDecl {{.*}} Name<firstEltIndex> Type<Value>
+// CHECK: `-MemberAccessExpr {{.*}} Member<0> Type<Value>
+// CHECK: `-DeclRefExpr {{.*}} Type<Op<test.all_single>>
+// CHECK: `-VariableDecl {{.*}} Name<firstEltName> Type<Value>
+// CHECK: `-MemberAccessExpr {{.*}} Member<result> Type<Value>
+// CHECK: `-DeclRefExpr {{.*}} Type<Op<test.all_single>>
+Pattern {
+ let op: Op<test.all_single>;
+ let firstEltIndex = op.0;
+ let firstEltName = op.result;
+
+ erase op;
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// OperationExpr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/include/interfaces.td b/mlir/test/mlir-pdll/Parser/include/interfaces.td
new file mode 100644
index 0000000..eea8783
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/include/interfaces.td
@@ -0,0 +1,5 @@
+include "mlir/IR/OpBase.td"
+
+def TestAttrInterface : AttrInterface<"TestAttrInterface">;
+def TestOpInterface : OpInterface<"TestOpInterface">;
+def TestTypeInterface : TypeInterface<"TestTypeInterface">;
diff --git a/mlir/test/mlir-pdll/Parser/include/ops.td b/mlir/test/mlir-pdll/Parser/include/ops.td
new file mode 100644
index 0000000..1727d1a
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/include/ops.td
@@ -0,0 +1,26 @@
+include "include/interfaces.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+}
+
+def OpAllEmpty : Op<Test_Dialect, "all_empty">;
+
+def OpAllSingle : Op<Test_Dialect, "all_single"> {
+ let arguments = (ins I64:$operand, I64Attr:$attr);
+ let results = (outs I64:$result);
+}
+
+def OpAllOptional : Op<Test_Dialect, "all_optional"> {
+ let arguments = (ins Optional<I64>:$operand, OptionalAttr<I64Attr>:$attr);
+ let results = (outs Optional<I64>:$result);
+}
+
+def OpAllVariadic : Op<Test_Dialect, "all_variadic"> {
+ let arguments = (ins Variadic<I64>:$operands);
+ let results = (outs Variadic<I64>:$results);
+}
+
+def OpMultipleSingleResult : Op<Test_Dialect, "multiple_single_result"> {
+ let results = (outs I64:$result, I64:$result2);
+}
diff --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll
new file mode 100644
index 0000000..c55ed1d
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/include_td.pdll
@@ -0,0 +1,52 @@
+// RUN: mlir-pdll %s -I %S -I %S/../../../include -dump-ods 2>&1 | FileCheck %s
+
+#include "include/ops.td"
+
+// CHECK: Operation `test.all_empty` {
+// CHECK-NEXT: }
+
+// CHECK: Operation `test.all_optional` {
+// CHECK-NEXT: Attributes { attr : Optional<I64Attr> }
+// CHECK-NEXT: Operands { operand : Optional<I64> }
+// CHECK-NEXT: Results { result : Optional<I64> }
+// CHECK-NEXT: }
+
+// CHECK: Operation `test.all_single` {
+// CHECK-NEXT: Attributes { attr : I64Attr }
+// CHECK-NEXT: Operands { operand : I64 }
+// CHECK-NEXT: Results { result : I64 }
+// CHECK-NEXT: }
+
+// CHECK: Operation `test.all_variadic` {
+// CHECK-NEXT: Operands { operands : Variadic<I64> }
+// CHECK-NEXT: Results { results : Variadic<I64> }
+// CHECK-NEXT: }
+
+// CHECK: AttributeConstraint `I64Attr` {
+// CHECK-NEXT: Summary: 64-bit signless integer attribute
+// CHECK-NEXT: CppClass: ::mlir::IntegerAttr
+// CHECK-NEXT: }
+
+// CHECK: TypeConstraint `I64` {
+// CHECK-NEXT: Summary: 64-bit signless integer
+// CHECK-NEXT: CppClass: ::mlir::IntegerType
+// CHECK-NEXT: }
+
+// CHECK: UserConstraintDecl {{.*}} Name<TestAttrInterface> ResultType<Tuple<>> Code<llvm::isa<::TestAttrInterface>(self)>
+// CHECK: `Inputs`
+// CHECK: `-VariableDecl {{.*}} Name<self> Type<Attr>
+// CHECK: `Constraints`
+// CHECK: `-AttrConstraintDecl
+
+// CHECK: UserConstraintDecl {{.*}} Name<TestOpInterface> ResultType<Tuple<>> Code<llvm::isa<::TestOpInterface>(self)>
+// CHECK: `Inputs`
+// CHECK: `-VariableDecl {{.*}} Name<self> Type<Op>
+// CHECK: `Constraints`
+// CHECK: `-OpConstraintDecl
+// CHECK: `-OpNameDecl
+
+// CHECK: UserConstraintDecl {{.*}} Name<TestTypeInterface> ResultType<Tuple<>> Code<llvm::isa<::TestTypeInterface>(self)>
+// CHECK: `Inputs`
+// CHECK: `-VariableDecl {{.*}} Name<self> Type<Type>
+// CHECK: `Constraints`
+// CHECK: `-TypeConstraintDecl {{.*}}
diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
index d52a2d2..4220259 100644
--- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
@@ -1,4 +1,4 @@
-// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
+// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s
// CHECK: expected top-level declaration, such as a `Pattern`
10
@@ -250,6 +250,28 @@ Pattern {
// -----
+#include "include/ops.td"
+
+Pattern {
+ // CHECK: unable to convert expression of type `Op<test.all_empty>` to the expected type of `Value`
+ // CHECK: see the definition of `test.all_empty`, which was defined with zero results
+ let value: Value = op<test.all_empty>;
+ erase _: Op;
+}
+
+// -----
+
+#include "include/ops.td"
+
+Pattern {
+ // CHECK: unable to convert expression of type `Op<test.multiple_single_result>` to the expected type of `Value`
+ // CHECK: see the definition of `test.multiple_single_result`, which was defined with at least 2 results
+ let value: Value = op<test.multiple_single_result>;
+ erase _: Op;
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// `replace`
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp
index e133d9e..904fb77 100644
--- a/mlir/tools/mlir-pdll/mlir-pdll.cpp
+++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp
@@ -13,6 +13,7 @@
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
+#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
@@ -35,16 +36,23 @@ enum class OutputType {
static LogicalResult
processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
- OutputType outputType, std::vector<std::string> &includeDirs) {
+ OutputType outputType, std::vector<std::string> &includeDirs,
+ bool dumpODS) {
llvm::SourceMgr sourceMgr;
sourceMgr.setIncludeDirs(includeDirs);
sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc());
- ast::Context astContext;
+ ods::Context odsContext;
+ ast::Context astContext(odsContext);
FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr);
if (failed(module))
return failure();
+ // Print out the ODS information if requested.
+ if (dumpODS)
+ odsContext.print(llvm::errs());
+
+ // Generate the output.
if (outputType == OutputType::AST) {
(*module)->print(os);
return success();
@@ -66,6 +74,10 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
}
int main(int argc, char **argv) {
+ // FIXME: This is necessary because we link in TableGen, which defines its
+ // options as static variables.. some of which overlap with our options.
+ llvm::cl::ResetCommandLineParser();
+
llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::value_desc("filename"));
@@ -78,6 +90,11 @@ int main(int argc, char **argv) {
"I", llvm::cl::desc("Directory of include files"),
llvm::cl::value_desc("directory"), llvm::cl::Prefix);
+ llvm::cl::opt<bool> dumpODS(
+ "dump-ods",
+ llvm::cl::desc(
+ "Print out the parsed ODS information from the input file"),
+ llvm::cl::init(false));
llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
@@ -118,7 +135,8 @@ int main(int argc, char **argv) {
// up into small pieces and checks each independently.
auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
raw_ostream &os) {
- return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs);
+ return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs,
+ dumpODS);
};
if (splitInputFile) {
if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,