diff options
Diffstat (limited to 'mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 110 |
1 files changed, 56 insertions, 54 deletions
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index d4deff5..e5496e5 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, if (!(ret = moduleOp.lookupSymbol<Op>(name))) { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); + ret = Op::create(rewriter, loc, std::forward<Args>(args)...); } return ret; } @@ -54,18 +54,18 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc, Value memRef, Type elType) { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); Value dataPtr = - rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1); - Value offset = rewriter.create<LLVM::ExtractValueOp>( - loc, rewriter.getI64Type(), memRef, 2); + LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1); + Value offset = LLVM::ExtractValueOp::create(rewriter, loc, + rewriter.getI64Type(), memRef, 2); Value resPtr = - rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset); + LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset); Value size; if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) { - size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef, - ArrayRef<int64_t>{3, 0}); - size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size); + size = LLVM::ExtractValueOp::create(rewriter, loc, memRef, + ArrayRef<int64_t>{3, 0}); + size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size); } else { - size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32); + size = arith::ConstantIntOp::create(rewriter, loc, 1, 32); } return {resPtr, size}; } @@ -157,13 +157,13 @@ public: Value getCommWorld(const Location loc, ConversionPatternRewriter &rewriter) override { static constexpr int MPI_COMM_WORLD = 0x44000000; - return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), - MPI_COMM_WORLD); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + MPI_COMM_WORLD); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm); + return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm); } intptr_t getStatusIgnore() override { return 1; } @@ -195,7 +195,8 @@ public: mtype = MPI_UINT8_T; else assert(false && "unsupported type"); - return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), + mtype); } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, @@ -245,7 +246,7 @@ public: op = MPI_REPLACE; break; } - return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op); + return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op); } }; @@ -281,16 +282,16 @@ public: getOrDefineExternalStruct(loc, rewriter, name, commStructT); // get address of symbol - auto comm = rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, name)); - return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm); + auto comm = LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, name)); + return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm); } Value castComm(const Location loc, ConversionPatternRewriter &rewriter, Value comm) override { - return rewriter.create<LLVM::IntToPtrOp>( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); + return LLVM::IntToPtrOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); } intptr_t getStatusIgnore() override { return 0; } @@ -330,9 +331,9 @@ public: // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT); // get address of symbol - return rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, mtype)); + return LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, mtype)); } Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, @@ -389,9 +390,9 @@ public: // make sure global op definition exists getOrDefineExternalStruct(loc, rewriter, op, opStructT); // get address of symbol - return rewriter.create<LLVM::AddressOfOp>( - loc, LLVM::LLVMPointerType::get(context), - SymbolRefAttr::get(context, op)); + return LLVM::AddressOfOp::create(rewriter, loc, + LLVM::LLVMPointerType::get(context), + SymbolRefAttr::get(context, op)); } }; @@ -424,7 +425,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> { Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr` - auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType); + auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType); Value llvmnull = nullPtrOp.getRes(); // grab a reference to the global module op: @@ -513,9 +514,9 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> { // get communicator Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); auto outPtr = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one); + LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one); // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm) auto funcType = @@ -524,14 +525,14 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> { LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Comm_split", funcType); - auto callOp = rewriter.create<LLVM::CallOp>( - loc, funcDecl, - ValueRange{comm, adaptor.getColor(), adaptor.getKey(), - outPtr.getRes()}); + auto callOp = + LLVM::CallOp::create(rewriter, loc, funcDecl, + ValueRange{comm, adaptor.getColor(), + adaptor.getKey(), outPtr.getRes()}); // load the communicator into a register - Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult()); - res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res); + Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult()); + res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res); // if retval is checked, replace uses of retval with the results from the // call op @@ -580,14 +581,14 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> { moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType); // replace with function call - auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); - auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one); - auto callOp = rewriter.create<LLVM::CallOp>( - loc, initDecl, ValueRange{comm, rankptr.getRes()}); + auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1); + auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one); + auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl, + ValueRange{comm, rankptr.getRes()}); // load the rank into a register auto loadedRank = - rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult()); + LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult()); // if retval is checked, replace uses of retval with the results from the // call op @@ -641,10 +642,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType); // replace op with function call - auto funcCall = rewriter.create<LLVM::CallOp>( - loc, funcDecl, - ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(), - comm}); + auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl, + ValueRange{dataPtr, size, dataType, + adaptor.getDest(), + adaptor.getTag(), comm}); if (op.getRetval()) rewriter.replaceOp(op, funcCall.getResult()); else @@ -683,10 +684,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { auto mpiTraits = MPIImplTraits::get(moduleOp); Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); - Value statusIgnore = rewriter.create<LLVM::ConstantOp>( - loc, i64, mpiTraits->getStatusIgnore()); + Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64, + mpiTraits->getStatusIgnore()); statusIgnore = - rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore); + LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore); // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst, // tag, comm)` @@ -698,8 +699,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType); // replace op with function call - auto funcCall = rewriter.create<LLVM::CallOp>( - loc, funcDecl, + auto funcCall = LLVM::CallOp::create( + rewriter, loc, funcDecl, ValueRange{dataPtr, size, dataType, adaptor.getSource(), adaptor.getTag(), comm, statusIgnore}); if (op.getRetval()) @@ -738,9 +739,10 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> { // If input and output are the same, request in-place operation. if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { - sendPtr = rewriter.create<LLVM::ConstantOp>( - loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace())); - sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr); + sendPtr = LLVM::ConstantOp::create( + rewriter, loc, i64, + reinterpret_cast<int64_t>(mpiTraits->getInPlace())); + sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr); } Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); @@ -757,8 +759,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> { getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType); // replace op with function call - auto funcCall = rewriter.create<LLVM::CallOp>( - loc, funcDecl, + auto funcCall = LLVM::CallOp::create( + rewriter, loc, funcDecl, ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld}); if (op.getRetval()) |