aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Butygin <ivan.butygin@gmail.com>2025-04-01 18:28:53 +0200
committerGitHub <noreply@github.com>2025-04-01 19:28:53 +0300
commit1f194ff34e4e861a18f7108c7874bccbd6459f30 (patch)
tree08375e49d07e1803c660f52d3c9857d720709a51
parent7e25b240731413d2cfca2b78ab1d0ed33d851622 (diff)
downloadllvm-1f194ff34e4e861a18f7108c7874bccbd6459f30.zip
llvm-1f194ff34e4e861a18f7108c7874bccbd6459f30.tar.gz
llvm-1f194ff34e4e861a18f7108c7874bccbd6459f30.tar.bz2
[mlir] Expose `simplifyAffineExpr` through python api (#133926)
-rw-r--r--mlir/include/mlir-c/AffineExpr.h10
-rw-r--r--mlir/lib/Bindings/Python/IRAffine.cpp10
-rw-r--r--mlir/lib/CAPI/IR/AffineExpr.cpp5
-rw-r--r--mlir/test/python/ir/affine_expr.py8
4 files changed, 33 insertions, 0 deletions
diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index ab768eb..161db626 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -104,6 +104,16 @@ MLIR_CAPI_EXPORTED MlirAffineExpr
mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
uint32_t shift, uint32_t offset);
+/// Simplify an affine expression by flattening and some amount of simple
+/// analysis. This has complexity linear in the number of nodes in 'expr'.
+/// Returns the simplified expression, which is the same as the input expression
+/// if it can't be simplified. When `expr` is semi-affine, a simplified
+/// semi-affine expression is constructed in the sorted order of dimension and
+/// symbol positions.
+MLIR_CAPI_EXPORTED MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr,
+ uint32_t numDims,
+ uint32_t numSymbols);
+
//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 3c95d29..50f2a4f 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -600,6 +600,16 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
nb::arg("num_symbols"), nb::arg("shift"),
nb::arg("offset").none() = 0)
.def_static(
+ "simplify_affine_expr",
+ [](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) {
+ return PyAffineExpr(
+ self.getContext(),
+ mlirSimplifyAffineExpr(self, numDims, numSymbols));
+ },
+ nb::arg("expr"), nb::arg("num_dims"), nb::arg("num_symbols"),
+ "Simplify an affine expression by flattening and some amount of "
+ "simple analysis.")
+ .def_static(
"get_add", &PyAffineAddExpr::get,
"Gets an affine expression containing a sum of two expressions.")
.def_static("get_add", &PyAffineAddExpr::getLHSConstant,
diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index bc3dcd4..5a0a03b 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -73,6 +73,11 @@ MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
}
+MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr, uint32_t numDims,
+ uint32_t numSymbols) {
+ return wrap(simplifyAffineExpr(unwrap(expr), numDims, numSymbols));
+}
+
//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index 2f64aff..c2a2ab3 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -416,3 +416,11 @@ def testAffineExprShift():
assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)
+
+
+# CHECK-LABEL: TEST: testAffineExprSimplify
+@run
+def testAffineExprSimplify():
+ with Context() as ctx:
+ expr = AffineExpr.get_dim(0) + AffineExpr.get_symbol(0)
+ assert expr == AffineExpr.simplify_affine_expr(expr, 1, 1)