aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp58
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h6
-rw-r--r--mlir/lib/Bindings/Python/IRTypes.cpp4
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp6
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp291
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp7
6 files changed, 259 insertions, 113 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c20b211..32b2b0c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3219,13 +3219,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("end_line"), nb::arg("end_col"),
nb::arg("context") = nb::none(), kContextGetFileRangeDocstring)
.def("is_a_file", mlirLocationIsAFileLineColRange)
- .def_prop_ro(
- "filename",
- [](MlirLocation loc) {
- return mlirIdentifierStr(
- mlirLocationFileLineColRangeGetFilename(loc));
- },
- nb::sig("def filename(self) -> str"))
+ .def_prop_ro("filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ })
.def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
.def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
.def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
@@ -3274,12 +3272,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("name"), nb::arg("childLoc") = nb::none(),
nb::arg("context") = nb::none(), kContextGetNameLocationDocString)
.def("is_a_name", mlirLocationIsAName)
- .def_prop_ro(
- "name_str",
- [](MlirLocation loc) {
- return mlirIdentifierStr(mlirLocationNameGetName(loc));
- },
- nb::sig("def name_str(self) -> str"))
+ .def_prop_ro("name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ })
.def_prop_ro("child_loc",
[](PyLocation &self) {
return PyLocation(self.getContext(),
@@ -3453,15 +3449,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return concreteOperation.getContext().getObject();
},
"Context that owns the Operation")
- .def_prop_ro(
- "name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation = concreteOperation.get();
- return mlirIdentifierStr(mlirOperationGetName(operation));
- },
- nb::sig("def name(self) -> str"))
+ .def_prop_ro("name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ return mlirIdentifierStr(mlirOperationGetName(operation));
+ })
.def_prop_ro("operands",
[](PyOperationBase &self) {
return PyOpOperandList(self.getOperation().getRef());
@@ -3603,12 +3597,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"Reports if the operation is attached to its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
- .def(
- "walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = " MAKE_MLIR_PYTHON_QUALNAME("ir.WalkOrder.POST_ORDER") ") -> None")
- // clang-format on
+ .def("walk", &PyOperationBase::walk, nb::arg("callback"),
+ nb::arg("walk_order") = MlirWalkPostOrder,
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None")
+ // clang-format on
);
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
@@ -4124,7 +4117,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyNamedAttribute &self) {
return mlirIdentifierStr(self.namedAttr.name);
},
- nb::sig("def name(self) -> str"),
"The name of the NamedAttribute binding")
.def_prop_ro(
"attr",
@@ -4342,17 +4334,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kValueReplaceAllUsesWithDocstring)
.def(
"replace_all_uses_except",
- [](MlirValue self, MlirValue with, PyOperation &exception) {
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
nb::arg("with_"), nb::arg("exceptions"),
- nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
- "Operation) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
- [](MlirValue self, MlirValue with, nb::list exceptions) {
+ [](PyValue &self, PyValue &with, const nb::list &exceptions) {
// Convert Python list to a SmallVector of MlirOperations
llvm::SmallVector<MlirOperation> exceptionOps;
for (nb::handle exception : exceptions) {
@@ -4364,8 +4354,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
exceptionOps.data());
},
nb::arg("with_"), nb::arg("exceptions"),
- nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
- "Sequence[Operation]) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 598ae01..edbd73e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -273,8 +273,7 @@ class DefaultingPyMlirContext
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
public:
using Defaulting::Defaulting;
- static constexpr const char kTypeDescription[] =
- MAKE_MLIR_PYTHON_QUALNAME("ir.Context");
+ static constexpr const char kTypeDescription[] = "Context";
static PyMlirContext &resolve();
};
@@ -500,8 +499,7 @@ class DefaultingPyLocation
: public Defaulting<DefaultingPyLocation, PyLocation> {
public:
using Defaulting::Defaulting;
- static constexpr const char kTypeDescription[] =
- MAKE_MLIR_PYTHON_QUALNAME("ir.Location");
+ static constexpr const char kTypeDescription[] = "Location";
static PyLocation &resolve();
operator MlirLocation() const { return *get(); }
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 3488d92..34c5b8d 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -1010,7 +1010,7 @@ public:
},
nb::arg("elements"), nb::arg("context") = nb::none(),
// clang-format off
- nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"),
+ nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
// clang-format on
"Create a tuple type");
c.def(
@@ -1070,7 +1070,7 @@ public:
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
// clang-format off
- nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
+ nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
// clang-format on
"Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 52656138..a14f09f 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -115,9 +115,6 @@ NB_MODULE(_mlir, m) {
});
},
"typeid"_a, nb::kw_only(), "replace"_a = false,
- // clang-format off
- nb::sig("def register_type_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"),
- // clang-format on
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
@@ -130,9 +127,6 @@ NB_MODULE(_mlir, m) {
});
},
"typeid"_a, nb::kw_only(), "replace"_a = false,
- // clang-format off
- nb::sig("def register_value_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"),
- // clang-format on
"Register a value caster for casting MLIR values to custom user values.");
// Define and populate IR submodule.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a173cf1..32ebe06 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -77,6 +77,232 @@ struct LLVMPointerPointerLikeModel
};
} // namespace
+/// Generate a name of a canonical loop nest of the format
+/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
+/// argument index of an operation that has multiple regions, if the operation
+/// has multiple regions.
+/// `_s<idx>` identifies the position of an operation within a region, where
+/// only operations that may potentially contain loops ("container operations"
+/// i.e. have region arguments) are counted. Again, it is omitted if there is
+/// only one such operation in a region. If there are canonical loops nested
+/// inside each other, also may also use the format `_d<num>` where <num> is the
+/// nesting depth of the loop.
+///
+/// The generated name is a best-effort to make canonical loop unique within an
+/// SSA namespace. This also means that regions with IsolatedFromAbove property
+/// do not consider any parents or siblings.
+static std::string generateLoopNestingName(StringRef prefix,
+ CanonicalLoopOp op) {
+ struct Component {
+ /// If true, this component describes a region operand of an operation (the
+ /// operand's owner) If false, this component describes an operation located
+ /// in a parent region
+ bool isRegionArgOfOp;
+ bool skip = false;
+ bool isUnique = false;
+
+ size_t idx;
+ Operation *op;
+ Region *parentRegion;
+ size_t loopDepth;
+
+ Operation *&getOwnerOp() {
+ assert(isRegionArgOfOp && "Must describe a region operand");
+ return op;
+ }
+ size_t &getArgIdx() {
+ assert(isRegionArgOfOp && "Must describe a region operand");
+ return idx;
+ }
+
+ Operation *&getContainerOp() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return op;
+ }
+ size_t &getOpPos() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return idx;
+ }
+ bool isLoopOp() const {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return isa<CanonicalLoopOp>(op);
+ }
+ Region *&getParentRegion() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return parentRegion;
+ }
+ size_t &getLoopDepth() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return loopDepth;
+ }
+
+ void skipIf(bool v = true) { skip = skip || v; }
+ };
+
+ // List of ancestors, from inner to outer.
+ // Alternates between
+ // * region argument of an operation
+ // * operation within a region
+ SmallVector<Component> components;
+
+ // Gather a list of parent regions and operations, and the position within
+ // their parent
+ Operation *o = op.getOperation();
+ while (o) {
+ // Operation within a region
+ Region *r = o->getParentRegion();
+ if (!r)
+ break;
+
+ llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
+ size_t idx = 0;
+ bool found = false;
+ size_t sequentialIdx = -1;
+ bool isOnlyContainerOp = true;
+ for (Block *b : traversal) {
+ for (Operation &op : *b) {
+ if (&op == o && !found) {
+ sequentialIdx = idx;
+ found = true;
+ }
+ if (op.getNumRegions()) {
+ idx += 1;
+ if (idx > 1)
+ isOnlyContainerOp = false;
+ }
+ if (found && !isOnlyContainerOp)
+ break;
+ }
+ }
+
+ Component &containerOpInRegion = components.emplace_back();
+ containerOpInRegion.isRegionArgOfOp = false;
+ containerOpInRegion.isUnique = isOnlyContainerOp;
+ containerOpInRegion.getContainerOp() = o;
+ containerOpInRegion.getOpPos() = sequentialIdx;
+ containerOpInRegion.getParentRegion() = r;
+
+ Operation *parent = r->getParentOp();
+
+ // Region argument of an operation
+ Component &regionArgOfOperation = components.emplace_back();
+ regionArgOfOperation.isRegionArgOfOp = true;
+ regionArgOfOperation.isUnique = true;
+ regionArgOfOperation.getArgIdx() = 0;
+ regionArgOfOperation.getOwnerOp() = parent;
+
+ // The IsolatedFromAbove trait of the parent operation implies that each
+ // individual region argument has its own separate namespace, so no
+ // ambiguity.
+ if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
+ break;
+
+ // Component only needed if operation has multiple region operands. Region
+ // arguments may be optional, but we currently do not consider this.
+ if (parent->getRegions().size() > 1) {
+ auto getRegionIndex = [](Operation *o, Region *r) {
+ for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
+ if (&region == r)
+ return idx;
+ }
+ llvm_unreachable("Region not child of its parent operation");
+ };
+ regionArgOfOperation.isUnique = false;
+ regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
+ }
+
+ // next parent
+ o = parent;
+ }
+
+ // Determine whether a region-argument component is not needed
+ for (Component &c : components)
+ c.skipIf(c.isRegionArgOfOp && c.isUnique);
+
+ // Find runs of nested loops and determine each loop's depth in the loop nest
+ size_t numSurroundingLoops = 0;
+ for (Component &c : llvm::reverse(components)) {
+ if (c.skip)
+ continue;
+
+ // non-skipped multi-argument operands interrupt the loop nest
+ if (c.isRegionArgOfOp) {
+ numSurroundingLoops = 0;
+ continue;
+ }
+
+ // Multiple loops in a region means each of them is the outermost loop of a
+ // new loop nest
+ if (!c.isUnique)
+ numSurroundingLoops = 0;
+
+ c.getLoopDepth() = numSurroundingLoops;
+
+ // Next loop is surrounded by one more loop
+ if (isa<CanonicalLoopOp>(c.getContainerOp()))
+ numSurroundingLoops += 1;
+ }
+
+ // In loop nests, skip all but the innermost loop that contains the depth
+ // number
+ bool isLoopNest = false;
+ for (Component &c : components) {
+ if (c.skip || c.isRegionArgOfOp)
+ continue;
+
+ if (!isLoopNest && c.getLoopDepth() >= 1) {
+ // Innermost loop of a loop nest of at least two loops
+ isLoopNest = true;
+ } else if (isLoopNest) {
+ // Non-innermost loop of a loop nest
+ c.skipIf(c.isUnique);
+
+ // If there is no surrounding loop left, this must have been the outermost
+ // loop; leave loop-nest mode for the next iteration
+ if (c.getLoopDepth() == 0)
+ isLoopNest = false;
+ }
+ }
+
+ // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
+ // we mark them as skipped only after computing loop nests)
+ for (Component &c : components)
+ c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
+ !isa<CanonicalLoopOp>(c.getContainerOp()));
+
+ // Components can be skipped if they are already disambiguated by their parent
+ // (or does not have a parent)
+ bool newRegion = true;
+ for (Component &c : llvm::reverse(components)) {
+ c.skipIf(newRegion && c.isUnique);
+
+ // non-skipped components disambiguate unique children
+ if (!c.skip)
+ newRegion = true;
+
+ // ...except canonical loops that need a suffix for each nest
+ if (!c.isRegionArgOfOp && c.getContainerOp())
+ newRegion = false;
+ }
+
+ // Compile the nesting name string
+ SmallString<64> Name{prefix};
+ llvm::raw_svector_ostream NameOS(Name);
+ for (auto &c : llvm::reverse(components)) {
+ if (c.skip)
+ continue;
+
+ if (c.isRegionArgOfOp)
+ NameOS << "_r" << c.getArgIdx();
+ else if (c.getLoopDepth() >= 1)
+ NameOS << "_d" << c.getLoopDepth();
+ else
+ NameOS << "_s" << c.getOpPos();
+ }
+
+ return NameOS.str().str();
+}
+
void OpenMPDialect::initialize() {
addOperations<
#define GET_OP_LIST
@@ -3172,67 +3398,7 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
cliName =
TypeSwitch<Operation *, std::string>(gen->getOwner())
.Case([&](CanonicalLoopOp op) {
- // Find the canonical loop nesting: For each ancestor add a
- // "+_r<idx>" suffix (in reverse order)
- SmallVector<std::string> components;
- Operation *o = op.getOperation();
- while (o) {
- if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
- break;
-
- Region *r = o->getParentRegion();
- if (!r)
- break;
-
- auto getSequentialIndex = [](Region *r, Operation *o) {
- llvm::ReversePostOrderTraversal<Block *> traversal(
- &r->getBlocks().front());
- size_t idx = 0;
- for (Block *b : traversal) {
- for (Operation &op : *b) {
- if (&op == o)
- return idx;
- // Only consider operations that are containers as
- // possible children
- if (!op.getRegions().empty())
- idx += 1;
- }
- }
- llvm_unreachable("Operation not part of the region");
- };
- size_t sequentialIdx = getSequentialIndex(r, o);
- components.push_back(("s" + Twine(sequentialIdx)).str());
-
- Operation *parent = r->getParentOp();
- if (!parent)
- break;
-
- // If the operation has more than one region, also count in
- // which of the regions
- if (parent->getRegions().size() > 1) {
- auto getRegionIndex = [](Operation *o, Region *r) {
- for (auto [idx, region] :
- llvm::enumerate(o->getRegions())) {
- if (&region == r)
- return idx;
- }
- llvm_unreachable("Region not child its parent operation");
- };
- size_t regionIdx = getRegionIndex(parent, r);
- components.push_back(("r" + Twine(regionIdx)).str());
- }
-
- // next parent
- o = parent;
- }
-
- SmallString<64> Name("canonloop");
- for (const std::string &s : reverse(components)) {
- Name += '_';
- Name += s;
- }
-
- return Name;
+ return generateLoopNestingName("canonloop", op);
})
.Case([&](UnrollHeuristicOp op) -> std::string {
llvm_unreachable("heuristic unrolling does not generate a loop");
@@ -3323,7 +3489,8 @@ void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
- setNameFn(region.getArgument(0), "iv");
+ std::string ivName = generateLoopNestingName("iv", *this);
+ setNameFn(region.getArgument(0), ivName);
}
void CanonicalLoopOp::print(OpAsmPrinter &p) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 132ed81..3385b2a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -616,11 +616,10 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
if (diag.succeeded()) {
// Tracking failure is the only failure.
return trackingFailure;
- } else {
- diag.attachNote() << "tracking listener also failed: "
- << trackingFailure.getMessage();
- (void)trackingFailure.silence();
}
+ diag.attachNote() << "tracking listener also failed: "
+ << trackingFailure.getMessage();
+ (void)trackingFailure.silence();
}
if (!diag.succeeded())