diff options
author | Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> | 2025-02-01 14:33:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-01 07:33:22 -0600 |
commit | 48f88651a01b050a28be99e5cdffe495754ea79a (patch) | |
tree | fcd3c9d3bdff6dbcac2f32cefd7175ec05660b3a | |
parent | 9725595f3acc0c1aaa354e15ac4ee2b1f8ff4cc9 (diff) | |
download | llvm-48f88651a01b050a28be99e5cdffe495754ea79a.zip llvm-48f88651a01b050a28be99e5cdffe495754ea79a.tar.gz llvm-48f88651a01b050a28be99e5cdffe495754ea79a.tar.bz2 |
[MLIR] Extend MPI dialect (#123255)
cc @tobiasgrosser @wsmoses
this PR adds some new ops and types to the MLIR MPI dialect. the goal is
to get the minimum required ops here to get a project of us working, and
if everything works well, continue adding ops to the mpi dialect on
subsequent PRs until we achieve some level of compliance with the MPI
standard.
---
Things left to do in subsequent PRs:
- Add back the `mpi.comm` type and add as optional argument of current
implemented ops that should support it (i.e. `send`, `recv`, `isend`,
`irecv`, `allreduce`, `barrier`).
- Support defining custom `MPI_Op`s (the MPI operations, not the
tablegen `MPI_Op`) as regions.
- Add more ops.
-rw-r--r-- | mlir/include/mlir/Dialect/MPI/IR/MPI.td | 39 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 195 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 22 | ||||
-rw-r--r-- | mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 10 | ||||
-rw-r--r-- | mlir/test/Dialect/MPI/ops.mlir | 39 |
5 files changed, 293 insertions, 12 deletions
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td index 643612e..7c84443 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td @@ -215,4 +215,43 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> { let assemblyFormat = "`<` $value `>`"; } +def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">; +def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">; +def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">; +def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">; +def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">; +def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">; +def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">; +def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">; +def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">; +def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">; +def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">; +def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">; +def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">; +def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">; + +def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [ + MPI_OpNull, + MPI_OpMax, + MPI_OpMin, + MPI_OpSum, + MPI_OpProd, + MPI_OpLand, + MPI_OpBand, + MPI_OpLor, + MPI_OpBor, + MPI_OpLxor, + MPI_OpBxor, + MPI_OpMinloc, + MPI_OpMaxloc, + MPI_OpReplace + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mpi"; +} + +def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> { + let assemblyFormat = "`<` $value `>`"; +} + #endif // MLIR_DIALECT_MPI_IR_MPI_TD diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 240fac5..284ba72 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -60,6 +60,28 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { } //===----------------------------------------------------------------------===// +// CommSizeOp +//===----------------------------------------------------------------------===// + +def MPI_CommSizeOp : MPI_Op<"comm_size", []> { + let summary = "Get the size of the group associated to the communicator, " + "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`"; + let description = [{ + Communicators other than `MPI_COMM_WORLD` are not supported for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let results = ( + outs Optional<MPI_Retval> : $retval, + I32 : $size + ); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + +//===----------------------------------------------------------------------===// // SendOp //===----------------------------------------------------------------------===// @@ -71,13 +93,17 @@ def MPI_SendOp : MPI_Op<"send", []> { `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supprted for now. + Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, + I32 : $rank + ); let results = (outs Optional<MPI_Retval>:$retval); @@ -88,6 +114,42 @@ def MPI_SendOp : MPI_Op<"send", []> { } //===----------------------------------------------------------------------===// +// ISendOp +//===----------------------------------------------------------------------===// + +def MPI_ISendOp : MPI_Op<"isend", []> { + let summary = + "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + let description = [{ + MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to + rank `dest`. The `tag` value and communicator enables the library to + determine the matching of multiple sends and receives between the same + ranks. + + Communicators other than `MPI_COMM_WORLD` are not supported for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, + I32 : $rank + ); + + let results = ( + outs Optional<MPI_Retval>:$retval, + MPI_Request : $req + ); + + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict " + "`:` type($ref) `,` type($tag) `,` type($rank) " + "`->` type(results)"; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// // RecvOp //===----------------------------------------------------------------------===// @@ -100,7 +162,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supprted for now. + Communicators other than `MPI_COMM_WORLD` are not supported for now. The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object is not yet ported to MLIR. @@ -108,16 +170,134 @@ def MPI_RecvOp : MPI_Op<"recv", []> { to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, I32 : $rank + ); let results = (outs Optional<MPI_Retval>:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// IRecvOp +//===----------------------------------------------------------------------===// + +def MPI_IRecvOp : MPI_Op<"irecv", []> { + let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, " + "MPI_COMM_WORLD, &req)`"; + let description = [{ + MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` + from rank `dest`. The `tag` value and communicator enables the library to + determine the matching of multiple sends and receives between the same + ranks. + + Communicators other than `MPI_COMM_WORLD` are not supported for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, + I32 : $rank + ); + + let results = ( + outs Optional<MPI_Retval>:$retval, + MPI_Request : $req + ); + + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" + "type($ref) `,` type($tag) `,` type($rank) `->`" + "type(results)"; + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// AllReduceOp +//===----------------------------------------------------------------------===// + +def MPI_AllReduceOp : MPI_Op<"allreduce", []> { + let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, " + "MPI_COMM_WORLD)`"; + let description = [{ + MPI_Allreduce performs a reduction operation on the values in the sendbuf + array and stores the result in the recvbuf array. The operation is + performed across all processes in the communicator. + + The `op` attribute specifies the reduction operation to be performed. + Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are + supported. + + Communicators other than `MPI_COMM_WORLD` are not supported for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = ( + ins AnyMemRef : $sendbuf, + AnyMemRef : $recvbuf, + MPI_OpClassAttr : $op + ); + + let results = (outs Optional<MPI_Retval>:$retval); + + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`" + "type($sendbuf) `,` type($recvbuf)" + "(`->` type($retval)^)?"; +} + +//===----------------------------------------------------------------------===// +// BarrierOp +//===----------------------------------------------------------------------===// + +def MPI_Barrier : MPI_Op<"barrier", []> { + let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`"; + let description = [{ + MPI_Barrier blocks execution until all processes in the communicator have + reached this routine. + + Communicators other than `MPI_COMM_WORLD` are not supported for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let results = (outs Optional<MPI_Retval>:$retval); + + let assemblyFormat = "attr-dict (`:` type($retval) ^)?"; +} + +//===----------------------------------------------------------------------===// +// WaitOp +//===----------------------------------------------------------------------===// + +def MPI_Wait : MPI_Op<"wait", []> { + let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`"; + let description = [{ + MPI_Wait blocks execution until the request has completed. + + The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object + is not yet ported to MLIR. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins MPI_Request : $req); + + let results = (outs Optional<MPI_Retval>:$retval); + + let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) " + "(`->` type($retval) ^)?"; +} //===----------------------------------------------------------------------===// // FinalizeOp @@ -139,7 +319,6 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> { let assemblyFormat = "attr-dict (`:` type($retval)^)?"; } - //===----------------------------------------------------------------------===// // RetvalCheckOp //===----------------------------------------------------------------------===// @@ -163,10 +342,8 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> { let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)"; } - - //===----------------------------------------------------------------------===// -// RetvalCheckOp +// ErrorClassOp //===----------------------------------------------------------------------===// def MPI_ErrorClassOp : MPI_Op<"error_class", []> { diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index 87eefa7..fafea0e 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -40,4 +40,26 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> { }]; } +//===----------------------------------------------------------------------===// +// mpi::RequestType +//===----------------------------------------------------------------------===// + +def MPI_Request : MPI_Type<"Request", "request"> { + let summary = "MPI asynchronous request handler"; + let description = [{ + This type represents a handler to an asynchronous request. + }]; +} + +//===----------------------------------------------------------------------===// +// mpi::StatusType +//===----------------------------------------------------------------------===// + +def MPI_Status : MPI_Type<"Status", "status"> { + let summary = "MPI reception operation status type"; + let description = [{ + This type represents the status of a reception operation. + }]; +} + #endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index dcb55d8..56d8edf 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -53,6 +53,16 @@ void mlir::mpi::RecvOp::getCanonicalizationPatterns( results.add<FoldCast<mlir::mpi::RecvOp>>(context); } +void mlir::mpi::ISendOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add<FoldCast<mlir::mpi::ISendOp>>(context); +} + +void mlir::mpi::IRecvOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add<FoldCast<mlir::mpi::IRecvOp>>(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index 8f2421a..f23a7e18 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -9,6 +9,9 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32 %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32 + %retval_0, %size = mpi.comm_size : !mpi.retval, i32 + // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 @@ -21,13 +24,43 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval - // CHECK-NEXT: %3 = mpi.finalize : !mpi.retval + // CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request + %req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request + + // CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + + // CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request + %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request + + // CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + + // CHECK-NEXT: mpi.wait(%req) : !mpi.request + mpi.wait(%req) : !mpi.request + + // CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval + %err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval + + // CHECK-NEXT: mpi.barrier : !mpi.retval + mpi.barrier : !mpi.retval + + // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval + %err7 = mpi.barrier : !mpi.retval + + // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> + mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32> + + // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval + %err8 = mpi.allreduce(%ref, %ref, <MPI_SUM>) : memref<100xf32>, memref<100xf32> -> !mpi.retval + + // CHECK-NEXT: %7 = mpi.finalize : !mpi.retval %rval = mpi.finalize : !mpi.retval - // CHECK-NEXT: %4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1 + // CHECK-NEXT: %8 = mpi.retval_check %retval = <MPI_SUCCESS> : i1 %res = mpi.retval_check %retval = <MPI_SUCCESS> : i1 - // CHECK-NEXT: %5 = mpi.error_class %0 : !mpi.retval + // CHECK-NEXT: %9 = mpi.error_class %0 : !mpi.retval %errclass = mpi.error_class %err : !mpi.retval // CHECK-NEXT: return |