aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp110
1 files changed, 56 insertions, 54 deletions
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index d4deff5..e5496e5 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
+ ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
}
return ret;
}
@@ -54,18 +54,18 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
Value memRef, Type elType) {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value dataPtr =
- rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
- Value offset = rewriter.create<LLVM::ExtractValueOp>(
- loc, rewriter.getI64Type(), memRef, 2);
+ LLVM::ExtractValueOp::create(rewriter, loc, ptrType, memRef, 1);
+ Value offset = LLVM::ExtractValueOp::create(rewriter, loc,
+ rewriter.getI64Type(), memRef, 2);
Value resPtr =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
+ LLVM::GEPOp::create(rewriter, loc, ptrType, elType, dataPtr, offset);
Value size;
if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
- size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
- ArrayRef<int64_t>{3, 0});
- size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
+ size = LLVM::ExtractValueOp::create(rewriter, loc, memRef,
+ ArrayRef<int64_t>{3, 0});
+ size = LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), size);
} else {
- size = rewriter.create<arith::ConstantIntOp>(loc, 1, 32);
+ size = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
}
return {resPtr, size};
}
@@ -157,13 +157,13 @@ public:
Value getCommWorld(const Location loc,
ConversionPatternRewriter &rewriter) override {
static constexpr int MPI_COMM_WORLD = 0x44000000;
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
- MPI_COMM_WORLD);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ MPI_COMM_WORLD);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
- return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
+ return LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), comm);
}
intptr_t getStatusIgnore() override { return 1; }
@@ -195,7 +195,8 @@ public:
mtype = MPI_UINT8_T;
else
assert(false && "unsupported type");
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ mtype);
}
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
@@ -245,7 +246,7 @@ public:
op = MPI_REPLACE;
break;
}
- return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
+ return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), op);
}
};
@@ -281,16 +282,16 @@ public:
getOrDefineExternalStruct(loc, rewriter, name, commStructT);
// get address of symbol
- auto comm = rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, name));
- return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
+ auto comm = LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, name));
+ return LLVM::PtrToIntOp::create(rewriter, loc, rewriter.getI64Type(), comm);
}
Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
Value comm) override {
- return rewriter.create<LLVM::IntToPtrOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
+ return LLVM::IntToPtrOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
}
intptr_t getStatusIgnore() override { return 0; }
@@ -330,9 +331,9 @@ public:
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
// get address of symbol
- return rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, mtype));
+ return LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, mtype));
}
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
@@ -389,9 +390,9 @@ public:
// make sure global op definition exists
getOrDefineExternalStruct(loc, rewriter, op, opStructT);
// get address of symbol
- return rewriter.create<LLVM::AddressOfOp>(
- loc, LLVM::LLVMPointerType::get(context),
- SymbolRefAttr::get(context, op));
+ return LLVM::AddressOfOp::create(rewriter, loc,
+ LLVM::LLVMPointerType::get(context),
+ SymbolRefAttr::get(context, op));
}
};
@@ -424,7 +425,7 @@ struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
- auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
+ auto nullPtrOp = LLVM::ZeroOp::create(rewriter, loc, ptrType);
Value llvmnull = nullPtrOp.getRes();
// grab a reference to the global module op:
@@ -513,9 +514,9 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
// get communicator
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
- auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
auto outPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, comm.getType(), one);
// int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
auto funcType =
@@ -524,14 +525,14 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
"MPI_Comm_split", funcType);
- auto callOp = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
- ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
- outPtr.getRes()});
+ auto callOp =
+ LLVM::CallOp::create(rewriter, loc, funcDecl,
+ ValueRange{comm, adaptor.getColor(),
+ adaptor.getKey(), outPtr.getRes()});
// load the communicator into a register
- Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
- res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
+ Value res = LLVM::LoadOp::create(rewriter, loc, i32, outPtr.getResult());
+ res = LLVM::SExtOp::create(rewriter, loc, rewriter.getI64Type(), res);
// if retval is checked, replace uses of retval with the results from the
// call op
@@ -580,14 +581,14 @@ struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
// replace with function call
- auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
- auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
- auto callOp = rewriter.create<LLVM::CallOp>(
- loc, initDecl, ValueRange{comm, rankptr.getRes()});
+ auto one = LLVM::ConstantOp::create(rewriter, loc, i32, 1);
+ auto rankptr = LLVM::AllocaOp::create(rewriter, loc, ptrType, i32, one);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, initDecl,
+ ValueRange{comm, rankptr.getRes()});
// load the rank into a register
auto loadedRank =
- rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
+ LLVM::LoadOp::create(rewriter, loc, i32, rankptr.getResult());
// if retval is checked, replace uses of retval with the results from the
// call op
@@ -641,10 +642,10 @@ struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
- ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
- comm});
+ auto funcCall = LLVM::CallOp::create(rewriter, loc, funcDecl,
+ ValueRange{dataPtr, size, dataType,
+ adaptor.getDest(),
+ adaptor.getTag(), comm});
if (op.getRetval())
rewriter.replaceOp(op, funcCall.getResult());
else
@@ -683,10 +684,10 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
auto mpiTraits = MPIImplTraits::get(moduleOp);
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
- Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
- loc, i64, mpiTraits->getStatusIgnore());
+ Value statusIgnore = LLVM::ConstantOp::create(rewriter, loc, i64,
+ mpiTraits->getStatusIgnore());
statusIgnore =
- rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrType, statusIgnore);
// LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
// tag, comm)`
@@ -698,8 +699,8 @@ struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
ValueRange{dataPtr, size, dataType, adaptor.getSource(),
adaptor.getTag(), comm, statusIgnore});
if (op.getRetval())
@@ -738,9 +739,10 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
// If input and output are the same, request in-place operation.
if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
- sendPtr = rewriter.create<LLVM::ConstantOp>(
- loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
- sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
+ sendPtr = LLVM::ConstantOp::create(
+ rewriter, loc, i64,
+ reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+ sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
}
Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
@@ -757,8 +759,8 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
// replace op with function call
- auto funcCall = rewriter.create<LLVM::CallOp>(
- loc, funcDecl,
+ auto funcCall = LLVM::CallOp::create(
+ rewriter, loc, funcDecl,
ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
if (op.getRetval())