aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2025-03-08 12:21:15 +0100
committerMatthias Springer <mspringer@nvidia.com>2025-04-01 01:52:53 +0200
commit4e7246ae3ac7166b40432828dcdc7123dffaadd6 (patch)
tree49dbf45b8abb6d88071ffdfddcf5776846d867a0
parent799e9053641a6478d3144866a97737b37b87c260 (diff)
downloadllvm-users/matthias-springer/memref_1_to_n.zip
llvm-users/matthias-springer/memref_1_to_n.tar.gz
llvm-users/matthias-springer/memref_1_to_n.tar.bz2
update some more code update update update update some progress update update more improements
-rw-r--r--mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h57
-rw-r--r--mlir/include/mlir/Conversion/LLVMCommon/Pattern.h2
-rw-r--r--mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h20
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp5
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp40
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp80
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp304
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp24
-rw-r--r--mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp263
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp45
-rw-r--r--mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp224
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp2
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp422
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp6
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp77
15 files changed, 669 insertions, 902 deletions
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index d5055f0..119106e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -30,13 +30,13 @@ class LLVMPointerType;
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
/// The Value may be null, in which case none of the operations are valid.
-class MemRefDescriptor : public StructBuilder {
+class MemRefDescriptor {
public:
/// Construct a helper for the given descriptor value.
- explicit MemRefDescriptor(Value descriptor);
+ explicit MemRefDescriptor(ValueRange elements);
/// Builds IR creating a `poison` value of the descriptor type.
static MemRefDescriptor poison(OpBuilder &builder, Location loc,
- Type descriptorType);
+ TypeRange descriptorTypes);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
@@ -49,6 +49,11 @@ public:
const LLVMTypeConverter &typeConverter, MemRefType type,
Value memory, Value alignedMemory);
+ /// Builds IR extracting individual elements of a MemRef descriptor structure
+ /// and returning them as `results` list.
+ static MemRefDescriptor fromPackedStruct(OpBuilder &builder, Location loc,
+ Value packed);
+
/// Builds IR extracting the allocated pointer from the descriptor.
Value allocatedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the allocated pointer into the descriptor.
@@ -98,6 +103,8 @@ public:
Value bufferPtr(OpBuilder &builder, Location loc,
const LLVMTypeConverter &converter, MemRefType type);
+ int64_t getRank();
+
/// Builds IR populating a MemRef descriptor structure from a list of
/// individual values composing that descriptor, in the following order:
/// - allocated pointer;
@@ -106,20 +113,21 @@ public:
/// - <rank> sizes;
/// - <rank> strides;
/// where <rank> is the MemRef rank as provided in `type`.
- static Value pack(OpBuilder &builder, Location loc,
- const LLVMTypeConverter &converter, MemRefType type,
- ValueRange values);
-
- /// Builds IR extracting individual elements of a MemRef descriptor structure
- /// and returning them as `results` list.
- static void unpack(OpBuilder &builder, Location loc, Value packed,
- MemRefType type, SmallVectorImpl<Value> &results);
+ Value packStruct(OpBuilder &builder, Location loc);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
static unsigned getNumUnpackedValues(MemRefType type);
+ ValueRange getElements() { return elements; }
+
+ /*implicit*/ operator ValueRange() { return elements; }
+
private:
+ SmallVector<Value> elements;
+ // Value allocatedPtrVal, alignedPtrVal, offsetVal;
+ // SmallVector<Value> sizeVals, strideVals;
+
// Cached index type.
Type indexType;
};
@@ -155,13 +163,18 @@ private:
ValueRange elements;
};
-class UnrankedMemRefDescriptor : public StructBuilder {
+class UnrankedMemRefDescriptor {
public:
/// Construct a helper for the given descriptor value.
- explicit UnrankedMemRefDescriptor(Value descriptor);
+ explicit UnrankedMemRefDescriptor(ValueRange elements);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc,
- Type descriptorType);
+ TypeRange descriptorType);
+
+ /// Builds IR extracting individual elements of a MemRef descriptor structure
+ /// and returning them as `results` list.
+ static UnrankedMemRefDescriptor fromPackedStruct(OpBuilder &builder,
+ Location loc, Value packed);
/// Builds IR extracting the rank from the descriptor
Value rank(OpBuilder &builder, Location loc) const;
@@ -176,14 +189,7 @@ public:
/// of individual constituent values in the following order:
/// - rank of the memref;
/// - pointer to the memref descriptor.
- static Value pack(OpBuilder &builder, Location loc,
- const LLVMTypeConverter &converter, UnrankedMemRefType type,
- ValueRange values);
-
- /// Builds IR extracting individual elements that compose an unranked memref
- /// descriptor and returns them as `results` list.
- static void unpack(OpBuilder &builder, Location loc, Value packed,
- SmallVectorImpl<Value> &results);
+ Value packStruct(OpBuilder &builder, Location loc);
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
@@ -269,6 +275,13 @@ public:
static void setStride(OpBuilder &builder, Location loc,
const LLVMTypeConverter &typeConverter,
Value strideBasePtr, Value index, Value stride);
+
+ ValueRange getElements() { return elements; }
+
+ /*implicit*/ operator ValueRange() { return elements; }
+
+private:
+ SmallVector<Value> elements;
};
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index e78f174..2d743a9 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -76,7 +76,7 @@ protected:
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
- Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
+ Value getStridedElementPtr(Location loc, MemRefType type, ValueRange memRefDesc,
ValueRange indices,
ConversionPatternRewriter &rewriter) const;
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 38b5e49..a65f136 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -91,6 +91,17 @@ public:
Type convertCallingConventionType(Type type,
bool useBarePointerCallConv = false) const;
+ /// Convert a memref type into an LLVM type that captures the relevant data.
+ LogicalResult convertMemRefType(MemRefType type,
+ SmallVectorImpl<Type> &result,
+ bool packed = false) const;
+
+ /// Convert an unranked memref type to an LLVM type that captures the
+ /// runtime rank and a pointer to the static ranked memref desc
+ LogicalResult convertUnrankedMemRefType(UnrankedMemRefType type,
+ SmallVectorImpl<Type> &result,
+ bool packed = false) const;
+
/// Promote the bare pointers in 'values' that resulted from memrefs to
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
@@ -111,7 +122,7 @@ public:
/// of the platform-specific C/C++ ABI lowering related to struct argument
/// passing.
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
+ ArrayRef<ValueRange> operands, OpBuilder &builder,
bool useBarePtrCallConv = false) const;
/// Promote the LLVM struct representation of one MemRef descriptor to stack
@@ -245,13 +256,6 @@ private:
/// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
Type convertComplexType(ComplexType type) const;
- /// Convert a memref type into an LLVM type that captures the relevant data.
- Type convertMemRefType(MemRefType type) const;
-
- /// Convert an unranked memref type to an LLVM type that captures the
- /// runtime rank and a pointer to the static ranked memref desc
- Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
-
/// Convert a memref type to a bare pointer to the memref element type.
Type convertMemRefToBarePtr(BaseMemRefType type) const;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470..e5f70c4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -178,6 +178,7 @@ struct FatRawBufferCastLowering
LogicalResult
matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ /*
Location loc = op.getLoc();
Value memRef = adaptor.getSource();
Value unconvertedMemref = op.getSource();
@@ -222,7 +223,7 @@ struct FatRawBufferCastLowering
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
- chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
+ chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=7);
Value result = MemRefDescriptor::poison(
rewriter, loc,
@@ -241,6 +242,8 @@ struct FatRawBufferCastLowering
}
rewriter.replaceOp(op, result);
return success();
+ */
+ return failure();
}
};
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd00..bc6613d 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,24 +125,35 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
return rewriter.applySignatureConversion(block, *conversion, converter);
}
+static SmallVector<Value> flattenValueRanges(ArrayRef<ValueRange> ranges) {
+ SmallVector<Value> result;
+ for (ValueRange range : ranges)
+ llvm::append_range(result, range);
+ return result;
+}
+
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedOperands =
+ flattenValueRanges(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
- TypeRange(adaptor.getOperands()));
+ TypeRange(ValueRange(flattenedOperands)));
if (failed(convertedBlock))
return failure();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- op, adaptor.getOperands(), *convertedBlock);
+ op, flattenedOperands, *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
- newOp->setAttrs(op->getAttrDictionary());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
return success();
}
};
@@ -151,28 +162,33 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::CondBranchOp op,
- typename cf::CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedTrueDestOperands =
+ flattenValueRanges(adaptor.getTrueDestOperands());
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
- TypeRange(adaptor.getTrueDestOperands()));
+ TypeRange(ValueRange(flattenedTrueDestOperands)));
if (failed(convertedTrueBlock))
return failure();
+ SmallVector<Value> flattenedFalseDestOperands =
+ flattenValueRanges(adaptor.getFalseDestOperands());
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
- TypeRange(adaptor.getFalseDestOperands()));
+ TypeRange(ValueRange(flattenedFalseDestOperands)));
if (failed(convertedFalseBlock))
return failure();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, adaptor.getCondition(), *convertedTrueBlock,
- adaptor.getTrueDestOperands(), *convertedFalseBlock,
- adaptor.getFalseDestOperands());
+ op, llvm::getSingleElement(adaptor.getCondition()), *convertedTrueBlock,
+ flattenedTrueDestOperands, *convertedFalseBlock,
+ flattenedFalseDestOperands);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
- newOp->setAttrs(op->getAttrDictionary());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
return success();
}
};
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 55f0a9a..c5c0817 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -140,15 +140,23 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
Value arg = wrapperFuncOp.getArgument(index + argOffset);
if (auto memrefType = dyn_cast<MemRefType>(argType)) {
+ SmallVector<Type> convertedType;
+ LogicalResult status = typeConverter.convertMemRefType(memrefType, convertedType, /*packed=*/true);
+ (void)status;
+ assert(succeeded(status) && "failed to convert memref type");
Value loaded = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter.convertType(memrefType), arg);
- MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
+ loc, llvm::getSingleElement(convertedType), arg);
+ llvm::append_range(args, MemRefDescriptor::fromPackedStruct(rewriter, loc, loaded).getElements());
continue;
}
- if (isa<UnrankedMemRefType>(argType)) {
+ if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(argType)) {
+ SmallVector<Type> convertedType;
+ LogicalResult status = typeConverter.convertUnrankedMemRefType(unrankedMemrefType, convertedType, /*packed=*/true);
+ (void)status;
+ assert(succeeded(status) && "failed to convert memref type");
Value loaded = rewriter.create<LLVM::LoadOp>(
- loc, typeConverter.convertType(argType), arg);
- UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
+ loc, llvm::getSingleElement(convertedType), arg);
+ llvm::append_range(args, UnrankedMemRefDescriptor::fromPackedStruct(rewriter, loc, loaded).getElements());
continue;
}
@@ -231,14 +239,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
numToDrop = memRefType
? MemRefDescriptor::getNumUnpackedValues(memRefType)
: UnrankedMemRefDescriptor::getNumUnpackedValues();
- Value packed =
- memRefType
- ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
- wrapperArgsRange.take_front(numToDrop))
- : UnrankedMemRefDescriptor::pack(
- builder, loc, typeConverter, unrankedMemRefType,
- wrapperArgsRange.take_front(numToDrop));
-
+ Value packed;
+ if (memRefType) {
+ packed = MemRefDescriptor(wrapperArgsRange.take_front(numToDrop)).packStruct(builder, loc);
+ } else {
+ packed = UnrankedMemRefDescriptor(wrapperArgsRange.take_front(numToDrop)).packStruct(builder, loc);
+ }
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
@@ -515,9 +521,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
LogicalResult matchAndRewriteImpl(CallOpType callOp,
- typename CallOpType::Adaptor adaptor,
+ Adaptor adaptor,
ConversionPatternRewriter &rewriter,
bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
@@ -579,7 +586,18 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
return failure();
}
- rewriter.replaceOp(callOp, results);
+ SmallVector<SmallVector<Value>> unpackedResults;
+ for (auto it : llvm::zip_equal(resultTypes, results)) {
+ SmallVector<Value> &result = unpackedResults.emplace_back();
+ if (isa<MemRefType>(std::get<0>(it))) {
+ llvm::append_range(result, MemRefDescriptor::fromPackedStruct(rewriter, callOp.getLoc(), std::get<1>(it)).getElements());
+ } else if (isa<UnrankedMemRefType>(std::get<0>(it))) {
+ llvm::append_range(result, UnrankedMemRefDescriptor::fromPackedStruct(rewriter, callOp.getLoc(), std::get<1>(it)).getElements());
+ } else {
+ result.push_back(std::get<1>(it));
+ }
+ }
+ rewriter.replaceOpWithMultiple(callOp, unpackedResults);
return success();
}
};
@@ -593,7 +611,7 @@ public:
symbolTable(symbolTable) {}
LogicalResult
- matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -623,7 +641,7 @@ struct CallIndirectOpLowering
using Super::Super;
LogicalResult
- matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
}
@@ -666,7 +684,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
unsigned numArguments = op.getNumOperands();
@@ -680,20 +698,36 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
// be returned from the memref descriptor.
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
Type oldTy = std::get<0>(it).getType();
- Value newOperand = std::get<1>(it);
+ ValueRange adaptorVal = std::get<1>(it);
if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
cast<BaseMemRefType>(oldTy))) {
- MemRefDescriptor memrefDesc(newOperand);
- newOperand = memrefDesc.allocatedPtr(rewriter, loc);
+ MemRefDescriptor memrefDesc(adaptorVal);
+ updatedOperands.push_back( memrefDesc.allocatedPtr(rewriter, loc));
} else if (isa<UnrankedMemRefType>(oldTy)) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
+ } else {
+ assert(adaptorVal.size() == 1 && "1:N conversion not supported for non-memref types");
+ updatedOperands.push_back(adaptorVal.front());
}
- updatedOperands.push_back(newOperand);
}
} else {
- updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
+ // Pack operands.
+ for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
+ Value operand = std::get<0>(it);
+ ValueRange adaptorVal = std::get<1>(it);
+ if (isa<MemRefType>(operand.getType())) {
+ MemRefDescriptor memrefDesc(adaptorVal);
+ updatedOperands.push_back(memrefDesc.packStruct(rewriter, loc));
+ } else if (isa<UnrankedMemRefType>(operand.getType())) {
+ UnrankedMemRefDescriptor unrankedMemrefDesc(adaptorVal);
+ updatedOperands.push_back(unrankedMemrefDesc.packStruct(rewriter, loc));
+ } else {
+ assert(adaptorVal.size() == 1 && "1:N conversion not supported for non-memref types");
+ updatedOperands.push_back(adaptorVal.front());
+ }
+ }
(void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
updatedOperands,
/*toDynamic=*/true);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f22ad1f..79bd1582 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -76,310 +76,8 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- Location loc = gpuFuncOp.getLoc();
-
- SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
- if (encodeWorkgroupAttributionsAsArguments) {
- // Append an `llvm.ptr` argument to the function signature to encode
- // workgroup attributions.
-
- ArrayRef<BlockArgument> workgroupAttributions =
- gpuFuncOp.getWorkgroupAttributions();
- size_t numAttributions = workgroupAttributions.size();
-
- // Insert all arguments at the end.
- unsigned index = gpuFuncOp.getNumArguments();
- SmallVector<unsigned> argIndices(numAttributions, index);
-
- // New arguments will simply be `llvm.ptr` with the correct address space
- Type workgroupPtrType =
- rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
- SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
-
- // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
- std::array attrs{
- rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
- rewriter.getUnitAttr()),
- rewriter.getNamedAttr(
- getDialect().getWorkgroupAttributionAttrHelper().getName(),
- rewriter.getUnitAttr()),
- };
- SmallVector<DictionaryAttr> argAttrs;
- for (BlockArgument attribution : workgroupAttributions) {
- auto attributionType = cast<MemRefType>(attribution.getType());
- IntegerAttr numElements =
- rewriter.getI64IntegerAttr(attributionType.getNumElements());
- Type llvmElementType =
- getTypeConverter()->convertType(attributionType.getElementType());
- if (!llvmElementType)
- return failure();
- TypeAttr type = TypeAttr::get(llvmElementType);
- attrs.back().setValue(
- rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
- argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
- }
- // Location match function location
- SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());
-
- // Perform signature modification
- rewriter.modifyOpInPlace(
- gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
- static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
- argIndices, argTypes, argAttrs, argLocs);
- });
- } else {
- workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
- for (auto [idx, attribution] :
- llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
- auto type = dyn_cast<MemRefType>(attribution.getType());
- assert(type && type.hasStaticShape() && "unexpected type in attribution");
-
- uint64_t numElements = type.getNumElements();
-
- auto elementType =
- cast<Type>(typeConverter->convertType(type.getElementType()));
- auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
- std::string name =
- std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
- uint64_t alignment = 0;
- if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
- gpuFuncOp.getWorkgroupAttributionAttr(
- idx, LLVM::LLVMDialect::getAlignAttrName())))
- alignment = alignAttr.getInt();
- auto globalOp = rewriter.create<LLVM::GlobalOp>(
- gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
- LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
- workgroupAddrSpace);
- workgroupBuffers.push_back(globalOp);
- }
- }
-
- // Remap proper input types.
- TypeConverter::SignatureConversion signatureConversion(
- gpuFuncOp.front().getNumArguments());
-
- Type funcType = getTypeConverter()->convertFunctionSignature(
- gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
- getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
- if (!funcType) {
- return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
- diag << "failed to convert function signature type for: "
- << gpuFuncOp.getFunctionType();
- });
- }
-
- // Create the new function operation. Only copy those attributes that are
- // not specific to function modeling.
- SmallVector<NamedAttribute, 4> attributes;
- ArrayAttr argAttrs;
- for (const auto &attr : gpuFuncOp->getAttrs()) {
- if (attr.getName() == SymbolTable::getSymbolAttrName() ||
- attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
- attr.getName() ==
- gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
- attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
- attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
- attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
- attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
- continue;
- if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
- argAttrs = gpuFuncOp.getArgAttrsAttr();
- continue;
- }
- attributes.push_back(attr);
- }
-
- DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
- DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
- // Ensure we don't lose information if the function is lowered before its
- // surrounding context.
- auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
- if (knownBlockSize)
- attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
- knownBlockSize);
- if (knownGridSize)
- attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
- knownGridSize);
-
- // Add a dialect specific kernel attribute in addition to GPU kernel
- // attribute. The former is necessary for further translation while the
- // latter is expected by gpu.launch_func.
- if (gpuFuncOp.isKernel()) {
- if (kernelAttributeName)
- attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
- // Set the dialect-specific block size attribute if there is one.
- if (kernelBlockSizeAttributeName && knownBlockSize) {
- attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
- }
- }
- LLVM::CConv callingConvention = gpuFuncOp.isKernel()
- ? kernelCallingConvention
- : nonKernelCallingConvention;
- auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
- LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
- /*comdat=*/nullptr, attributes);
-
- {
- // Insert operations that correspond to converted workgroup and private
- // memory attributions to the body of the function. This must operate on
- // the original function, before the body region is inlined in the new
- // function to maintain the relation between block arguments and the
- // parent operation that assigns their semantics.
- OpBuilder::InsertionGuard guard(rewriter);
-
- // Rewrite workgroup memory attributions to addresses of global buffers.
- rewriter.setInsertionPointToStart(&gpuFuncOp.front());
- unsigned numProperArguments = gpuFuncOp.getNumArguments();
-
- if (encodeWorkgroupAttributionsAsArguments) {
- // Build a MemRefDescriptor with each of the arguments added above.
-
- unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
- assert(numProperArguments >= numAttributions &&
- "Expecting attributions to be encoded as arguments already");
-
- // Arguments encoding workgroup attributions will be in positions
- // [numProperArguments, numProperArguments+numAttributions)
- ArrayRef<BlockArgument> attributionArguments =
- gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
- numAttributions);
- for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
- gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
- auto [attribution, arg] = vals;
- auto type = cast<MemRefType>(attribution.getType());
-
- // Arguments are of llvm.ptr type and attributions are of memref type:
- // we need to wrap them in memref descriptors.
- Value descr = MemRefDescriptor::fromStaticShape(
- rewriter, loc, *getTypeConverter(), type, arg);
-
- // And remap the arguments
- signatureConversion.remapInput(numProperArguments + idx, descr);
- }
- } else {
- for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
- auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
- global.getAddrSpace());
- Value address = rewriter.create<LLVM::AddressOfOp>(
- loc, ptrType, global.getSymNameAttr());
- Value memory =
- rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
- address, ArrayRef<LLVM::GEPArg>{0, 0});
-
- // Build a memref descriptor pointing to the buffer to plug with the
- // existing memref infrastructure. This may use more registers than
- // otherwise necessary given that memref sizes are fixed, but we can try
- // and canonicalize that away later.
- Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
- auto type = cast<MemRefType>(attribution.getType());
- Value descr = MemRefDescriptor::fromStaticShape(
- rewriter, loc, *getTypeConverter(), type, memory);
- signatureConversion.remapInput(numProperArguments + idx, descr);
- }
- }
-
- // Rewrite private memory attributions to alloca'ed buffers.
- unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
- auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
- for (const auto [idx, attribution] :
- llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
- auto type = cast<MemRefType>(attribution.getType());
- assert(type && type.hasStaticShape() && "unexpected type in attribution");
-
- // Explicitly drop memory space when lowering private memory
- // attributions since NVVM models it as `alloca`s in the default
- // memory space and does not support `alloca`s with addrspace(5).
- Type elementType = typeConverter->convertType(type.getElementType());
- auto ptrType =
- LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
- Value numElements = rewriter.create<LLVM::ConstantOp>(
- gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
- uint64_t alignment = 0;
- if (auto alignAttr =
- dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
- idx, LLVM::LLVMDialect::getAlignAttrName())))
- alignment = alignAttr.getInt();
- Value allocated = rewriter.create<LLVM::AllocaOp>(
- gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
- Value descr = MemRefDescriptor::fromStaticShape(
- rewriter, loc, *getTypeConverter(), type, allocated);
- signatureConversion.remapInput(
- numProperArguments + numWorkgroupAttributions + idx, descr);
- }
- }
-
- // Move the region to the new function, update the entry block signature.
- rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
- llvmFuncOp.end());
- if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
- &signatureConversion)))
- return failure();
-
- // Get memref type from function arguments and set the noalias to
- // pointer arguments.
- for (const auto [idx, argTy] :
- llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
- auto remapping = signatureConversion.getInputMapping(idx);
- NamedAttrList argAttr =
- argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
- auto copyAttribute = [&](StringRef attrName) {
- Attribute attr = argAttr.erase(attrName);
- if (!attr)
- return;
- for (size_t i = 0, e = remapping->size; i < e; ++i)
- llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
- };
- auto copyPointerAttribute = [&](StringRef attrName) {
- Attribute attr = argAttr.erase(attrName);
-
- if (!attr)
- return;
- if (remapping->size > 1 &&
- attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
- emitWarning(llvmFuncOp.getLoc(),
- "Cannot copy noalias with non-bare pointers.\n");
- return;
- }
- for (size_t i = 0, e = remapping->size; i < e; ++i) {
- if (isa<LLVM::LLVMPointerType>(
- llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
- llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
- }
- }
- };
-
- if (argAttr.empty())
- continue;
-
- copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
- copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
- copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
- bool lowersToPointer = false;
- for (size_t i = 0, e = remapping->size; i < e; ++i) {
- lowersToPointer |= isa<LLVM::LLVMPointerType>(
- llvmFuncOp.getArgument(remapping->inputNo + i).getType());
- }
-
- if (lowersToPointer) {
- copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
- copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
- copyPointerAttribute(
- LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
- copyPointerAttribute(
- LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
- }
- }
- rewriter.eraseOp(gpuFuncOp);
- return success();
+ return failure();
}
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 512820b..f0b1602 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -723,8 +723,10 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
- auto arguments = getTypeConverter()->promoteOperands(
- loc, op->getOperands(), adaptor.getOperands(), rewriter);
+ llvm_unreachable("TODO");
+ SmallVector<Value> arguments;
+ //auto arguments = getTypeConverter()->promoteOperands(
+ // loc, op->getOperands(), adaptor.getOperands(), rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);
@@ -745,8 +747,10 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
- auto arguments = getTypeConverter()->promoteOperands(
- loc, op->getOperands(), adaptor.getOperands(), rewriter);
+ llvm_unreachable("TODO");
+ SmallVector<Value> arguments;
+ //auto arguments = getTypeConverter()->promoteOperands(
+ // loc, op->getOperands(), adaptor.getOperands(), rewriter);
arguments.push_back(elementSize);
hostUnregisterCallBuilder.create(loc, rewriter, arguments);
@@ -805,9 +809,9 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
if (allocOp.getAsyncToken()) {
// Async alloc: make dependent ops use the same stream.
- rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
+ //rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
} else {
- rewriter.replaceOp(allocOp, {memRefDescriptor});
+ //rewriter.replaceOp(allocOp, {memRefDescriptor});
}
return success();
@@ -977,9 +981,11 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
// Note: If `useBarePtrCallConv` is set in the type converter's options,
// the value of `kernelBarePtrCallConv` will be ignored.
OperandRange origArguments = launchOp.getKernelOperands();
- SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
- loc, origArguments, adaptor.getKernelOperands(), rewriter,
- /*useBarePtrCallConv=*/kernelBarePtrCallConv);
+ llvm_unreachable("TODO");
+ SmallVector<Value,8> llvmArguments;
+ //SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
+ // loc, origArguments, adaptor.getKernelOperands(), rewriter,
+ // /*useBarePtrCallConv=*/kernelBarePtrCallConv);
SmallVector<Value, 8> llvmArgumentsWithSizes;
// Intersperse size information if requested.
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 86d6643..9f8030a 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -21,19 +21,23 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
/// Construct a helper for the given descriptor value.
-MemRefDescriptor::MemRefDescriptor(Value descriptor)
- : StructBuilder(descriptor) {
- assert(value != nullptr && "value cannot be null");
- indexType = cast<LLVM::LLVMStructType>(value.getType())
- .getBody()[kOffsetPosInMemRefDescriptor];
+MemRefDescriptor::MemRefDescriptor(ValueRange elements) : elements(elements) {
+ indexType = elements[kOffsetPosInMemRefDescriptor].getType();
}
/// Builds IR creating an `undef` value of the descriptor type.
MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc,
- Type descriptorType) {
-
- Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
- return MemRefDescriptor(descriptor);
+ TypeRange descriptorTypes) {
+ DenseMap<Type, Value> poisonValues;
+ SmallVector<Value> elements;
+ for (Type t : descriptorTypes) {
+ auto it = poisonValues.find(t);
+ if (it == poisonValues.end()) {
+ poisonValues[t] = builder.create<LLVM::PoisonOp>(loc, t);
+ }
+ elements.push_back(poisonValues[t]);
+ }
+ return MemRefDescriptor(elements);
}
/// Builds IR creating a MemRef descriptor that represents `type` and
@@ -57,10 +61,11 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
"expected static strides");
- auto convertedType = typeConverter.convertType(type);
- assert(convertedType && "unexpected failure in memref type conversion");
+ SmallVector<Type> convertedTypes;
+ LogicalResult status = typeConverter.convertType(type, convertedTypes);
+ assert(succeeded(status) && "unexpected failure in memref type conversion");
- auto descr = MemRefDescriptor::poison(builder, loc, convertedType);
+ auto descr = MemRefDescriptor::poison(builder, loc, convertedTypes);
descr.setAllocatedPtr(builder, loc, memory);
descr.setAlignedPtr(builder, loc, alignedMemory);
descr.setConstantOffset(builder, loc, offset);
@@ -73,26 +78,81 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
return descr;
}
+static Value extractStructElement(OpBuilder &builder, Location loc,
+ Value packed, ArrayRef<int64_t> idx) {
+ return builder.create<LLVM::ExtractValueOp>(loc, packed, idx);
+}
+
+static Value insertStructElement(OpBuilder &builder, Location loc, Value packed,
+ Value val, ArrayRef<int64_t> idx) {
+ return builder.create<LLVM::InsertValueOp>(loc, packed, val, idx);
+}
+MemRefDescriptor MemRefDescriptor::fromPackedStruct(OpBuilder &builder,
+ Location loc,
+ Value packed) {
+ auto llvmStruct = cast<LLVM::LLVMStructType>(packed.getType());
+ SmallVector<Value> elements;
+ elements.push_back(extractStructElement(builder, loc, packed, 0));
+ elements.push_back(extractStructElement(builder, loc, packed, 1));
+ elements.push_back(extractStructElement(builder, loc, packed, 2));
+ if (llvmStruct.getBody().size() > 3) {
+ auto llvmArray = cast<LLVM::LLVMArrayType>(llvmStruct.getBody()[3]);
+ int64_t rank = llvmArray.getNumElements();
+ for (int i = 0; i < rank; ++i)
+ elements.push_back(extractStructElement(builder, loc, packed, {3, i}));
+ for (int i = 0; i < rank; ++i)
+ elements.push_back(extractStructElement(builder, loc, packed, {4, i}));
+ }
+ return MemRefDescriptor(elements);
+}
+
+Value MemRefDescriptor::packStruct(OpBuilder &builder, Location loc) {
+ Type offsetStrideTy = elements[2].getType();
+ SmallVector<Type> fields;
+ fields.push_back(elements[0].getType());
+ fields.push_back(elements[1].getType());
+ fields.push_back(offsetStrideTy);
+ if (getRank() > 0) {
+ auto llvmArray = LLVM::LLVMArrayType::get(builder.getContext(),
+ offsetStrideTy, getRank());
+ fields.push_back(llvmArray);
+ fields.push_back(llvmArray);
+ }
+ Value desc = builder.create<LLVM::UndefOp>(
+ loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), fields));
+ desc = insertStructElement(builder, loc, desc, elements[0], 0);
+ desc = insertStructElement(builder, loc, desc, elements[1], 1);
+ desc = insertStructElement(builder, loc, desc, elements[2], 2);
+ if(getRank() > 0) {
+ for (int i = 0; i < getRank(); ++i)
+ desc = insertStructElement(builder, loc, desc, elements[3 + i], {3, i});
+ for (int i = 0; i < getRank(); ++i)
+ desc = insertStructElement(builder, loc, desc, elements[3 + getRank() + i],
+ {4, i});
+ }
+ return desc;
+}
+
/// Builds IR extracting the allocated pointer from the descriptor.
Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
- return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
+ return elements[kAllocatedPtrPosInMemRefDescriptor];
}
/// Builds IR inserting the allocated pointer into the descriptor.
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value ptr) {
- setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
+ elements[kAllocatedPtrPosInMemRefDescriptor] = ptr;
}
/// Builds IR extracting the aligned pointer from the descriptor.
Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
- return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
+ return elements[kAlignedPtrPosInMemRefDescriptor];
}
/// Builds IR inserting the aligned pointer into the descriptor.
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
Value ptr) {
- setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
+ elements[kAlignedPtrPosInMemRefDescriptor] = ptr;
}
// Creates a constant Op producing a value of `resultType` from an index-typed
@@ -105,28 +165,25 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
/// Builds IR extracting the offset from the descriptor.
Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
- return builder.create<LLVM::ExtractValueOp>(loc, value,
- kOffsetPosInMemRefDescriptor);
+ return elements[kOffsetPosInMemRefDescriptor];
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value offset) {
- value = builder.create<LLVM::InsertValueOp>(loc, value, offset,
- kOffsetPosInMemRefDescriptor);
+ elements[kOffsetPosInMemRefDescriptor] = offset;
}
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
uint64_t offset) {
- setOffset(builder, loc,
- createIndexAttrConstant(builder, loc, indexType, offset));
+ elements[kOffsetPosInMemRefDescriptor] =
+ createIndexAttrConstant(builder, loc, indexType, offset);
}
/// Builds IR extracting the pos-th size from the descriptor.
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+ return elements[kSizePosInMemRefDescriptor + pos];
}
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
@@ -137,8 +194,14 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
// Copy size values to stack-allocated memory.
auto one = createIndexAttrConstant(builder, loc, indexType, 1);
- auto sizes = builder.create<LLVM::ExtractValueOp>(
- loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor}));
+ SmallVector<Type> structElems(rank, indexType);
+ Value sizes = builder.create<LLVM::UndefOp>(
+ loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), structElems));
+ ValueRange sizeVals =
+ ValueRange(elements).slice(kSizePosInMemRefDescriptor, rank);
+ for (auto it : llvm::enumerate(sizeVals))
+ sizes =
+ builder.create<LLVM::InsertValueOp>(loc, sizes, it.value(), it.index());
auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one,
/*alignment=*/0);
builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
@@ -152,40 +215,35 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value size) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+ elements[kSizePosInMemRefDescriptor + pos] = size;
}
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
unsigned pos, uint64_t size) {
- setSize(builder, loc, pos,
- createIndexAttrConstant(builder, loc, indexType, size));
+ elements[kSizePosInMemRefDescriptor + pos] =
+ createIndexAttrConstant(builder, loc, indexType, size);
}
/// Builds IR extracting the pos-th stride from the descriptor.
Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
- return builder.create<LLVM::ExtractValueOp>(
- loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
+ return elements[kSizePosInMemRefDescriptor + getRank() + pos];
}
/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
Value stride) {
- value = builder.create<LLVM::InsertValueOp>(
- loc, value, stride,
- ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
+ elements[kSizePosInMemRefDescriptor + getRank() + pos] = stride;
}
void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
unsigned pos, uint64_t stride) {
- setStride(builder, loc, pos,
- createIndexAttrConstant(builder, loc, indexType, stride));
+ elements[kSizePosInMemRefDescriptor + getRank() + pos] =
+ createIndexAttrConstant(builder, loc, indexType, stride);
}
LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
return cast<LLVM::LLVMPointerType>(
- cast<LLVM::LLVMStructType>(value.getType())
- .getBody()[kAlignedPtrPosInMemRefDescriptor]);
+ elements[kAlignedPtrPosInMemRefDescriptor].getType());
}
Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
@@ -212,51 +270,6 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
return ptr;
}
-/// Creates a MemRef descriptor structure from a list of individual values
-/// composing that descriptor, in the following order:
-/// - allocated pointer;
-/// - aligned pointer;
-/// - offset;
-/// - <rank> sizes;
-/// - <rank> strides;
-/// where <rank> is the MemRef rank as provided in `type`.
-Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
- const LLVMTypeConverter &converter,
- MemRefType type, ValueRange values) {
- Type llvmType = converter.convertType(type);
- auto d = MemRefDescriptor::poison(builder, loc, llvmType);
-
- d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
- d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
- d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
-
- int64_t rank = type.getRank();
- for (unsigned i = 0; i < rank; ++i) {
- d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
- d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
- }
-
- return d;
-}
-
-/// Builds IR extracting individual elements of a MemRef descriptor structure
-/// and returning them as `results` list.
-void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
- MemRefType type,
- SmallVectorImpl<Value> &results) {
- int64_t rank = type.getRank();
- results.reserve(results.size() + getNumUnpackedValues(type));
-
- MemRefDescriptor d(packed);
- results.push_back(d.allocatedPtr(builder, loc));
- results.push_back(d.alignedPtr(builder, loc));
- results.push_back(d.offset(builder, loc));
- for (int64_t i = 0; i < rank; ++i)
- results.push_back(d.size(builder, loc, i));
- for (int64_t i = 0; i < rank; ++i)
- results.push_back(d.stride(builder, loc, i));
-}
-
/// Returns the number of non-aggregate values that would be produced by
/// `unpack`.
unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
@@ -264,6 +277,8 @@ unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
return 3 + 2 * type.getRank();
}
+int64_t MemRefDescriptor::getRank() { return (elements.size() - 3) / 2; }
+
//===----------------------------------------------------------------------===//
// MemRefDescriptorView implementation.
//===----------------------------------------------------------------------===//
@@ -296,57 +311,61 @@ Value MemRefDescriptorView::stride(unsigned pos) {
//===----------------------------------------------------------------------===//
/// Construct a helper for the given descriptor value.
-UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
- : StructBuilder(descriptor) {}
+UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(ValueRange elements)
+ : elements(elements) {}
/// Builds IR creating an `undef` value of the descriptor type.
-UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder,
- Location loc,
- Type descriptorType) {
- Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
- return UnrankedMemRefDescriptor(descriptor);
+UnrankedMemRefDescriptor
+UnrankedMemRefDescriptor::poison(OpBuilder &builder, Location loc,
+ TypeRange descriptorTypes) {
+ DenseMap<Type, Value> poisonValues;
+ SmallVector<Value> elements;
+ for (Type t : descriptorTypes) {
+ auto it = poisonValues.find(t);
+ if (it == poisonValues.end()) {
+ poisonValues[t] = builder.create<LLVM::PoisonOp>(loc, t);
+ }
+ elements.push_back(poisonValues[t]);
+ }
+ return UnrankedMemRefDescriptor(elements);
+}
+
+/// Builds IR extracting individual elements of a MemRef descriptor structure
+/// and returning them as `results` list.
+UnrankedMemRefDescriptor
+UnrankedMemRefDescriptor::fromPackedStruct(OpBuilder &builder, Location loc,
+ Value packed) {
+ SmallVector<Value> elements;
+ elements.push_back(extractStructElement(builder, loc, packed, 0));
+ elements.push_back(extractStructElement(builder, loc, packed, 1));
+ return UnrankedMemRefDescriptor(elements);
+}
+
+Value UnrankedMemRefDescriptor::packStruct(OpBuilder &builder, Location loc) {
+ SmallVector<Type> fields;
+ fields.push_back(elements[0].getType());
+ fields.push_back(elements[1].getType());
+ Value desc = builder.create<LLVM::UndefOp>(
+ loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), fields));
+ desc = insertStructElement(builder, loc, desc, elements[0], 0);
+ desc = insertStructElement(builder, loc, desc, elements[1], 1);
+ return desc;
}
+
Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
- return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
+ return elements[kRankInUnrankedMemRefDescriptor];
}
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
Value v) {
- setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
+ elements[kRankInUnrankedMemRefDescriptor] = v;
}
Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
Location loc) const {
- return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
+ return elements[kPtrInUnrankedMemRefDescriptor];
}
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
Location loc, Value v) {
- setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
-}
-
-/// Builds IR populating an unranked MemRef descriptor structure from a list
-/// of individual constituent values in the following order:
-/// - rank of the memref;
-/// - pointer to the memref descriptor.
-Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
- const LLVMTypeConverter &converter,
- UnrankedMemRefType type,
- ValueRange values) {
- Type llvmType = converter.convertType(type);
- auto d = UnrankedMemRefDescriptor::poison(builder, loc, llvmType);
-
- d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
- d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
- return d;
-}
-
-/// Builds IR extracting individual elements that compose an unranked memref
-/// descriptor and returns them as `results` list.
-void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
- Value packed,
- SmallVectorImpl<Value> &results) {
- UnrankedMemRefDescriptor d(packed);
- results.reserve(results.size() + 2);
- results.push_back(d.rank(builder, loc));
- results.push_back(d.memRefDescPtr(builder, loc));
+ elements[kPtrInUnrankedMemRefDescriptor] = v;
}
void UnrankedMemRefDescriptor::computeSizes(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 71b68619..c5af470 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -59,7 +59,7 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
}
Value ConvertToLLVMPattern::getStridedElementPtr(
- Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
+ Location loc, MemRefType type, ValueRange memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
auto [strides, offset] = type.getStridesAndOffset();
@@ -217,34 +217,20 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
ArrayRef<Value> sizes, ArrayRef<Value> strides,
ConversionPatternRewriter &rewriter) const {
- auto structType = typeConverter->convertType(memRefType);
- auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);
-
- // Field 1: Allocated pointer, used for malloc/free.
- memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
-
- // Field 2: Actual aligned pointer to payload.
- memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
-
- // Field 3: Offset in aligned pointer.
+ SmallVector<Value> elements;
+ elements.push_back(allocatedPtr);
+ elements.push_back(alignedPtr);
Type indexType = getIndexType();
- memRefDescriptor.setOffset(
- rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
-
- // Fields 4: Sizes.
- for (const auto &en : llvm::enumerate(sizes))
- memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
-
- // Field 5: Strides.
- for (const auto &en : llvm::enumerate(strides))
- memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
-
- return memRefDescriptor;
+ elements.push_back(createIndexAttrConstant(rewriter, loc, indexType, 0));
+ llvm::append_range(elements, sizes);
+ llvm::append_range(elements, strides);
+ return MemRefDescriptor(elements);
}
LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
OpBuilder &builder, Location loc, TypeRange origTypes,
SmallVectorImpl<Value> &operands, bool toDynamic) const {
+ // TODO: Pass unpacked structs to this function.
assert(origTypes.size() == operands.size() &&
"expected as may original types as operands");
@@ -253,7 +239,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
SmallVector<unsigned> unrankedAddressSpaces;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
- unrankedMemrefs.emplace_back(operands[i]);
+ unrankedMemrefs.push_back(UnrankedMemRefDescriptor::fromPackedStruct(builder, loc, operands[i]));
FailureOr<unsigned> addressSpace =
getTypeConverter()->getMemRefAddressSpace(memRefType);
if (failed(addressSpace))
@@ -294,7 +280,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (!isa<UnrankedMemRefType>(type))
continue;
Value allocationSize = sizes[unrankedMemrefPos++];
- UnrankedMemRefDescriptor desc(operands[i]);
+ UnrankedMemRefDescriptor desc = UnrankedMemRefDescriptor::fromPackedStruct(builder, loc, operands[i]);
// Allocate memory, copy, and free the source if necessary.
Value memory =
@@ -315,16 +301,15 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// times, attempting to modify its pointer can lead to memory leaks
// (allocated twice and overwritten) or double frees (the caller does not
// know if the descriptor points to the same memory).
- Type descriptorType = getTypeConverter()->convertType(type);
- if (!descriptorType)
+ SmallVector<Type> descriptorTypes;
+ if (failed(getTypeConverter()->convertType(type, descriptorTypes)))
return failure();
auto updatedDesc =
- UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
+ UnrankedMemRefDescriptor::poison(builder, loc, descriptorTypes);
Value rank = desc.rank(builder, loc);
updatedDesc.setRank(builder, loc, rank);
updatedDesc.setMemRefDescPtr(builder, loc, memory);
-
- operands[i] = updatedDesc;
+ operands[i] = updatedDesc.packStruct(builder, loc);
}
return success();
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4..2113bd3 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -50,68 +50,6 @@ static bool isBarePointer(ValueRange values) {
isa<LLVM::LLVMPointerType>(values.front().getType());
}
-/// Pack SSA values into an unranked memref descriptor struct.
-static Value packUnrankedMemRefDesc(OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc,
- const LLVMTypeConverter &converter) {
- // Note: Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
- if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
- return Value();
- return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
- inputs);
-}
-
-/// Pack SSA values into a ranked memref descriptor struct.
-static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc,
- const LLVMTypeConverter &converter) {
- assert(resultType && "expected non-null result type");
- if (isBarePointer(inputs))
- return MemRefDescriptor::fromStaticShape(builder, loc, converter,
- resultType, inputs[0]);
- if (TypeRange(inputs) ==
- converter.getMemRefDescriptorFields(resultType,
- /*unpackAggregates=*/true))
- return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
- // The inputs are neither a bare pointer nor an unpacked memref descriptor.
- // This materialization function cannot be used.
- return Value();
-}
-
-/// MemRef descriptor elements -> UnrankedMemRefType
-static Value unrankedMemRefMaterialization(OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc,
- const LLVMTypeConverter &converter) {
- // A source materialization must return a value of type
- // `resultType`, so insert a cast from the memref descriptor type
- // (!llvm.struct) to the original memref type.
- Value packed =
- packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
- if (!packed)
- return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
- .getResult(0);
-}
-
-/// MemRef descriptor elements -> MemRefType
-static Value rankedMemRefMaterialization(OpBuilder &builder,
- MemRefType resultType,
- ValueRange inputs, Location loc,
- const LLVMTypeConverter &converter) {
- // A source materialization must return a value of type `resultType`,
- // so insert a cast from the memref descriptor type (!llvm.struct) to the
- // original memref type.
- Value packed =
- packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
- if (!packed)
- return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
- .getResult(0);
-}
-
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const LowerToLLVMOptions &options,
@@ -126,9 +64,22 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addConversion([&](FunctionType type) { return convertFunctionType(type); });
addConversion([&](IndexType type) { return convertIndexType(type); });
addConversion([&](IntegerType type) { return convertIntegerType(type); });
- addConversion([&](MemRefType type) { return convertMemRefType(type); });
addConversion(
- [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
+ [&](MemRefType type,
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+ LogicalResult status = convertMemRefType(type, result);
+ if (failed(status))
+ return std::nullopt;
+ return success();
+ });
+ addConversion(
+ [&](UnrankedMemRefType type,
+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+ LogicalResult status = convertUnrankedMemRefType(type, result);
+ if (failed(status))
+ return std::nullopt;
+ return success();
+ });
addConversion([&](VectorType type) -> std::optional<Type> {
FailureOr<Type> llvmType = convertVectorType(type);
if (failed(llvmType))
@@ -228,42 +179,26 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
- addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) {
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
- .getResult(0);
+ addTargetMaterialization([&](OpBuilder &builder, TypeRange resultTypes,
+ ValueRange inputs,
+ Location loc) -> SmallVector<Value> {
+ auto castOp =
+ builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
+ return llvm::map_to_vector(castOp.getResults(),
+ [](OpResult r) -> Value { return r; });
});
- // Source materializations convert from the new block argument types
- // (multiple SSA values that make up a memref descriptor) back to the
- // original block argument type.
- addSourceMaterialization([&](OpBuilder &builder,
- UnrankedMemRefType resultType, ValueRange inputs,
- Location loc) {
- return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
- *this);
- });
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
- return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
- });
-
- // Bare pointer -> Packed MemRef descriptor
- addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc,
- Type originalType) -> Value {
- // The original MemRef type is required to build a MemRef descriptor
- // because the sizes/strides of the MemRef cannot be inferred from just the
- // bare pointer.
- if (!originalType)
- return Value();
- if (resultType != convertType(originalType))
- return Value();
- if (auto memrefType = dyn_cast<MemRefType>(originalType))
- return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
- if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
- return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
- *this);
+ if (isBarePointer(inputs)) {
+ MemRefDescriptor desc = MemRefDescriptor::fromStaticShape(
+ builder, loc, *this, resultType, inputs[0]);
+ return builder
+ .create<UnrealizedConversionCastOp>(loc, resultType,
+ desc.getElements())
+ .getResult(0);
+ }
+ // Default materialization creates unrealized_conversion_cast.
return Value();
});
@@ -430,8 +365,10 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
Type resultType = type.getNumResults() == 0
? LLVM::LLVMVoidType::get(&getContext())
: packFunctionResults(type.getResults());
- if (!resultType)
+ if (!resultType) {
+ llvm_unreachable("no result type!");
return {};
+ }
auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
@@ -443,9 +380,11 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
}
for (Type t : type.getInputs()) {
- auto converted = convertType(t);
- if (!converted || !LLVM::isCompatibleType(converted))
+ auto converted = convertCallingConventionType(t);
+ if (!converted || !LLVM::isCompatibleType(converted)) {
+ llvm_unreachable("could not convert input!");
return {};
+ }
if (isa<MemRefType, UnrankedMemRefType>(t))
converted = ptrType;
inputs.push_back(converted);
@@ -533,14 +472,18 @@ LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
-Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
- // When converting a MemRefType to a struct with descriptor fields, do not
- // unpack the `sizes` and `strides` arrays.
- SmallVector<Type, 5> types =
- getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
- if (types.empty())
- return {};
- return LLVM::LLVMStructType::getLiteral(&getContext(), types);
+LogicalResult LLVMTypeConverter::convertMemRefType(
+ MemRefType type, SmallVectorImpl<Type> &result, bool packed) const {
+ SmallVector<Type, 5> fields =
+ getMemRefDescriptorFields(type, /*unpackAggregates=*/!packed);
+ if (fields.empty())
+ return failure();
+ if (packed) {
+ result.push_back(LLVM::LLVMStructType::getLiteral(&getContext(), fields));
+ } else {
+ llvm::append_range(result, fields);
+ }
+ return success();
}
/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
@@ -563,12 +506,17 @@ unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
llvm::divideCeil(getPointerBitwidth(space), 8);
}
-Type LLVMTypeConverter::convertUnrankedMemRefType(
- UnrankedMemRefType type) const {
+LogicalResult LLVMTypeConverter::convertUnrankedMemRefType(
+ UnrankedMemRefType type, SmallVectorImpl<Type> &result, bool packed) const {
if (!convertType(type.getElementType()))
- return {};
- return LLVM::LLVMStructType::getLiteral(&getContext(),
- getUnrankedMemRefDescriptorFields());
+ return failure();
+ if (packed) {
+ result.push_back(LLVM::LLVMStructType::getLiteral(
+ &getContext(), getUnrankedMemRefDescriptorFields()));
+ } else {
+ llvm::append_range(result, getUnrankedMemRefDescriptorFields());
+ }
+ return success();
}
FailureOr<unsigned>
@@ -665,6 +613,20 @@ Type LLVMTypeConverter::convertCallingConventionType(
if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
return convertMemRefToBarePtr(memrefTy);
+ if (auto memrefTy = dyn_cast<MemRefType>(type)) {
+ SmallVector<Type> convertedType;
+ LogicalResult status = convertMemRefType(memrefTy, convertedType, true);
+ if (failed(status)) return Type();
+ return llvm::getSingleElement(convertedType);
+ }
+
+ if (auto unrankedMemrefTy = dyn_cast<UnrankedMemRefType>(type)) {
+ SmallVector<Type> convertedType;
+ LogicalResult status = convertUnrankedMemRefType(unrankedMemrefTy, convertedType, true);
+ if (failed(status)) return Type();
+ return llvm::getSingleElement(convertedType);
+ }
+
return convertType(type);
}
@@ -674,12 +636,15 @@ Type LLVMTypeConverter::convertCallingConventionType(
void LLVMTypeConverter::promoteBarePtrsToDescriptors(
ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
SmallVectorImpl<Value> &values) const {
- assert(stdTypes.size() == values.size() &&
- "The number of types and values doesn't match");
- for (unsigned i = 0, end = values.size(); i < end; ++i)
- if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
- values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
- memrefTy, values[i]);
+ /*
+ assert(stdTypes.size() == values.size() &&
+ "The number of types and values doesn't match");
+ for (unsigned i = 0, end = values.size(); i < end; ++i)
+ if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
+ values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
+ memrefTy, values[i]);
+ */
+ llvm_unreachable("not implemented");
}
/// Convert a non-empty list of types of values produced by an operation into an
@@ -743,38 +708,27 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
SmallVector<Value, 4>
LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
+ ArrayRef<ValueRange> operands, OpBuilder &builder,
bool useBarePtrCallConv) const {
SmallVector<Value, 4> promotedOperands;
promotedOperands.reserve(operands.size());
useBarePtrCallConv |= options.useBarePtrCallConv;
for (auto it : llvm::zip(opOperands, operands)) {
auto operand = std::get<0>(it);
- auto llvmOperand = std::get<1>(it);
+ auto llvmOperands = std::get<1>(it);
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (dyn_cast<MemRefType>(operand.getType())) {
- MemRefDescriptor desc(llvmOperand);
- llvmOperand = desc.alignedPtr(builder, loc);
+ MemRefDescriptor desc(llvmOperands);
+ promotedOperands.push_back(desc.alignedPtr(builder, loc));
+ continue;
} else if (isa<UnrankedMemRefType>(operand.getType())) {
llvm_unreachable("Unranked memrefs are not supported");
}
- } else {
- if (isa<UnrankedMemRefType>(operand.getType())) {
- UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
- promotedOperands);
- continue;
- }
- if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
- promotedOperands);
- continue;
- }
}
-
- promotedOperands.push_back(llvmOperand);
+ llvm::append_range(promotedOperands, llvmOperands);
}
return promotedOperands;
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index c5b2e83..c072723 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -195,6 +195,6 @@ LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite(
loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
// Return the final value of the descriptor.
- rewriter.replaceOp(op, {memRefDescriptor});
+ rewriter.replaceOpWithMultiple(op, {memRefDescriptor});
return success();
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cb4317e..a12507b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/MathExtras.h"
#include <optional>
@@ -185,15 +186,14 @@ struct AssumeAlignmentOpLowering
: ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
LogicalResult
- matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::AssumeAlignmentOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Value memref = adaptor.getMemref();
unsigned alignment = op.getAlignment();
auto loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
- Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
- rewriter);
+ Value ptr = getStridedElementPtr(loc, srcMemRefType, adaptor.getMemref(),
+ /*indices=*/{}, rewriter);
// Emit llvm.assume(true) ["align"(memref, alignment)].
// This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -220,7 +220,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
: ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
LogicalResult
- matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::DeallocOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
FailureOr<LLVM::LLVMFuncOp> freeFunc =
@@ -253,21 +253,20 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::DimOp dimOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type operandType = dimOp.getSource().getType();
if (isa<UnrankedMemRefType>(operandType)) {
- FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
- operandType, dimOp, adaptor.getOperands(), rewriter);
+ FailureOr<Value> extractedSize =
+ extractSizeOfUnrankedMemRef(operandType, dimOp, adaptor, rewriter);
if (failed(extractedSize))
return failure();
rewriter.replaceOp(dimOp, {*extractedSize});
return success();
}
if (isa<MemRefType>(operandType)) {
- rewriter.replaceOp(
- dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
- adaptor.getOperands(), rewriter)});
+ rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
+ adaptor, rewriter)});
return success();
}
llvm_unreachable("expected MemRefType or UnrankedMemRefType");
@@ -276,7 +275,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
private:
FailureOr<Value>
extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
- OpAdaptor adaptor,
+ OneToNOpAdaptor &adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
@@ -298,20 +297,24 @@ private:
UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
- Type elementType = typeConverter->convertType(scalarMemRefType);
+ SmallVector<Type> convertedMemRefType;
+ if (failed(static_cast<const LLVMTypeConverter *>(typeConverter)
+ ->convertMemRefType(scalarMemRefType, convertedMemRefType,
+ /*packed=*/true)))
+ return failure();
// Get pointer to offset field of memref<element_type> descriptor.
auto indexPtrTy =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
Value offsetPtr = rewriter.create<LLVM::GEPOp>(
- loc, indexPtrTy, elementType, underlyingRankedDesc,
- ArrayRef<LLVM::GEPArg>{0, 2});
+ loc, indexPtrTy, llvm::getSingleElement(convertedMemRefType),
+ underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
// The size value that we have to extract can be obtained using GEPop with
// `dimOp.index() + 1` index argument.
Value idxPlusOne = rewriter.create<LLVM::AddOp>(
loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
- adaptor.getIndex());
+ llvm::getSingleElement(adaptor.getIndex()));
Value sizePtr = rewriter.create<LLVM::GEPOp>(
loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
idxPlusOne);
@@ -331,7 +334,7 @@ private:
}
Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
- OpAdaptor adaptor,
+ OneToNOpAdaptor &adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
@@ -351,7 +354,7 @@ private:
return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
}
}
- Value index = adaptor.getIndex();
+ Value index = llvm::getSingleElement(adaptor.getIndex());
int64_t rank = memRefType.getRank();
MemRefDescriptor memrefDescriptor(adaptor.getSource());
return memrefDescriptor.size(rewriter, loc, index, rank);
@@ -400,7 +403,7 @@ struct GenericAtomicRMWOpLowering
using Base::Base;
LogicalResult
- matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = atomicOp.getLoc();
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
@@ -416,8 +419,12 @@ struct GenericAtomicRMWOpLowering
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
+ SmallVector<Value> indices =
+ llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+ return llvm::getSingleElement(r);
+ });
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ indices, rewriter);
Value init = rewriter.create<LLVM::LoadOp>(
loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -579,13 +586,15 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::LoadOp loadOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = loadOp.getMemRefType();
-
- Value dataPtr =
- getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ SmallVector<Value> indices =
+ llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+ return llvm::getSingleElement(r);
+ });
+ Value dataPtr = getStridedElementPtr(
+ loadOp.getLoc(), type, adaptor.getMemref(), indices, rewriter);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
false, loadOp.getNontemporal());
@@ -599,14 +608,18 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::StoreOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
-
+ SmallVector<Value> indices =
+ llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+ return llvm::getSingleElement(r);
+ });
Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
- 0, false, op.getNontemporal());
+ indices, rewriter);
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
+ op, llvm::getSingleElement(adaptor.getValue()), dataPtr, 0, false,
+ op.getNontemporal());
return success();
}
};
@@ -617,13 +630,16 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::PrefetchOp prefetchOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = prefetchOp.getMemRefType();
auto loc = prefetchOp.getLoc();
-
- Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ SmallVector<Value> indices =
+ llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+ return llvm::getSingleElement(r);
+ });
+ Value dataPtr =
+ getStridedElementPtr(loc, type, adaptor.getMemref(), indices, rewriter);
// Replace with llvm.prefetch.
IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -640,7 +656,7 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::RankOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type operandType = op.getMemref().getType();
@@ -664,8 +680,9 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::CastOp memRefCastOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto loc = memRefCastOp.getLoc();
Type srcType = memRefCastOp.getOperand().getType();
Type dstType = memRefCastOp.getType();
@@ -674,21 +691,21 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
// and require source and result type to have the same rank. Therefore,
// perform a sanity check that the underlying structs are the same. Once op
// semantics are relaxed we can revisit.
+ SmallVector<Type> convertedSrc, convertedDst;
+ if (failed(typeConverter->convertType(srcType, convertedSrc)) ||
+ failed(typeConverter->convertType(dstType, convertedDst)))
+ return failure();
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
- if (typeConverter->convertType(srcType) !=
- typeConverter->convertType(dstType))
+ if (!llvm::equal(convertedSrc, convertedDst))
return failure();
// Unranked to unranked cast is disallowed
if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
return failure();
- auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
- auto loc = memRefCastOp.getLoc();
-
// For ranked/ranked case, just keep the original descriptor.
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
- rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
+ rewriter.replaceOpWithMultiple(memRefCastOp, {adaptor.getSource()});
return success();
}
@@ -701,19 +718,20 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
int64_t rank = srcMemRefType.getRank();
// ptr = AllocaOp sizeof(MemRefDescriptor)
auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
- loc, adaptor.getSource(), rewriter);
+ loc, MemRefDescriptor(adaptor.getSource()).packStruct(rewriter, loc),
+ rewriter);
// rank = ConstantOp srcRank
auto rankVal = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(rank));
// poison = PoisonOp
UnrankedMemRefDescriptor memRefDesc =
- UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
+ UnrankedMemRefDescriptor::poison(rewriter, loc, convertedDst);
// d1 = InsertValueOp poison, rank, 0
memRefDesc.setRank(rewriter, loc, rankVal);
// d2 = InsertValueOp d1, ptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
- rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
+ rewriter.replaceOpWithMultiple(memRefCastOp, {memRefDesc});
} else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
// Casting from unranked type to ranked.
@@ -722,10 +740,16 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
// ptr = ExtractValueOp src, 1
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
-
// struct = LoadOp ptr
- auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
- rewriter.replaceOp(memRefCastOp, loadOp.getResult());
+ SmallVector<Type> targetStructType;
+ if (failed(getTypeConverter()->convertMemRefType(
+ cast<MemRefType>(dstType), targetStructType, /*packed=*/true)))
+ return failure();
+ auto loadOp = rewriter.create<LLVM::LoadOp>(
+ loc, llvm::getSingleElement(targetStructType), ptr);
+ rewriter.replaceOpWithMultiple(memRefCastOp,
+ {MemRefDescriptor::fromPackedStruct(
+ rewriter, loc, loadOp.getResult())});
} else {
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
}
@@ -743,7 +767,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
LogicalResult
- lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
+ lowerToMemCopyIntrinsic(memref::CopyOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
@@ -782,74 +806,75 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
return success();
}
- LogicalResult
- lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- auto loc = op.getLoc();
- auto srcType = cast<BaseMemRefType>(op.getSource().getType());
- auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
-
- // First make sure we have an unranked memref descriptor representation.
- auto makeUnranked = [&, this](Value ranked, MemRefType type) {
- auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- type.getRank());
- auto *typeConverter = getTypeConverter();
- auto ptr =
- typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
-
- auto unrankedType =
- UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
- return UnrankedMemRefDescriptor::pack(
- rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
- };
-
- // Save stack position before promoting descriptors
- auto stackSaveOp =
- rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
-
- auto srcMemRefType = dyn_cast<MemRefType>(srcType);
- Value unrankedSource =
- srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
- : adaptor.getSource();
- auto targetMemRefType = dyn_cast<MemRefType>(targetType);
- Value unrankedTarget =
- targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
- : adaptor.getTarget();
-
- // Now promote the unranked descriptors to the stack.
- auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
- rewriter.getIndexAttr(1));
- auto promote = [&](Value desc) {
- auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
- auto allocated =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
- rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
- return allocated;
- };
-
- auto sourcePtr = promote(unrankedSource);
- auto targetPtr = promote(unrankedTarget);
-
- // Derive size from llvm.getelementptr which will account for any
- // potential alignment
- auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
- auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
- op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
- if (failed(copyFn))
- return failure();
- rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
- ValueRange{elemSize, sourcePtr, targetPtr});
-
- // Restore stack used for descriptors
- rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+ /*
+ LogicalResult
+ lowerToMemCopyFunctionCall(memref::CopyOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto loc = op.getLoc();
+ auto srcType = cast<BaseMemRefType>(op.getSource().getType());
+ auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
+
+ // First make sure we have an unranked memref descriptor representation.
+ auto makeUnranked = [&, this](Value ranked, MemRefType type) {
+ auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+ type.getRank());
+ auto *typeConverter = getTypeConverter();
+ auto ptr =
+ typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
+
+ auto unrankedType =
+ UnrankedMemRefType::get(type.getElementType(),
+ type.getMemorySpace()); return UnrankedMemRefDescriptor::pack( rewriter,
+ loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
+ };
+
+ // Save stack position before promoting descriptors
+ auto stackSaveOp =
+ rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
+
+ auto srcMemRefType = dyn_cast<MemRefType>(srcType);
+ Value unrankedSource =
+ srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
+ : adaptor.getSource();
+ auto targetMemRefType = dyn_cast<MemRefType>(targetType);
+ Value unrankedTarget =
+ targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
+ : adaptor.getTarget();
+
+ // Now promote the unranked descriptors to the stack.
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+ rewriter.getIndexAttr(1));
+ auto promote = [&](Value desc) {
+ auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+ auto allocated =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
+ rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
+ return allocated;
+ };
+
+ auto sourcePtr = promote(unrankedSource);
+ auto targetPtr = promote(unrankedTarget);
+
+ // Derive size from llvm.getelementptr which will account for any
+ // potential alignment
+ auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
+ auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
+ op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
+ if (failed(copyFn))
+ return failure();
+ rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
+ ValueRange{elemSize, sourcePtr, targetPtr});
- rewriter.eraseOp(op);
+ // Restore stack used for descriptors
+ rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
- return success();
- }
+ rewriter.eraseOp(op);
+ return success();
+ }
+ */
LogicalResult
- matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::CopyOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = cast<BaseMemRefType>(op.getSource().getType());
auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
@@ -868,7 +893,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
- return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
+ return failure();
+ // return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
}
};
@@ -878,26 +904,23 @@ struct MemorySpaceCastOpLowering
memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::MemorySpaceCastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
-
Type resultType = op.getDest().getType();
+ SmallVector<Type> convertedResultTypes;
+ if (failed(typeConverter->convertType(resultType, convertedResultTypes)))
+ return failure();
+
if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
- auto resultDescType =
- cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
- Type newPtrType = resultDescType.getBody()[0];
+ Type newPtrType = convertedResultTypes[0];
- SmallVector<Value> descVals;
- MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
- descVals);
+ SmallVector<Value> descVals = llvm::to_vector(adaptor.getSource());
descVals[0] =
rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
descVals[1] =
rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
- Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
- resultTypeR, descVals);
- rewriter.replaceOp(op, result);
+ rewriter.replaceOpWithMultiple(op, {descVals});
return success();
}
if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
@@ -922,8 +945,8 @@ struct MemorySpaceCastOpLowering
Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
// Create and allocate storage for new memref descriptor.
- auto result = UnrankedMemRefDescriptor::poison(
- rewriter, loc, typeConverter->convertType(resultTypeU));
+ auto result =
+ UnrankedMemRefDescriptor::poison(rewriter, loc, convertedResultTypes);
result.setRank(rewriter, loc, rank);
SmallVector<Value, 1> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
@@ -972,7 +995,7 @@ struct MemorySpaceCastOpLowering
rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
copySize, /*isVolatile=*/false);
- rewriter.replaceOp(op, ValueRange{result});
+ rewriter.replaceOpWithMultiple(op, ValueRange{result});
return success();
}
return rewriter.notifyMatchFailure(loc, "unexpected memref type");
@@ -986,7 +1009,7 @@ static void extractPointersAndOffset(Location loc,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter,
Value originalOperand,
- Value convertedOperand,
+ ValueRange convertedOperand,
Value *allocatedPtr, Value *alignedPtr,
Value *offset = nullptr) {
Type operandType = originalOperand.getType();
@@ -1026,33 +1049,32 @@ struct MemRefReinterpretCastOpLowering
memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::ReinterpretCastOp castOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = castOp.getSource().getType();
- Value descriptor;
+ SmallVector<Value> descriptor;
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
adaptor, &descriptor)))
return failure();
- rewriter.replaceOp(castOp, {descriptor});
+ rewriter.replaceOpWithMultiple(castOp, {descriptor});
return success();
}
private:
LogicalResult convertSourceMemRefToDescriptor(
ConversionPatternRewriter &rewriter, Type srcType,
- memref::ReinterpretCastOp castOp,
- memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
+ memref::ReinterpretCastOp castOp, OneToNOpAdaptor adaptor,
+ SmallVector<Value> *descriptor) const {
MemRefType targetMemRefType =
cast<MemRefType>(castOp.getResult().getType());
- auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
- typeConverter->convertType(targetMemRefType));
- if (!llvmTargetDescriptorTy)
+ SmallVector<Type> convertedTypes;
+ if (failed(typeConverter->convertType(targetMemRefType, convertedTypes)))
return failure();
// Create descriptor.
Location loc = castOp.getLoc();
- auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
+ auto desc = MemRefDescriptor::poison(rewriter, loc, convertedTypes);
// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
@@ -1064,7 +1086,8 @@ private:
// Set offset.
if (castOp.isDynamicOffset(0))
- desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
+ desc.setOffset(rewriter, loc,
+ llvm::getSingleElement(adaptor.getOffsets()[0]));
else
desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
@@ -1073,16 +1096,19 @@ private:
unsigned dynStrideId = 0;
for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
if (castOp.isDynamicSize(i))
- desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
+ desc.setSize(rewriter, loc, i,
+ llvm::getSingleElement(adaptor.getSizes()[dynSizeId++]));
else
desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
if (castOp.isDynamicStride(i))
- desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
+ desc.setStride(
+ rewriter, loc, i,
+ llvm::getSingleElement(adaptor.getStrides()[dynStrideId++]));
else
desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
}
- *descriptor = desc;
+ llvm::append_range(*descriptor, desc.getElements());
return success();
}
};
@@ -1092,15 +1118,15 @@ struct MemRefReshapeOpLowering
using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::ReshapeOp reshapeOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = reshapeOp.getSource().getType();
- Value descriptor;
+ SmallVector<Value> descriptor;
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
adaptor, &descriptor)))
return failure();
- rewriter.replaceOp(reshapeOp, {descriptor});
+ rewriter.replaceOpWithMultiple(reshapeOp, {descriptor});
return success();
}
@@ -1108,21 +1134,19 @@ private:
LogicalResult
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
Type srcType, memref::ReshapeOp reshapeOp,
- memref::ReshapeOp::Adaptor adaptor,
- Value *descriptor) const {
+ OneToNOpAdaptor adaptor,
+ SmallVector<Value> *descriptor) const {
auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
if (shapeMemRefType.hasStaticShape()) {
MemRefType targetMemRefType =
cast<MemRefType>(reshapeOp.getResult().getType());
- auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
- typeConverter->convertType(targetMemRefType));
- if (!llvmTargetDescriptorTy)
+ SmallVector<Type> convertedTypes;
+ if (failed(typeConverter->convertType(targetMemRefType, convertedTypes)))
return failure();
// Create descriptor.
Location loc = reshapeOp.getLoc();
- auto desc =
- MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
+ auto desc = MemRefDescriptor::poison(rewriter, loc, convertedTypes);
// Set allocated and aligned pointers.
Value allocatedPtr, alignedPtr;
@@ -1188,7 +1212,7 @@ private:
stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
}
- *descriptor = desc;
+ llvm::append_range(*descriptor, desc.getElements());
return success();
}
@@ -1204,8 +1228,11 @@ private:
// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
+ SmallVector<Type> convertedTypes;
+ if (failed(typeConverter->convertType(targetType, convertedTypes)))
+ return failure();
auto targetDesc = UnrankedMemRefDescriptor::poison(
- rewriter, loc, typeConverter->convertType(targetType));
+ rewriter, loc, convertedTypes);
targetDesc.setRank(rewriter, loc, resultRank);
SmallVector<Value, 4> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
@@ -1303,7 +1330,7 @@ private:
// Reset position to beginning of new remainder block.
rewriter.setInsertionPointToStart(remainder);
- *descriptor = targetDesc;
+ llvm::append_range(*descriptor, targetDesc.getElements());
return success();
}
};
@@ -1315,10 +1342,11 @@ class ReassociatingReshapeOpConversion
: public ConvertOpToLLVMPattern<ReshapeOp> {
public:
using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
- using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
+ using ReshapeOpAdaptor =
+ typename ConvertOpToLLVMPattern<ReshapeOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
+ matchAndRewrite(ReshapeOp reshapeOp, ReshapeOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(
reshapeOp,
@@ -1332,7 +1360,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::SubViewOp subViewOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(
subViewOp, "subview operations should have been expanded beforehand");
@@ -1351,7 +1379,7 @@ public:
using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::TransposeOp transposeOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = transposeOp.getLoc();
MemRefDescriptor viewMemRef(adaptor.getIn());
@@ -1360,9 +1388,11 @@ public:
if (transposeOp.getPermutation().isIdentity())
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
- auto targetMemRef = MemRefDescriptor::poison(
- rewriter, loc,
- typeConverter->convertType(transposeOp.getIn().getType()));
+ SmallVector<Type> convertedTypes;
+ if (failed(typeConverter->convertType(transposeOp.getIn().getType(),
+ convertedTypes)))
+ return failure();
+ auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, convertedTypes);
// Copy the base and aligned pointers from the old descriptor to the new
// one.
@@ -1388,7 +1418,7 @@ public:
viewMemRef.stride(rewriter, loc, sourcePos));
}
- rewriter.replaceOp(transposeOp, {targetMemRef});
+ rewriter.replaceOpWithMultiple(transposeOp, {targetMemRef});
return success();
}
};
@@ -1434,17 +1464,19 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
}
LogicalResult
- matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::ViewOp viewOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = viewOp.getLoc();
auto viewMemRefType = viewOp.getType();
auto targetElementTy =
typeConverter->convertType(viewMemRefType.getElementType());
- auto targetDescTy = typeConverter->convertType(viewMemRefType);
- if (!targetDescTy || !targetElementTy ||
- !LLVM::isCompatibleType(targetElementTy) ||
- !LLVM::isCompatibleType(targetDescTy))
+ SmallVector<Type> targetDescTy;
+ if (failed(typeConverter->convertType(viewMemRefType, targetDescTy)))
+ return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
+ failure();
+ // TODO: Check targetDescTy is LLVM compatible.
+ if (!targetElementTy || !LLVM::isCompatibleType(targetElementTy))
return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
failure();
@@ -1475,7 +1507,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
alignedPtr = rewriter.create<LLVM::GEPOp>(
loc, alignedPtr.getType(),
typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
- adaptor.getByteShift());
+ llvm::getSingleElement(adaptor.getByteShift()));
targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
@@ -1493,10 +1525,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
// Fields 4 and 5: Update sizes and strides.
Value stride = nullptr, nextSize = nullptr;
+ SmallVector<Value> sizes =
+ llvm::map_to_vector(adaptor.getSizes(), [](ValueRange r) {
+ return llvm::getSingleElement(r);
+ });
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size.
- Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
- adaptor.getSizes(), i, indexType);
+ Value size = getSize(rewriter, loc, viewMemRefType.getShape(), sizes, i,
+ indexType);
targetMemRef.setSize(rewriter, loc, i, size);
// Update stride.
stride =
@@ -1505,7 +1541,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
nextSize = size;
}
- rewriter.replaceOp(viewOp, {targetMemRef});
+ rewriter.replaceOpWithMultiple(viewOp, {targetMemRef.getElements()});
return success();
}
};
@@ -1551,7 +1587,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
using Base::Base;
LogicalResult
- matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
+ matchAndRewrite(memref::AtomicRMWOp atomicOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
@@ -1561,11 +1597,15 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
int64_t offset;
if (failed(memRefType.getStridesAndOffset(strides, offset)))
return failure();
- auto dataPtr =
- getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
+ SmallVector<Value> indices =
+ llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+ return llvm::getSingleElement(r);
+ });
+ auto dataPtr = getStridedElementPtr(atomicOp.getLoc(), memRefType,
+ adaptor.getMemref(), indices, rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
- atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
+ atomicOp, *maybeKind, dataPtr,
+ llvm::getSingleElement(adaptor.getValue()),
LLVM::AtomicOrdering::acq_rel);
return success();
}
@@ -1580,7 +1620,7 @@ public:
LogicalResult
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
- OpAdaptor adaptor,
+ OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
BaseMemRefType sourceTy = extractOp.getSource().getType();
@@ -1616,12 +1656,8 @@ public:
LogicalResult
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
- OpAdaptor adaptor,
+ OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
- if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
- return failure();
-
// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.getSource());
Location loc = extractStridedMetadataOp.getLoc();
@@ -1629,7 +1665,7 @@ public:
auto sourceMemRefType = cast<MemRefType>(source.getType());
int64_t rank = sourceMemRefType.getRank();
- SmallVector<Value> results;
+ SmallVector<ValueRange> results;
results.reserve(2 + rank * 2);
// Base buffer.
@@ -1639,19 +1675,11 @@ public:
rewriter, loc, *getTypeConverter(),
cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
baseBuffer, alignedBuffer);
- results.push_back((Value)dstMemRef);
-
- // Offset.
- results.push_back(sourceMemRef.offset(rewriter, loc));
-
- // Sizes.
- for (unsigned i = 0; i < rank; ++i)
- results.push_back(sourceMemRef.size(rewriter, loc, i));
- // Strides.
- for (unsigned i = 0; i < rank; ++i)
- results.push_back(sourceMemRef.stride(rewriter, loc, i));
-
- rewriter.replaceOp(extractStridedMetadataOp, results);
+ results.push_back(dstMemRef.getElements());
+ // Offset, sizes, strides of the source memref.
+ for (size_t i = 2, e = sourceMemRef.getElements().size(); i < e; ++i)
+ results.push_back(sourceMemRef.getElements().slice(i, 1));
+ rewriter.replaceOpWithMultiple(extractStridedMetadataOp, results);
return success();
}
};
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 51507c6..4613b90 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1179,8 +1179,10 @@ struct NVGPUTmaCreateDescriptorOpLowering
Value tensorElementType =
elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
- auto promotedOperands = getTypeConverter()->promoteOperands(
- b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
+ llvm_unreachable("TODO");
+ SmallVector<Value> promotedOperands;
+ //auto promotedOperands = getTypeConverter()->promoteOperands(
+ // b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
makeI64Const(b, 5));
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 213f737..23525cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -95,7 +95,7 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
// Add an index vector component to a base pointer.
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter &typeConverter,
- MemRefType memRefType, Value llvmMemref, Value base,
+ MemRefType memRefType, ValueRange llvmMemref, Value base,
Value index, VectorType vectorType) {
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
"unsupported memref type");
@@ -185,8 +185,9 @@ public:
/// Overloaded utility that replaces a vector.load, vector.store,
/// vector.maskedload and vector.maskedstore with their respective LLVM
/// couterparts.
+template<typename Adaptor>
static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
- vector::LoadOpAdaptor adaptor,
+ Adaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
@@ -194,29 +195,32 @@ static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
loadOp.getNontemporal());
}
+template<typename Adaptor>
static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
- vector::MaskedLoadOpAdaptor adaptor,
+ Adaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
- loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
+ loadOp, vectorTy, ptr, llvm::getSingleElement(adaptor.getMask()), llvm::getSingleElement(adaptor.getPassThru()), align);
}
+template<typename Adaptor>
static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
- vector::StoreOpAdaptor adaptor,
+ Adaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, llvm::getSingleElement(adaptor.getValueToStore()),
ptr, align, /*volatile_=*/false,
storeOp.getNontemporal());
}
+template<typename Adaptor>
static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
- vector::MaskedStoreOpAdaptor adaptor,
+ Adaptor adaptor,
VectorType vectorTy, Value ptr, unsigned align,
ConversionPatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
- storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
+ storeOp, llvm::getSingleElement(adaptor.getValueToStore()), ptr, llvm::getSingleElement(adaptor.getMask()), align);
}
/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
@@ -225,10 +229,11 @@ template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
+ using Adaptor = typename ConvertOpToLLVMPattern<LoadOrStoreOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
- typename LoadOrStoreOp::Adaptor adaptor,
+ Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only 1-D vectors can be lowered to LLVM.
VectorType vectorTy = loadOrStoreOp.getVectorType();
@@ -244,10 +249,11 @@ public:
return failure();
// Resolve address.
+ SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
auto vtype = cast<VectorType>(
this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ indices, rewriter);
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
rewriter);
return success();
@@ -261,7 +267,7 @@ public:
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
+ matchAndRewrite(vector::GatherOp gather, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = gather->getLoc();
MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
@@ -284,17 +290,18 @@ public:
}
// Resolve address.
+ SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
- Value base = adaptor.getBase();
+ indices, rewriter);
+ ValueRange base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
+ base, ptr, llvm::getSingleElement(adaptor.getIndexVec()), vType);
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ gather, typeConverter->convertType(vType), ptrs, llvm::getSingleElement(adaptor.getMask()),
+ llvm::getSingleElement(adaptor.getPassThru()), rewriter.getI32IntegerAttr(align));
return success();
}
};
@@ -306,7 +313,7 @@ public:
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
+ matchAndRewrite(vector::ScatterOp scatter, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
MemRefType memRefType = scatter.getMemRefType();
@@ -328,15 +335,16 @@ public:
}
// Resolve address.
+ SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ indices, rewriter);
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+ adaptor.getBase(), ptr, llvm::getSingleElement(adaptor.getIndexVec()), vType);
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
- scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
+ scatter, llvm::getSingleElement(adaptor.getValueToStore()), ptrs, llvm::getSingleElement(adaptor.getMask()),
rewriter.getI32IntegerAttr(align));
return success();
}
@@ -349,18 +357,19 @@ public:
using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
+ matchAndRewrite(vector::ExpandLoadOp expand, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = expand->getLoc();
MemRefType memRefType = expand.getMemRefType();
// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
+ SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ indices, rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
+ expand, vtype, ptr, llvm::getSingleElement(adaptor.getMask()), llvm::getSingleElement(adaptor.getPassThru()));
return success();
}
};
@@ -372,17 +381,18 @@ public:
using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
+ matchAndRewrite(vector::CompressStoreOp compress, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = compress->getLoc();
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
+ SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- adaptor.getIndices(), rewriter);
+ indices, rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
- compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
+ compress, llvm::getSingleElement(adaptor.getValueToStore()), ptr, llvm::getSingleElement(adaptor.getMask()));
return success();
}
};
@@ -1416,7 +1426,7 @@ public:
using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
+ matchAndRewrite(vector::TypeCastOp castOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
@@ -1428,15 +1438,10 @@ public:
!targetMemRefType.hasStaticShape())
return failure();
- auto llvmSourceDescriptorTy =
- dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
- if (!llvmSourceDescriptorTy)
- return failure();
MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
- auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
- typeConverter->convertType(targetMemRefType));
- if (!llvmTargetDescriptorTy)
+ SmallVector<Type> llvmTargetDescriptorTypes;
+ if (failed(typeConverter->convertType(targetMemRefType, llvmTargetDescriptorTypes)))
return failure();
// Only contiguous source buffers supported atm.
@@ -1453,7 +1458,7 @@ public:
auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
// Create descriptor.
- auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
+ auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTypes);
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
desc.setAllocatedPtr(rewriter, loc, allocated);
@@ -1480,7 +1485,7 @@ public:
desc.setStride(rewriter, loc, index, stride);
}
- rewriter.replaceOp(castOp, {desc});
+ rewriter.replaceOpWithMultiple(castOp, {desc.getElements()});
return success();
}
};