diff options
author | Frank Schlimbach <frank.schlimbach@intel.com> | 2025-04-01 08:58:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-01 08:58:55 +0200 |
commit | 49f080afc4466ddf415d7fc7e98989c0bd07d8ea (patch) | |
tree | 925581a9985501d1a81a34b0cf5999025e38ad52 | |
parent | aa889ed129ff26d9341c50a9eaba4db728ca6212 (diff) | |
download | llvm-49f080afc4466ddf415d7fc7e98989c0bd07d8ea.zip llvm-49f080afc4466ddf415d7fc7e98989c0bd07d8ea.tar.gz llvm-49f080afc4466ddf415d7fc7e98989c0bd07d8ea.tar.bz2 |
[mlir][mpi] Mandatory Communicator (#133280)
This is replacing #125361
- communicator is mandatory
- new mpi.comm_world
- new mp.comm_split
- lowering and test
---------
Co-authored-by: Sergio Sánchez RamÃrez <sergio.sanchez.ramirez+git@bsc.es>
-rw-r--r-- | mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 132 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 11 | ||||
-rw-r--r-- | mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 132 | ||||
-rw-r--r-- | mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 24 | ||||
-rw-r--r-- | mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir | 116 | ||||
-rw-r--r-- | mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir | 62 | ||||
-rw-r--r-- | mlir/test/Dialect/MPI/mpiops.mlir | 87 |
7 files changed, 389 insertions, 175 deletions
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index a8267b1..d78aa92 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -38,25 +38,40 @@ def MPI_InitOp : MPI_Op<"init", []> { } //===----------------------------------------------------------------------===// +// CommWorldOp +//===----------------------------------------------------------------------===// + +def MPI_CommWorldOp : MPI_Op<"comm_world", []> { + let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`"; + let description = [{ + This operation returns the predefined MPI_COMM_WORLD communicator. + }]; + + let results = (outs MPI_Comm : $comm); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + +//===----------------------------------------------------------------------===// // CommRankOp //===----------------------------------------------------------------------===// def MPI_CommRankOp : MPI_Op<"comm_rank", []> { let summary = "Get the current rank, equivalent to " - "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`"; + "`MPI_Comm_rank(comm, &rank)`"; let description = [{ - Communicators other than `MPI_COMM_WORLD` are not supported for now. - This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins MPI_Comm : $comm); + let results = ( outs Optional<MPI_Retval> : $retval, I32 : $rank ); - let assemblyFormat = "attr-dict `:` type(results)"; + let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -65,20 +80,48 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { def MPI_CommSizeOp : MPI_Op<"comm_size", []> { let summary = "Get the size of the group associated to the communicator, " - "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`"; + "equivalent to `MPI_Comm_size(comm, &size)`"; let description = [{ - Communicators other than `MPI_COMM_WORLD` are not supported for now. - This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins MPI_Comm : $comm); + let results = ( outs Optional<MPI_Retval> : $retval, I32 : $size ); - let assemblyFormat = "attr-dict `:` type(results)"; + let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)"; +} + +//===----------------------------------------------------------------------===// +// CommSplitOp +//===----------------------------------------------------------------------===// + +def MPI_CommSplitOp : MPI_Op<"comm_split", []> { + let summary = "Partition the group associated with the given communicator into " + "disjoint subgroups"; + let description = [{ + This operation splits the communicator into multiple sub-communicators. + The color value determines the group of processes that will be part of the + new communicator. The key value determines the rank of the calling process + in the new communicator. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key); + + let results = ( + outs Optional<MPI_Retval> : $retval, + MPI_Comm : $newcomm + ); + + let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` " + "type(results)"; } //===----------------------------------------------------------------------===// @@ -87,14 +130,12 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { def MPI_SendOp : MPI_Op<"send", []> { let summary = - "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`"; let description = [{ MPI_Send performs a blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. - This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; @@ -102,12 +143,13 @@ def MPI_SendOp : MPI_Op<"send", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $dest + I32 : $dest, + MPI_Comm : $comm ); let results = (outs Optional<MPI_Retval>:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $dest `)` attr-dict `:` " + let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($dest)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; @@ -119,15 +161,13 @@ def MPI_SendOp : MPI_Op<"send", []> { def MPI_ISendOp : MPI_Op<"isend", []> { let summary = - "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`"; let description = [{ MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. - This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; @@ -135,7 +175,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank + I32 : $dest, + MPI_Comm : $comm ); let results = ( @@ -143,8 +184,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> { MPI_Request : $req ); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict " - "`:` type($ref) `,` type($tag) `,` type($rank) " + let assemblyFormat = "`(` $ref `,` $tag `,` $dest `,` $comm`)` attr-dict " + "`:` type($ref) `,` type($tag) `,` type($dest) " "`->` type(results)"; let hasCanonicalizer = 1; } @@ -155,14 +196,13 @@ def MPI_ISendOp : MPI_Op<"isend", []> { def MPI_RecvOp : MPI_Op<"recv", []> { let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, source, tag, " - "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; + "comm, MPI_STATUS_IGNORE)`"; let description = [{ MPI_Recv performs a blocking receive of `size` elements of type `dtype` from rank `source`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object is not yet ported to MLIR. @@ -172,13 +212,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let arguments = ( ins AnyMemRef : $ref, - I32 : $tag, I32 : $source + I32 : $tag, I32 : $source, + MPI_Comm : $comm ); let results = (outs Optional<MPI_Retval>:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $source `)` attr-dict `:` " - "type($ref) `,` type($tag) `,` type($source)" + let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm `)` attr-dict" + " `:` type($ref) `,` type($tag) `,` type($source) " "(`->` type($retval)^)?"; let hasCanonicalizer = 1; } @@ -188,16 +229,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> { //===----------------------------------------------------------------------===// def MPI_IRecvOp : MPI_Op<"irecv", []> { - let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, " - "MPI_COMM_WORLD, &req)`"; + let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, source, tag, " + "comm, &req)`"; let description = [{ MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` - from rank `dest`. The `tag` value and communicator enables the library to + from rank `source`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supported for now. - This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; @@ -205,7 +244,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank + I32 : $source, + MPI_Comm : $comm ); let results = ( @@ -213,9 +253,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { MPI_Request : $req ); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" - "type($ref) `,` type($tag) `,` type($rank) `->`" - "type(results)"; + let assemblyFormat = "`(` $ref `,` $tag `,` $source `,` $comm`)` attr-dict " + "`:` type($ref) `,` type($tag) `,` type($source)" + "`->` type(results)"; let hasCanonicalizer = 1; } @@ -224,8 +264,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { //===----------------------------------------------------------------------===// def MPI_AllReduceOp : MPI_Op<"allreduce", []> { - let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, " - "MPI_COMM_WORLD)`"; + let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`"; let description = [{ MPI_Allreduce performs a reduction operation on the values in the sendbuf array and stores the result in the recvbuf array. The operation is @@ -235,8 +274,6 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are supported. - Communicators other than `MPI_COMM_WORLD` are not supported for now. - This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; @@ -244,13 +281,14 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let arguments = ( ins AnyMemRef : $sendbuf, AnyMemRef : $recvbuf, - MPI_OpClassEnum : $op + MPI_OpClassEnum : $op, + MPI_Comm : $comm ); let results = (outs Optional<MPI_Retval>:$retval); - let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`" - "type($sendbuf) `,` type($recvbuf)" + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` " + "attr-dict `:` type($sendbuf) `,` type($recvbuf) " "(`->` type($retval)^)?"; } @@ -259,20 +297,23 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { //===----------------------------------------------------------------------===// def MPI_Barrier : MPI_Op<"barrier", []> { - let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`"; + let summary = "Equivalent to `MPI_Barrier(comm)`"; let description = [{ MPI_Barrier blocks execution until all processes in the communicator have reached this routine. - Communicators other than `MPI_COMM_WORLD` are not supported for now. - This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins MPI_Comm : $comm); + let results = (outs Optional<MPI_Retval>:$retval); - let assemblyFormat = "attr-dict (`:` type($retval) ^)?"; + let assemblyFormat = [{ + `(` $comm `)` attr-dict + (`->` type($retval)^)? + }]; } //===----------------------------------------------------------------------===// @@ -295,8 +336,7 @@ def MPI_Wait : MPI_Op<"wait", []> { let results = (outs Optional<MPI_Retval>:$retval); - let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) " - "(`->` type($retval) ^)?"; + let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index a55d30e..adc35a7 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -41,6 +41,17 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> { } //===----------------------------------------------------------------------===// +// mpi::CommType +//===----------------------------------------------------------------------===// + +def MPI_Comm : MPI_Type<"Comm", "comm"> { + let summary = "MPI communicator handler"; + let description = [{ + This type represents a handler for the MPI communicator. + }]; +} + +//===----------------------------------------------------------------------===// // mpi::RequestType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 4e0f593..9df5e99 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -83,9 +83,17 @@ public: ModuleOp &getModuleOp() { return moduleOp; } /// Gets or creates MPI_COMM_WORLD as a Value. + /// Different MPI implementations have different communicator types. + /// Using i64 as a portable, intermediate type. + /// Appropriate cast needs to take place before calling MPI functions. virtual Value getCommWorld(const Location loc, ConversionPatternRewriter &rewriter) = 0; + /// Type converter provides i64 type for communicator type. + /// Converts to native type, which might be ptr or int or whatever. + virtual Value castComm(const Location loc, + ConversionPatternRewriter &rewriter, Value comm) = 0; + /// Get the MPI_STATUS_IGNORE value (typically a pointer type). virtual intptr_t getStatusIgnore() = 0; @@ -139,10 +147,15 @@ public: Value getCommWorld(const Location loc, ConversionPatternRewriter &rewriter) override { static constexpr int MPI_COMM_WORLD = 0x44000000; - return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), + return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), MPI_COMM_WORLD); } + Value castComm(const Location loc, ConversionPatternRewriter &rewriter, + Value comm) override { + return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm); + } + intptr_t getStatusIgnore() override { return 1; } Value getDataType(const Location loc, ConversionPatternRewriter &rewriter, @@ -256,9 +269,16 @@ public: getOrDefineExternalStruct(loc, rewriter, name, commStructT); // get address of symbol - return rewriter.create<LLVM::AddressOfOp>( + auto comm = rewriter.create<LLVM::AddressOfOp>( loc, LLVM::LLVMPointerType::get(context), SymbolRefAttr::get(context, name)); + return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm); + } + + Value castComm(const Location loc, ConversionPatternRewriter &rewriter, + Value comm) override { + return rewriter.create<LLVM::IntToPtrOp>( + loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); } intptr_t getStatusIgnore() override { return 0; } @@ -441,6 +461,78 @@ struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> { }; //===----------------------------------------------------------------------===// +// CommWorldOpLowering +//===----------------------------------------------------------------------===// + +struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // grab a reference to the global module op: + auto moduleOp = op->getParentOfType<ModuleOp>(); + auto mpiTraits = MPIImplTraits::get(moduleOp); + // get MPI_COMM_WORLD + rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter)); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// CommSplitOpLowering +//===----------------------------------------------------------------------===// + +struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // grab a reference to the global module op: + auto moduleOp = op->getParentOfType<ModuleOp>(); + auto mpiTraits = MPIImplTraits::get(moduleOp); + Type i32 = rewriter.getI32Type(); + Type ptrType = LLVM::LLVMPointerType::get(op->getContext()); + Location loc = op.getLoc(); + + // get communicator + Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); + auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); + auto outPtr = + rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one); + + // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm) + auto funcType = + LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType}); + // get or create function declaration: + LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, + "MPI_Comm_split", funcType); + + auto callOp = rewriter.create<LLVM::CallOp>( + loc, funcDecl, + ValueRange{comm, adaptor.getColor(), adaptor.getKey(), + outPtr.getRes()}); + + // load the communicator into a register + auto res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult()); + + // if retval is checked, replace uses of retval with the results from the + // call op + SmallVector<Value> replacements; + if (op.getRetval()) + replacements.push_back(callOp.getResult()); + + // replace op + replacements.push_back(res.getRes()); + rewriter.replaceOp(op, replacements); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// // CommRankOpLowering //===----------------------------------------------------------------------===// @@ -462,21 +554,21 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> { auto moduleOp = op->getParentOfType<ModuleOp>(); auto mpiTraits = MPIImplTraits::get(moduleOp); - // get MPI_COMM_WORLD - Value commWorld = mpiTraits->getCommWorld(loc, rewriter); + // get communicator + Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)` auto rankFuncType = - LLVM::LLVMFunctionType::get(i32, {commWorld.getType(), ptrType}); + LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType}); // get or create function declaration: LLVM::LLVMFuncOp initDecl = getOrDefineFunction( moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType); - // replace init with function call + // replace with function call auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one); auto callOp = rewriter.create<LLVM::CallOp>( - loc, initDecl, ValueRange{commWorld, rankptr.getRes()}); + loc, initDecl, ValueRange{comm, rankptr.getRes()}); // load the rank into a register auto loadedRank = @@ -523,12 +615,12 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> { getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType); auto mpiTraits = MPIImplTraits::get(moduleOp); Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); - Value commWorld = mpiTraits->getCommWorld(loc, rewriter); + Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst, // tag, comm)` auto funcType = LLVM::LLVMFunctionType::get( - i32, {ptrType, i32, dataType.getType(), i32, i32, commWorld.getType()}); + i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()}); // get or create function declaration: LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType); @@ -537,7 +629,7 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> { auto funcCall = rewriter.create<LLVM::CallOp>( loc, funcDecl, ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(), - commWorld}); + comm}); if (op.getRetval()) rewriter.replaceOp(op, funcCall.getResult()); else @@ -575,7 +667,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType); auto mpiTraits = MPIImplTraits::get(moduleOp); Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); - Value commWorld = mpiTraits->getCommWorld(loc, rewriter); + Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); Value statusIgnore = rewriter.create<LLVM::ConstantOp>( loc, i64, mpiTraits->getStatusIgnore()); statusIgnore = @@ -585,7 +677,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { // tag, comm)` auto funcType = LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32, - i32, commWorld.getType(), ptrType}); + i32, comm.getType(), ptrType}); // get or create function declaration: LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType); @@ -594,7 +686,7 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { auto funcCall = rewriter.create<LLVM::CallOp>( loc, funcDecl, ValueRange{dataPtr, size, dataType, adaptor.getSource(), - adaptor.getTag(), commWorld, statusIgnore}); + adaptor.getTag(), comm, statusIgnore}); if (op.getRetval()) rewriter.replaceOp(op, funcCall.getResult()); else @@ -629,7 +721,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> { getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType); Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp()); - Value commWorld = mpiTraits->getCommWorld(loc, rewriter); + Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); + // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)' auto funcType = LLVM::LLVMFunctionType::get( @@ -676,8 +769,15 @@ struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add<CommRankOpLowering, FinalizeOpLowering, InitOpLowering, - SendOpLowering, RecvOpLowering, AllReduceOpLowering>(converter); + // Using i64 as a portable, intermediate type for !mpi.comm. + // It would be nicer to somehow get the right type directly, but TLDI is not + // available here. + converter.addConversion([](mpi::CommType type) { + return IntegerType::get(type.getContext(), 64); + }); + patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering, + FinalizeOpLowering, InitOpLowering, SendOpLowering, + RecvOpLowering, AllReduceOpLowering>(converter); } void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) { diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 87c2938..cafbf83 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -310,11 +310,16 @@ public: } // Otherwise call create mpi::CommRankOp - auto rank = rewriter - .create<mpi::CommRankOp>( - loc, TypeRange{mpi::RetvalType::get(op->getContext()), - rewriter.getI32Type()}) - .getRank(); + auto ctx = op.getContext(); + Value commWorld = + rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx)); + auto rank = + rewriter + .create<mpi::CommRankOp>( + loc, + TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, + commWorld) + .getRank(); rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank); return success(); @@ -652,6 +657,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { auto upperSendOffset = rewriter.create<arith::SubIOp>( loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); + Value commWorld = rewriter.create<mpi::CommWorldOp>( + loc, mpi::CommType::get(op->getContext())); + // Make sure we send/recv in a way that does not lead to a dead-lock. // The current approach is by far not optimal, this should be at least // be a red-black pattern or using MPI_sendrecv. @@ -680,7 +688,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { auto subview = builder.create<memref::SubViewOp>( loc, array, offsets, dimSizes, strides); builder.create<memref::CopyOp>(loc, subview, buffer); - builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to); + builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to, + commWorld); builder.create<scf::YieldOp>(loc); }); // if has neighbor: receive halo data into buffer and copy to array @@ -688,7 +697,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { loc, hasFrom, [&](OpBuilder &builder, Location loc) { offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) : OpFoldResult(lowerRecvOffset); - builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from); + builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from, + commWorld); auto subview = builder.create<memref::SubViewOp>( loc, array, offsets, dimSizes, strides); builder.create<memref::CopyOp>(loc, buffer, subview); diff --git a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir index b630ce3..174f7c7 100644 --- a/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir +++ b/mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir @@ -3,6 +3,7 @@ // COM: Test MPICH ABI // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // CHECK: llvm.func @MPI_Finalize() -> i32 +// CHECK: llvm.func @MPI_Comm_split(i32, i32, i32, !llvm.ptr) -> i32 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, i32, i32, i32, i32) -> i32 // CHECK: llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 @@ -22,11 +23,14 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32 %0 = mpi.init : !mpi.retval - // CHECK: [[v8:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32 + // CHECK: [[comm:%.*]] = llvm.mlir.constant(1140850688 : i64) : i64 + %comm = mpi.comm_world : !mpi.comm + + // CHECK: [[v8:%.*]] = llvm.trunc [[comm]] : i64 to i32 // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (i32, !llvm.ptr) -> i32 - %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32 // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -35,9 +39,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32 // CHECK: [[v18:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[v19:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32 - // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 - mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + // CHECK: [[comm_1:%.*]] = llvm.trunc [[comm]] : i64 to i32 + // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[comm_1]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 + mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -45,9 +49,9 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32 // CHECK: [[v26:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[v27:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32 - // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 - %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK: [[comm_2:%.*]] = llvm.trunc [[comm]] : i64 to i32 + // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[comm_2]]) : (!llvm.ptr, i32, i32, i32, i32, i32) -> i32 + %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -55,11 +59,11 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32 // CHECK: [[v34:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[v35:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32 + // CHECK: [[comm_3:%.*]] = llvm.trunc [[comm]] : i64 to i32 // CHECK: [[v36:%.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr - // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[comm_3]], [[v37]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 + mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -67,27 +71,38 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32 // CHECK: [[v44:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[v45:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32 + // CHECK: [[comm_4:%.*]] = llvm.trunc [[comm]] : i64 to i32 // CHECK: [[v46:%.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr - // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval - - // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v51:%.*]] = llvm.getelementptr [[v49]][[[v50]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v52:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v53:%.*]] = llvm.trunc [[v52]] : i64 to i32 - // CHECK: [[v54:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v55:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v56:%.*]] = llvm.getelementptr [[v54]][[[v55]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - // CHECK: [[v57:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32 - // CHECK: [[v59:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 - // CHECK: [[v60:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32 - // CHECK: [[v61:%.*]] = llvm.mlir.constant(1140850688 : i32) : i32 - // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 - mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> + // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[comm_4]], [[v47]]) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 + %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval + + // CHECK: [[v51:%.*]] = llvm.mlir.constant(10 : i32) : i32 + %color = arith.constant 10 : i32 + // CHECK: [[v52:%.*]] = llvm.mlir.constant(22 : i32) : i32 + %key = arith.constant 22 : i32 + // CHECK: [[v53:%.*]] = llvm.trunc [[comm]] : i64 to i32 + // CHECK: [[v54:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[v55:%.*]] = llvm.alloca [[v54]] x i32 : (i32) -> !llvm.ptr + // CHECK: [[v56:%.*]] = llvm.call @MPI_Comm_split([[v53]], [[v51]], [[v52]], [[v55]]) : (i32, i32, i32, !llvm.ptr) -> i32 + // CHECK: [[v57:%.*]] = llvm.load [[v55]] : !llvm.ptr -> i32 + %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm + + // CHECK: [[v59:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v60:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v61:%.*]] = llvm.getelementptr [[v59]][[[v60]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v62:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v63:%.*]] = llvm.trunc [[v62]] : i64 to i32 + // CHECK: [[v64:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v65:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v66:%.*]] = llvm.getelementptr [[v64]][[[v65]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + // CHECK: [[v67:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[v68:%.*]] = llvm.trunc [[v67]] : i64 to i32 + // CHECK: [[v69:%.*]] = llvm.mlir.constant(1275069450 : i32) : i32 + // CHECK: [[v70:%.*]] = llvm.mlir.constant(1476395011 : i32) : i32 + // CHECK: [[v71:%.*]] = llvm.trunc [[comm]] : i64 to i32 + // CHECK: [[v72:%.*]] = llvm.call @MPI_Allreduce([[v61]], [[v66]], [[v63]], [[v69]], [[v70]], [[v71]]) : (!llvm.ptr, !llvm.ptr, i32, i32, i32, i32) -> i32 + mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> // CHECK: llvm.call @MPI_Finalize() : () -> i32 %3 = mpi.finalize : !mpi.retval @@ -101,6 +116,7 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "MPICH">} { // COM: Test OpenMPI ABI // CHECK: module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI">} { // CHECK: llvm.func @MPI_Finalize() -> i32 +// CHECK: llvm.func @MPI_Comm_split(!llvm.ptr, i32, i32, !llvm.ptr) -> i32 // CHECK: llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 // CHECK: llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 // CHECK: llvm.mlir.global external @ompi_mpi_float() {addr_space = 0 : i32} : !llvm.struct<"ompi_predefined_datatype_t", opaque> @@ -122,11 +138,14 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // CHECK: [[v7:%.*]] = llvm.call @MPI_Init([[v6]], [[v6]]) : (!llvm.ptr, !llvm.ptr) -> i32 %0 = mpi.init : !mpi.retval + %comm = mpi.comm_world : !mpi.comm // CHECK: [[v8:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[comm:%.*]] = llvm.ptrtoint [[v8]] : !llvm.ptr to i64 + // CHECK: [[comm_1:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr // CHECK: [[v9:%.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: [[v10:%.*]] = llvm.alloca [[v9]] x i32 : (i32) -> !llvm.ptr - // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[v8]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32 - %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + // CHECK: [[v11:%.*]] = llvm.call @MPI_Comm_rank([[comm_1]], [[v10]]) : (!llvm.ptr, !llvm.ptr) -> i32 + %retval, %rank = mpi.comm_rank(%comm) : !mpi.retval, i32 // CHECK: [[v12:%.*]] = llvm.load [[v10]] : !llvm.ptr -> i32 // CHECK: [[v13:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -135,9 +154,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // CHECK: [[v16:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v17:%.*]] = llvm.trunc [[v16]] : i64 to i32 // CHECK: [[v18:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v19:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v19:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr // CHECK: [[v20:%.*]] = llvm.call @MPI_Send([[v15]], [[v17]], [[v18]], [[v12]], [[v12]], [[v19]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 - mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 // CHECK: [[v21:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v22:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -145,9 +164,9 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // CHECK: [[v24:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v25:%.*]] = llvm.trunc [[v24]] : i64 to i32 // CHECK: [[v26:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v27:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v27:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr // CHECK: [[v28:%.*]] = llvm.call @MPI_Send([[v23]], [[v25]], [[v26]], [[v12]], [[v12]], [[v27]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32 - %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + %1 = mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval // CHECK: [[v29:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v30:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -155,11 +174,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // CHECK: [[v32:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v33:%.*]] = llvm.trunc [[v32]] : i64 to i32 // CHECK: [[v34:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v35:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v35:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr // CHECK: [[v36:%.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: [[v37:%.*]] = llvm.inttoptr [[v36]] : i64 to !llvm.ptr // CHECK: [[v38:%.*]] = llvm.call @MPI_Recv([[v31]], [[v33]], [[v34]], [[v12]], [[v12]], [[v35]], [[v37]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 - mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 // CHECK: [[v39:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v40:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -167,11 +186,11 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // CHECK: [[v42:%.*]] = llvm.extractvalue [[v5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v43:%.*]] = llvm.trunc [[v42]] : i64 to i32 // CHECK: [[v44:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr - // CHECK: [[v45:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v45:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr // CHECK: [[v46:%.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: [[v47:%.*]] = llvm.inttoptr [[v46]] : i64 to !llvm.ptr // CHECK: [[v48:%.*]] = llvm.call @MPI_Recv([[v41]], [[v43]], [[v44]], [[v12]], [[v12]], [[v45]], [[v47]]) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32 - %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + %2 = mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval // CHECK: [[v49:%.*]] = llvm.extractvalue [[v5]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[v50:%.*]] = llvm.extractvalue [[v5]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -185,11 +204,22 @@ module attributes { dlti.map = #dlti.map<"MPI:Implementation" = "OpenMPI"> } { // CHECK: [[v58:%.*]] = llvm.trunc [[v57]] : i64 to i32 // CHECK: [[v59:%.*]] = llvm.mlir.addressof @ompi_mpi_float : !llvm.ptr // CHECK: [[v60:%.*]] = llvm.mlir.addressof @ompi_mpi_sum : !llvm.ptr - // CHECK: [[v61:%.*]] = llvm.mlir.addressof @ompi_mpi_comm_world : !llvm.ptr + // CHECK: [[v61:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr // CHECK: [[v62:%.*]] = llvm.call @MPI_Allreduce([[v51]], [[v56]], [[v53]], [[v59]], [[v60]], [[v61]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32 - mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> + mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> + + // CHECK: [[v71:%.*]] = llvm.mlir.constant(10 : i32) : i32 + %color = arith.constant 10 : i32 + // CHECK: [[v72:%.*]] = llvm.mlir.constant(22 : i32) : i32 + %key = arith.constant 22 : i32 + // CHECK: [[v73:%.*]] = llvm.inttoptr [[comm]] : i64 to !llvm.ptr + // CHECK: [[v74:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: [[v75:%.*]] = llvm.alloca [[v74]] x !llvm.ptr : (i32) -> !llvm.ptr + // CHECK: [[v76:%.*]] = llvm.call @MPI_Comm_split([[v73]], [[v71]], [[v72]], [[v75]]) : (!llvm.ptr, i32, i32, !llvm.ptr) -> i32 + // CHECK: [[v77:%.*]] = llvm.load [[v75]] : !llvm.ptr -> i32 + %split = mpi.comm_split(%comm, %color, %key) : !mpi.comm - // CHECK: [[v49:%.*]] = llvm.call @MPI_Finalize() : () -> i32 + // CHECK: llvm.call @MPI_Finalize() : () -> i32 %3 = mpi.finalize : !mpi.retval return diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir index 4e60c6f..23756bb 100644 --- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir +++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir @@ -4,7 +4,7 @@ // CHECK: mesh.mesh @mesh0 mesh.mesh @mesh0(shape = 3x4x5) func.func @process_multi_index() -> (index, index, index) { - // CHECK: mpi.comm_rank : !mpi.retval, i32 + // CHECK: mpi.comm_rank // CHECK-DAG: %[[v4:.*]] = arith.remsi // CHECK-DAG: %[[v0:.*]] = arith.remsi // CHECK-DAG: %[[v1:.*]] = arith.remsi @@ -15,7 +15,7 @@ func.func @process_multi_index() -> (index, index, index) { // CHECK-LABEL: func @process_linear_index func.func @process_linear_index() -> index { - // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32 + // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index %0 = mesh.process_linear_index on @mesh0 : index // CHECK: return %[[cast]] : index @@ -113,17 +113,17 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } { // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32 // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32 // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8> - // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 - // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8 - // CHECK-SAME: to memref<2x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32 - // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8 - // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8 + // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> + // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>> to memref<2x120x120xi8> + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]], [[v0]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> + // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>> // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8> %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8> - // CHECK: return [[res:%.*]] : memref<120x120x120xi8> + // CHECK: return [[varg0]] : memref<120x120x120xi8> return %res : memref<120x120x120xi8> } } @@ -140,41 +140,44 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32 + // CHECK-NEXT: [[v0:%.*]] = mpi.comm_world : !mpi.comm // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x5xi8>, i32, i32 // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v0]]) : memref<117x113x6xi8>, i32, i32 // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> - // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x3x120xi8>, i32, i32 // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v1]]) : memref<117x4x120xi8>, i32, i32 // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> + // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<1x120x120xi8>, i32, i32 // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32 // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8> // CHECK: return [[varg0]] : memref<120x120x120xi8> @@ -191,45 +194,48 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } { // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32 // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32 // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8> + // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8> // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8> - // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x5xi8>, i32, i32 // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>> // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8> // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8> // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8> - // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32 - // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]], [[v1]]) : memref<117x113x6xi8>, i32, i32 // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>> // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8> + // CHECK-NEXT: [[v2:%.*]] = mpi.comm_world : !mpi.comm // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8> // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8> - // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x3x120xi8>, i32, i32 // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8> // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]], [[v2]]) : memref<117x4x120xi8>, i32, i32 // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>> // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8> + // CHECK-NEXT: [[v3:%.*]] = mpi.comm_world : !mpi.comm // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8> - // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32 + // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<1x120x120xi8>, i32, i32 // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>> // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8> // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8> // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8> - // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32 + // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32 // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8> - // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8> + // CHECK-NEXT: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8> %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8> - // CHECK: return [[v1]] : tensor<120x120x120xi8> + // CHECK-NEXT: return [[v4]] : tensor<120x120x120xi8> return %res : tensor<120x120x120xi8> } } diff --git a/mlir/test/Dialect/MPI/mpiops.mlir b/mlir/test/Dialect/MPI/mpiops.mlir index fb43336..ef45762 100644 --- a/mlir/test/Dialect/MPI/mpiops.mlir +++ b/mlir/test/Dialect/MPI/mpiops.mlir @@ -1,66 +1,83 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s +// CHECK-LABEL: func.func @mpi_test( +// CHECK-SAME: [[varg0:%.*]]: memref<100xf32>) { func.func @mpi_test(%ref : memref<100xf32>) -> () { // Note: the !mpi.retval result is optional on all operations except mpi.error_class - // CHECK: %0 = mpi.init : !mpi.retval + // CHECK-NEXT: [[v0:%.*]] = mpi.init : !mpi.retval %err = mpi.init : !mpi.retval - // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32 - %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + // CHECK-NEXT: [[v1:%.*]] = mpi.comm_world : !mpi.comm + %comm = mpi.comm_world : !mpi.comm - // CHECK-NEXT: %retval_0, %size = mpi.comm_size : !mpi.retval, i32 - %retval_0, %size = mpi.comm_size : !mpi.retval, i32 + // CHECK-NEXT: [[vrank:%.*]] = mpi.comm_rank([[v1]]) : i32 + %rank = mpi.comm_rank(%comm) : i32 - // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 - mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 + // CHECK-NEXT: [[vretval:%.*]], [[vrank_0:%.*]] = mpi.comm_rank([[v1]]) : !mpi.retval, i32 + %retval, %rank_1 = mpi.comm_rank(%comm) : !mpi.retval, i32 - // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval - %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: [[vsize:%.*]] = mpi.comm_size([[v1]]) : i32 + %size = mpi.comm_size(%comm) : i32 - // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 - mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 + // CHECK-NEXT: [[vretval_1:%.*]], [[vsize_2:%.*]] = mpi.comm_size([[v1]]) : !mpi.retval, i32 + %retval_0, %size_1 = mpi.comm_size(%comm) : !mpi.retval, i32 - // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval - %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.comm + %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.comm - // CHECK-NEXT: %req = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request - %req = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request + // CHECK-NEXT: [[vretval_3:%.*]], [[vnewcomm_4:%.*]] = mpi.comm_split([[v1]], [[vrank]], [[vrank]]) : !mpi.retval, !mpi.comm + %retval_1, %new_comm_1 = mpi.comm_split(%comm, %rank, %rank) : !mpi.retval, !mpi.comm - // CHECK-NEXT: %retval_1, %req_2 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request - %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + // CHECK-NEXT: mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 + mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 - // CHECK-NEXT: %req_3 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request - %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.request + // CHECK-NEXT: [[v2:%.*]] = mpi.send([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval + %retval_2 = mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval - // CHECK-NEXT: %retval_4, %req_5 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request - %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + // CHECK-NEXT: mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 + mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 - // CHECK-NEXT: mpi.wait(%req) : !mpi.request - mpi.wait(%req) : !mpi.request + // CHECK-NEXT: [[v3:%.*]] = mpi.recv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval + %retval_3 = mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval - // CHECK-NEXT: %3 = mpi.wait(%req_2) : !mpi.request -> !mpi.retval + // CHECK-NEXT: [[vretval_5:%.*]], [[vreq:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + %err4, %req2 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + + // CHECK-NEXT: [[vreq_6:%.*]] = mpi.isend([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request + %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request + + // CHECK-NEXT: [[vreq_7:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.request + %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.request + + // CHECK-NEXT: [[vretval_8:%.*]], [[vreq_9:%.*]] = mpi.irecv([[varg0]], [[vrank]], [[vrank]], [[v1]]) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + %err5, %req4 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32 -> !mpi.retval, !mpi.request + + // CHECK-NEXT: mpi.wait([[vreq_9]]) : !mpi.request + mpi.wait(%req4) : !mpi.request + + // CHECK-NEXT: [[v4:%.*]] = mpi.wait([[vreq]]) : !mpi.request -> !mpi.retval %err6 = mpi.wait(%req2) : !mpi.request -> !mpi.retval - // CHECK-NEXT: mpi.barrier : !mpi.retval - mpi.barrier : !mpi.retval + // CHECK-NEXT: mpi.barrier([[v1]]) + mpi.barrier(%comm) - // CHECK-NEXT: %5 = mpi.barrier : !mpi.retval - %err7 = mpi.barrier : !mpi.retval + // CHECK-NEXT: [[v5:%.*]] = mpi.barrier([[v1]]) -> !mpi.retval + %err7 = mpi.barrier(%comm) -> !mpi.retval - // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> - mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> + // CHECK-NEXT: [[v6:%.*]] = mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> -> !mpi.retval + %err8 = mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> -> !mpi.retval - // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval - %err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval + // CHECK-NEXT: mpi.allreduce([[varg0]], [[varg0]], MPI_SUM, [[v1]]) : memref<100xf32>, memref<100xf32> + mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32> - // CHECK-NEXT: %7 = mpi.finalize : !mpi.retval + // CHECK-NEXT: [[v7:%.*]] = mpi.finalize : !mpi.retval %rval = mpi.finalize : !mpi.retval - // CHECK-NEXT: %8 = mpi.retval_check %retval = <MPI_SUCCESS> : i1 + // CHECK-NEXT: [[v8:%.*]] = mpi.retval_check [[vretval:%.*]] = <MPI_SUCCESS> : i1 %res = mpi.retval_check %retval = <MPI_SUCCESS> : i1 - // CHECK-NEXT: %9 = mpi.error_class %0 : !mpi.retval + // CHECK-NEXT: [[v9:%.*]] = mpi.error_class [[v0]] : !mpi.retval %errclass = mpi.error_class %err : !mpi.retval // CHECK-NEXT: return |