aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/LLVMCommon/Pattern.cpp')
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp99
1 files changed, 49 insertions, 50 deletions
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index c5f72f7..2568044 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -57,8 +57,8 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
Location loc,
Type resultType,
int64_t value) {
- return builder.create<LLVM::ConstantOp>(loc, resultType,
- builder.getIndexAttr(value));
+ return LLVM::ConstantOp::create(builder, loc, resultType,
+ builder.getIndexAttr(value));
}
Value ConvertToLLVMPattern::getStridedElementPtr(
@@ -123,7 +123,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
runningStride = sizes[i];
else if (stride == ShapedType::kDynamic)
runningStride =
- rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
+ LLVM::MulOp::create(rewriter, loc, runningStride, sizes[i]);
else
runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
}
@@ -131,10 +131,10 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
// Buffer size in bytes.
Type elementType = typeConverter->convertType(memRefType.getElementType());
auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(
- loc, elementPtrType, elementType, nullPtr, runningStride);
- size = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+ Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType);
+ Value gepPtr = LLVM::GEPOp::create(rewriter, loc, elementPtrType,
+ elementType, nullPtr, runningStride);
+ size = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr);
} else {
size = runningStride;
}
@@ -149,10 +149,10 @@ Value ConvertToLLVMPattern::getSizeInBytes(
// which is a common pattern of getting the size of a type in bytes.
Type llvmType = typeConverter->convertType(type);
auto convertedPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, convertedPtrType);
- auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType,
- nullPtr, ArrayRef<LLVM::GEPArg>{1});
- return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
+ auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, convertedPtrType);
+ auto gep = LLVM::GEPOp::create(rewriter, loc, convertedPtrType, llvmType,
+ nullPtr, ArrayRef<LLVM::GEPArg>{1});
+ return LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gep);
}
Value ConvertToLLVMPattern::getNumElements(
@@ -175,7 +175,7 @@ Value ConvertToLLVMPattern::getNumElements(
staticSize == ShapedType::kDynamic
? dynamicSizes[dynamicIndex++]
: createIndexAttrConstant(rewriter, loc, indexType, staticSize);
- numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
+ numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
} else {
numElements =
staticSize == ShapedType::kDynamic
@@ -272,18 +272,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
- toDynamic
- ? builder
- .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
- .getResult()
- : builder.create<LLVM::AllocaOp>(loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
+ toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
- builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
+ LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
- builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
+ LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
@@ -349,8 +348,8 @@ LogicalResult LLVM::detail::oneToOneRewrite(
SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- op->getLoc(), newOp->getResult(0), i));
+ results.push_back(LLVM::ExtractValueOp::create(rewriter, op->getLoc(),
+ newOp->getResult(0), i));
}
rewriter.replaceOp(op, results);
return success();
@@ -371,8 +370,8 @@ LogicalResult LLVM::detail::intrinsicRewrite(
if (numResults != 0)
resType = typeConverter.packOperationResults(op->getResultTypes());
- auto callIntrOp = rewriter.create<LLVM::CallIntrinsicOp>(
- loc, resType, rewriter.getStringAttr(intrinsic), operands);
+ auto callIntrOp = LLVM::CallIntrinsicOp::create(
+ rewriter, loc, resType, rewriter.getStringAttr(intrinsic), operands);
// Propagate attributes.
callIntrOp->setAttrs(op->getAttrDictionary());
@@ -388,7 +387,7 @@ LogicalResult LLVM::detail::intrinsicRewrite(
results.reserve(numResults);
Value intrRes = callIntrOp.getResults();
for (unsigned i = 0; i < numResults; ++i)
- results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
+ results.push_back(LLVM::ExtractValueOp::create(rewriter, loc, intrRes, i));
rewriter.replaceOp(op, results);
return success();
@@ -406,7 +405,7 @@ static unsigned getBitWidth(Type type) {
static Value createI32Constant(OpBuilder &builder, Location loc,
int32_t value) {
Type i32 = builder.getI32Type();
- return builder.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(builder, loc, i32, value);
}
SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
@@ -418,17 +417,17 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
unsigned srcBitWidth = getBitWidth(srcType);
unsigned dstBitWidth = getBitWidth(dstType);
if (srcBitWidth == dstBitWidth) {
- Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
+ Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
return {cast};
}
if (dstBitWidth > srcBitWidth) {
auto smallerInt = builder.getIntegerType(srcBitWidth);
if (srcType != smallerInt)
- src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
+ src = LLVM::BitcastOp::create(builder, loc, smallerInt, src);
auto largerInt = builder.getIntegerType(dstBitWidth);
- Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
+ Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
return {res};
}
assert(srcBitWidth % dstBitWidth == 0 &&
@@ -436,12 +435,12 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
int64_t numElements = srcBitWidth / dstBitWidth;
auto vecType = VectorType::get(numElements, dstType);
- src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
+ src = LLVM::BitcastOp::create(builder, loc, vecType, src);
SmallVector<Value> res;
for (auto i : llvm::seq(numElements)) {
Value idx = createI32Constant(builder, loc, i);
- Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
+ Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
res.emplace_back(elem);
}
@@ -461,28 +460,28 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
if (dstBitWidth < srcBitWidth) {
auto largerInt = builder.getIntegerType(srcBitWidth);
if (res.getType() != largerInt)
- res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
+ res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
auto smallerInt = builder.getIntegerType(dstBitWidth);
- res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
+ res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
}
if (res.getType() != dstType)
- res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
return res;
}
int64_t numElements = src.size();
auto srcType = VectorType::get(numElements, src.front().getType());
- Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
+ Value res = LLVM::PoisonOp::create(builder, loc, srcType);
for (auto &&[i, elem] : llvm::enumerate(src)) {
Value idx = createI32Constant(builder, loc, i);
- res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
+ res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
}
if (res.getType() != dstType)
- res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
return res;
}
@@ -518,20 +517,20 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
Value stride =
ShapedType::isDynamic(strides[i])
? memRefDescriptor.stride(builder, loc, i)
- : builder.create<LLVM::ConstantOp>(
- loc, indexType, builder.getIndexAttr(strides[i]));
- increment =
- builder.create<LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
+ : LLVM::ConstantOp::create(builder, loc, indexType,
+ builder.getIndexAttr(strides[i]));
+ increment = LLVM::MulOp::create(builder, loc, increment, stride,
+ intOverflowFlags);
}
- index = index ? builder.create<LLVM::AddOp>(loc, index, increment,
- intOverflowFlags)
+ index = index ? LLVM::AddOp::create(builder, loc, index, increment,
+ intOverflowFlags)
: increment;
}
Type elementPtrType = memRefDescriptor.getElementPtrType();
- return index ? builder.create<LLVM::GEPOp>(
- loc, elementPtrType,
- converter.convertType(type.getElementType()), base, index,
- noWrapFlags)
- : base;
+ return index
+ ? LLVM::GEPOp::create(builder, loc, elementPtrType,
+ converter.convertType(type.getElementType()),
+ base, index, noWrapFlags)
+ : base;
}