diff options
-rwxr-xr-x | mlir/artifacts/jq-linux64 | bin | 0 -> 3953824 bytes | |||
-rw-r--r-- | mlir/include/mlir/Conversion/LLVMCommon/Pattern.h | 35 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 57 | ||||
-rw-r--r-- | mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp | 96 | ||||
-rw-r--r-- | mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp | 56 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp | 104 | ||||
-rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp | 188 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 540 | ||||
-rw-r--r-- | mlir/test/Transforms/decompose-call-graph-types.mlir | 38 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp | 2 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 |
11 files changed, 663 insertions, 454 deletions
diff --git a/mlir/artifacts/jq-linux64 b/mlir/artifacts/jq-linux64 Binary files differnew file mode 100755 index 0000000..f48b0ca --- /dev/null +++ b/mlir/artifacts/jq-linux64 diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index f3bf5b6..6751c3e 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -143,6 +143,8 @@ template <typename SourceOp> class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>; explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) @@ -153,8 +155,13 @@ public: /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)), - rewriter); + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); } LogicalResult match(Operation *op) const final { return match(cast<SourceOp>(op)); @@ -162,8 +169,15 @@ public: LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast<SourceOp>(op), - OpAdaptor(operands, cast<SourceOp>(op)), rewriter); + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be @@ -175,6 +189,12 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -183,6 +203,13 @@ public: rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConvertToLLVMPattern::match; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index de47765..4c555e1 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -537,6 +537,10 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } + virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } /// Hook for derived classes to implement combined matching and rewriting. virtual LogicalResult @@ -547,6 +551,11 @@ public: rewrite(op, operands, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } /// Attempt to match and rewrite the IR root at the specified operation. LogicalResult matchAndRewrite(Operation *op, @@ -574,6 +583,9 @@ protected: : RewritePattern(std::forward<Args>(args)...), typeConverter(&typeConverter) {} + SmallVector<Value> + getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands) const; + protected: /// An optional type converter for use by this pattern. const TypeConverter *typeConverter = nullptr; @@ -589,6 +601,8 @@ template <typename SourceOp> class OpConversionPattern : public ConversionPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>; OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} @@ -607,12 +621,24 @@ public: auto sourceOp = cast<SourceOp>(op); rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast<SourceOp>(op); return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -623,6 +649,12 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -631,6 +663,13 @@ public: rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConversionPattern::matchAndRewrite; @@ -656,11 +695,20 @@ public: ConversionPatternRewriter &rewriter) const final { rewrite(cast<SourceOp>(op), operands, rewriter); } + void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast<SourceOp>(op), operands, rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -668,6 +716,10 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { @@ -676,6 +728,11 @@ public: rewrite(op, operands, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } private: using ConversionPattern::matchAndRewrite; diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index ce91424..20a2a10 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -153,6 +153,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, type.isVarArg()); }); +/* // Argument materializations convert from the new block argument types // (multiple SSA values that make up a memref descriptor) back to the // original block argument type. The dialect conversion framework will then @@ -198,16 +199,62 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc) .getResult(0); }); + +*/ // Add generic source and target materializations to handle cases where // non-LLVM types persist after an LLVM conversion. addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - if (inputs.size() != 1) - return Value(); + //if (inputs.size() != 1) + // return Value(); return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) .getResult(0); }); + addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, Location loc) { + if (inputs.size()== 1 && isa<LLVM::LLVMStructType>(inputs.front().getType())) return Value(); + + Value desc; + if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) { + // This is a bare pointer. We allow bare pointers only for function entry + // blocks. + BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front()); + if (!barePtr) + return Value(); + Block *block = barePtr.getOwner(); + if (!block->isEntryBlock() || + !isa<FunctionOpInterface>(block->getParentOp())) + return Value(); + desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType, + inputs[0]); + } else { + //llvm::errs() << "pack elems: " << inputs.size() << "\n"; + //llvm::errs() << inputs[0] << "\n"; + desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); + //llvm::errs() << "done packing\n"; + } + // An argument materialization must return a value of type `resultType`, + // so insert a cast from the memref descriptor type (!llvm.struct) to the + // original memref type. + return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc) + .getResult(0); + }); + addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, + ValueRange inputs, Location loc) { + if (inputs.size() == 1) { + // Bare pointers are not supported for unranked memrefs because a + // memref descriptor cannot be built just from a bare pointer. + return Value(); + } + Value desc = + UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs); + // An argument materialization must return a value of type + // `resultType`, so insert a cast from the memref descriptor type + // (!llvm.struct) to the original memref type. + return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc) + .getResult(0); + }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { if (inputs.size() != 1) @@ -216,6 +263,51 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) .getResult(0); }); + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc, Type originalType) -> Value { + llvm::errs() << "TARGET MAT: -> " << resultType << "\n"; + if (!originalType) { + llvm::errs() << " -- no orig\n"; + return Value(); + } + if (auto memrefType = dyn_cast<MemRefType>(originalType)) { + assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type"); + if (inputs.size() == 1) { + Value input = inputs.front(); + if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) { + if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) { + input = castOp.getInputs()[0]; + } + } + if (!isa<LLVM::LLVMPointerType>(input.getType())) + return Value(); + BlockArgument barePtr = dyn_cast<BlockArgument>(input); + if (!barePtr) + return Value(); + Block *block = barePtr.getOwner(); + if (!block->isEntryBlock() || + !isa<FunctionOpInterface>(block->getParentOp())) + return Value(); + // Bare ptr + return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType, + input); + } + return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs); + } + if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) { + assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type"); + if (inputs.size() == 1) { + // Bare pointers are not supported for unranked memrefs because a + // memref descriptor cannot be built just from a bare pointer. + return Value(); + } + return UnrankedMemRefDescriptor::pack(builder, loc, *this, + memrefType, inputs); + } + + return Value(); + }); // Integer memory spaces map to themselves. addTypeAttributeConversion( diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp index a087643..03be003 100644 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp @@ -14,40 +14,6 @@ using namespace mlir; using namespace mlir::func; //===----------------------------------------------------------------------===// -// Helper functions -//===----------------------------------------------------------------------===// - -/// If the given value can be decomposed with the type converter, decompose it. -/// Otherwise, return the given value. -// TODO: Value decomposition should happen automatically through a 1:N adaptor. -// This function will disappear when the 1:1 and 1:N drivers are merged. -static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, - Value value, - const TypeConverter *converter) { - // Try to convert the given value's type. If that fails, just return the - // given value. - SmallVector<Type> convertedTypes; - if (failed(converter->convertType(value.getType(), convertedTypes))) - return {value}; - if (convertedTypes.empty()) - return {}; - - // If the given value's type is already legal, just return the given value. - TypeRange convertedTypeRange(convertedTypes); - if (convertedTypeRange == TypeRange(value.getType())) - return {value}; - - // Try to materialize a target conversion. If the materialization did not - // produce values of the requested type, the materialization failed. Just - // return the given value in that case. - SmallVector<Value> result = converter->materializeTargetConversion( - builder, loc, convertedTypeRange, value); - if (result.empty()) - return {value}; - return result; -} - -//===----------------------------------------------------------------------===// // DecomposeCallGraphTypesForFuncArgs //===----------------------------------------------------------------------===// @@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector<Value, 2> newOperands; - for (Value operand : adaptor.getOperands()) { - // TODO: We can directly take the values from the adaptor once this is a - // 1:N conversion pattern. - llvm::append_range(newOperands, - decomposeValue(rewriter, operand.getLoc(), operand, - getTypeConverter())); - } + for (ValueRange operand : adaptor.getOperands()) + llvm::append_range(newOperands, operand); rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands); return success(); } @@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CallOp op, OpAdaptor adaptor, + matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // Create the operands list of the new `CallOp`. SmallVector<Value, 2> newOperands; - for (Value operand : adaptor.getOperands()) { - // TODO: We can directly take the values from the adaptor once this is a - // 1:N conversion pattern. - llvm::append_range(newOperands, - decomposeValue(rewriter, operand.getLoc(), operand, - getTypeConverter())); - } + for (ValueRange operand : adaptor.getOperands()) + llvm::append_range(newOperands, operand); // Create the new result types for the new `CallOp` and track the number of // replacement types for each original op result. diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 93a7805..4d154b0 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -16,20 +16,16 @@ using namespace mlir::scf; namespace { -// Unpacks the single unrealized_conversion_cast using the list of inputs -// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d) -static void unpackUnrealizedConversionCast(Value v, - SmallVectorImpl<Value> &unpacked) { - if (auto cast = - dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) { - if (cast.getInputs().size() != 1) { - // 1 : N type conversion. - unpacked.append(cast.getInputs().begin(), cast.getInputs().end()); - return; - } - } - // 1 : 1 type conversion. - unpacked.push_back(v); +static SmallVector<Value> flattenValues(ArrayRef<ArrayRef<Value>> values) { + SmallVector<Value> result; + for (ArrayRef<Value> v : values) + llvm::append_range(result, v); + return result; +} + +static Value getSingleValue(ArrayRef<Value> values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); } // CRTP @@ -40,19 +36,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> { public: using OpConversionPattern<SourceOp>::typeConverter; using OpConversionPattern<SourceOp>::OpConversionPattern; - using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor; + using OneToNOpAdaptor = + typename OpConversionPattern<SourceOp>::OneToNOpAdaptor; // // Derived classes should provide the following method which performs the // actual conversion. It should return std::nullopt upon conversion failure // and return the converted operation upon success. // - // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor, - // ConversionPatternRewriter &rewriter, - // TypeRange dstTypes) const; + // std::optional<SourceOp> convertSourceOp( + // SourceOp op, OneToNOpAdaptor adaptor, + // ConversionPatternRewriter &rewriter, + // TypeRange dstTypes) const; LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector<Type> dstTypes; SmallVector<unsigned> offsets; @@ -73,28 +71,15 @@ public: return rewriter.notifyMatchFailure(op, "could not convert operation"); // Packs the return value. - SmallVector<Value> packedRets; + SmallVector<ValueRange> packedRets; for (unsigned i = 1, e = offsets.size(); i < e; i++) { unsigned start = offsets[i - 1], end = offsets[i]; unsigned len = end - start; ValueRange mappedValue = newOp->getResults().slice(start, len); - if (len != 1) { - // 1 : N type conversion. - Type origType = op.getResultTypes()[i - 1]; - Value mat = typeConverter->materializeSourceConversion( - rewriter, op.getLoc(), origType, mappedValue); - if (!mat) { - return rewriter.notifyMatchFailure( - op, "Failed to materialize 1:N type conversion"); - } - packedRets.push_back(mat); - } else { - // 1 : 1 type conversion. - packedRets.push_back(mappedValue.front()); - } + packedRets.push_back(mappedValue); } - rewriter.replaceOp(op, packedRets); + rewriter.replaceOpWithMultiple(op, packedRets); return success(); } }; @@ -105,7 +90,7 @@ public: using Structural1ToNConversionPattern::Structural1ToNConversionPattern; // The callback required by CRTP. - std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor, + std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { // Create a empty new op and inline the regions from the old op. @@ -129,16 +114,13 @@ public: if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter))) return std::nullopt; - // Unpacked the iteration arguments. - SmallVector<Value> flatArgs; - for (Value arg : adaptor.getInitArgs()) - unpackUnrealizedConversionCast(arg, flatArgs); - // We can not do clone as the number of result types after conversion // might be different. - ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(), - adaptor.getUpperBound(), - adaptor.getStep(), flatArgs); + ForOp newOp = rewriter.create<ForOp>( + op.getLoc(), getSingleValue(adaptor.getLowerBound()), + getSingleValue(adaptor.getUpperBound()), + getSingleValue(adaptor.getStep()), + flattenValues(adaptor.getInitArgs())); // Reserve whatever attributes in the original op. newOp->setAttrs(op->getAttrs()); @@ -160,12 +142,12 @@ class ConvertIfOpTypes public: using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor, + std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes, - adaptor.getCondition(), true); + IfOp newOp = rewriter.create<IfOp>( + op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true); newOp->setAttrs(op->getAttrs()); // We do not need the empty blocks created by rewriter. @@ -189,15 +171,11 @@ class ConvertWhileOpTypes public: using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor, + std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - // Unpacked the iteration arguments. - SmallVector<Value> flatArgs; - for (Value arg : adaptor.getOperands()) - unpackUnrealizedConversionCast(arg, flatArgs); - - auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs); + auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, + flattenValues(adaptor.getOperands())); for (auto i : {0u, 1u}) { if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) @@ -218,13 +196,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> unpackedYield; - for (Value operand : adaptor.getOperands()) - unpackUnrealizedConversionCast(operand, unpackedYield); - - rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield); + rewriter.replaceOpWithNewOp<scf::YieldOp>( + op, flattenValues(adaptor.getOperands())); return success(); } }; @@ -235,13 +210,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> { public: using OpConversionPattern<ConditionOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ConditionOp op, OpAdaptor adaptor, + matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> unpackedYield; - for (Value operand : adaptor.getOperands()) - unpackUnrealizedConversionCast(operand, unpackedYield); - - rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); }); + rewriter.modifyOpInPlace( + op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); }); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 9abb1d3..0fa9f26 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -39,25 +39,16 @@ using namespace mlir::sparse_tensor; // Helper methods. //===----------------------------------------------------------------------===// -/// Flattens a list of operands that may contain sparse tensors. -static void flattenOperands(ValueRange operands, - SmallVectorImpl<Value> &flattened) { - // In case of - // sparse_tensor, c, sparse_tensor - // ==> - // memref ..., c, memref ... - for (auto operand : operands) { - if (getSparseTensorEncoding(operand.getType())) { - auto tuple = getTuple(operand); - // An unrealized_conversion_cast will be inserted by type converter to - // inter-mix the gap between 1:N conversion between sparse tensors and - // fields. In this case, take the operands in the cast and replace the - // sparse tensor output with the flattened type array. - flattened.append(tuple.getOperands().begin(), tuple.getOperands().end()); - } else { - flattened.push_back(operand); - } - } +static SmallVector<Value> flattenValues(ArrayRef<ArrayRef<Value>> values) { + SmallVector<Value> result; + for (ArrayRef<Value> v : values) + llvm::append_range(result, v); + return result; +} + +static Value getSingleValue(ArrayRef<Value> values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); } /// Generates a load with proper `index` typing. @@ -567,12 +558,11 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> flattened; - flattenOperands(adaptor.getOperands(), flattened); // Create a return with the flattened value extracted from sparse tensors. - rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened); + rewriter.replaceOpWithNewOp<func::ReturnOp>( + op, flattenValues(adaptor.getOperands())); return success(); } }; @@ -583,7 +573,7 @@ public: // The default CallOp converter can not handle 1:N type conversion. using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); // In case of: @@ -596,10 +586,8 @@ public: return failure(); // (1) Generates new call with flattened return value. - SmallVector<Value> flattened; - flattenOperands(adaptor.getOperands(), flattened); - auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(), - finalRetTy, flattened); + auto newCall = rewriter.create<func::CallOp>( + loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); // (2) Gather sparse tensor returns. SmallVector<SmallVector<Value>> packedResultVals; // Tracks the offset of current return value (of the original call) @@ -643,13 +631,15 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(LvlOp op, OpAdaptor adaptor, + matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { std::optional<int64_t> lvl = op.getConstantLvlIndex(); - if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType())) + if (!lvl || !getSparseTensorEncoding(op.getSource().getType())) return failure(); - auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getSource().getType())), + adaptor.getSource()); auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl); rewriter.replaceOp(op, sz); @@ -661,7 +651,7 @@ public: struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor, + matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); @@ -675,8 +665,10 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { assert(dstStt.hasSameDimToLvl(srcStt)); // We don't need a mutable descriptor here as we perform sorting in-place. - auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo()); - auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getInputCoo().getType())), + adaptor.getInputCoo()); + auto nnz = desc.getValMemSize(rewriter, op.getLoc()); auto crd = desc.getAOSMemRef(); auto val = desc.getValMemRef(); @@ -691,7 +683,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { // Since we do in-place sorting, the destinate tensor will have the same set // of memrefs as the source tensor. - rewriter.replaceOp(op, adaptor.getInputCoo()); + rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()}); return success(); } }; @@ -701,10 +693,13 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { public: using OpConversionPattern<Op>::OpConversionPattern; LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, + matchAndRewrite(Op op, + typename OpConversionPattern<Op>::OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply lowers to specifer.get <field> operation. - auto desc = getDescriptorFromTensorTuple(adaptor.getSlice()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getSlice().getType())), + adaptor.getSlice()); auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, op.getDim().getZExtValue()); @@ -718,14 +713,14 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only rewrite identically annotated source/dest. auto encDst = getSparseTensorEncoding(op.getType()); auto encSrc = getSparseTensorEncoding(op.getSource().getType()); if (!encDst || encDst != encSrc) return failure(); - rewriter.replaceOp(op, adaptor.getOperands()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } }; @@ -734,10 +729,10 @@ class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, + matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply fold the operation. - rewriter.replaceOp(op, adaptor.getSource()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } }; @@ -753,7 +748,7 @@ public: enableBufferInitialization(enableInit) {} LogicalResult - matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const auto resType = getSparseTensorType(op); if (!resType.hasEncoding()) @@ -762,7 +757,9 @@ public: Location loc = op.getLoc(); // Deal with copy. if (op.getCopy()) { - auto desc = getDescriptorFromTensorTuple(adaptor.getCopy()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getCopy().getType())), + adaptor.getCopy()); SmallVector<Value> fields; fields.reserve(desc.getNumFields()); // Memcpy on memref fields. @@ -787,7 +784,8 @@ public: } // Level size equals to dimension size since lvl2dim map is an identity map. SmallVector<Value> lvlSizesValues; - createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), + createDimSizes(rewriter, loc, resType, + flattenValues(adaptor.getDynamicSizes()), /*dimSizesValues=*/lvlSizesValues); // Construct allocation for each field. @@ -857,7 +855,7 @@ public: createDeallocs(createDeallocs) {} LogicalResult - matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto enc = getSparseTensorEncoding(op.getTensor().getType()); if (!enc) @@ -868,7 +866,9 @@ public: if (createDeallocs) { // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create<memref::DeallocOp>(loc, input); @@ -886,10 +886,12 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(LoadOp op, OpAdaptor adaptor, + matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prepare descriptor. - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); // Generate optional insertion finalization code. if (op.getHasInserts()) genEndInsert(rewriter, op.getLoc(), desc); @@ -904,12 +906,14 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ExpandOp op, OpAdaptor adaptor, + matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!getSparseTensorEncoding(op.getTensor().getType())) return failure(); Location loc = op->getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); const auto srcType = getSparseTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); @@ -955,15 +959,18 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CompressOp op, OpAdaptor adaptor, + matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); SmallVector<Value> fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); - Value values = adaptor.getValues(); - Value filled = adaptor.getFilled(); - Value added = adaptor.getAdded(); - Value count = adaptor.getCount(); + llvm::append_range(fields, adaptor.getTensor()); + MutSparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + fields); + Value values = getSingleValue(adaptor.getValues()); + Value filled = getSingleValue(adaptor.getFilled()); + Value added = getSingleValue(adaptor.getAdded()); + Value count = getSingleValue(adaptor.getCount()); const SparseTensorType dstType(desc.getRankedTensorType()); Type eltType = dstType.getElementType(); @@ -996,7 +1003,8 @@ public: SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end()); SmallVector<Type> flatSpTensorTps = llvm::to_vector( llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); - params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end()); + SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords()); + params.append(flatLvlCoords.begin(), flatLvlCoords.end()); params.push_back(crd); params.push_back(value); SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, @@ -1024,19 +1032,22 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto stt = getSparseTensorType(adaptor.getDest()); + auto stt = getSparseTensorType(op.getDest()); if (!stt.hasEncoding()) return failure(); assert(stt.isIdentity() && "Run reinterpret-map before conversion."); Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getDest()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getDest().getType())), + adaptor.getDest()); TypeRange flatSpTensorTps = desc.getFields().getTypes(); SmallVector<Value> params = llvm::to_vector(desc.getFields()); - params.append(adaptor.getIndices().begin(), adaptor.getIndices().end()); - params.push_back(adaptor.getScalar()); + SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices()); + params.append(flatIndices.begin(), flatIndices.end()); + params.push_back(getSingleValue(adaptor.getScalar())); SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, params, /*genCall=*/true); SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc); @@ -1052,14 +1063,16 @@ public: using OpAdaptor = typename ToPositionsOp::Adaptor; using OpConversionPattern<ToPositionsOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, + matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested position access with corresponding field. // The view is restricted to the actual size to ensure clients // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = op.getLevel(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); auto mem = desc.getPosMemRef(lvl); auto size = desc.getPosMemSize(rewriter, loc, lvl); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1074,14 +1087,16 @@ public: using OpAdaptor = typename ToCoordinatesOp::Adaptor; using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, + matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested coordinates access with corresponding field. // The view is restricted to the actual size to ensure clients // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = op.getLevel(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl); if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) { auto size = desc.getCrdMemSize(rewriter, loc, lvl); @@ -1099,14 +1114,16 @@ public: using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor, + matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested coordinates access with corresponding field. // The view is restricted to the actual size to ensure clients // of this operation truly observe size, not capacity! Location loc = op.getLoc(); Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); auto mem = desc.getAOSMemRef(); auto size = desc.getCrdMemSize(rewriter, loc, lvl); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1120,13 +1137,15 @@ public: using OpAdaptor = typename ToValuesOp::Adaptor; using OpConversionPattern<ToValuesOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, + matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested values access with corresponding field. // The view is restricted to the actual size to ensure clients // of this operation truly observe size, not capacity! Location loc = op.getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); auto mem = desc.getValMemRef(); auto size = desc.getValMemSize(rewriter, loc); rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size)); @@ -1139,7 +1158,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConvertOp op, OpAdaptor adaptor, + matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); SparseTensorEncodingAttr encSrc = @@ -1159,7 +1178,7 @@ public: Type srcElemTp = op.getSource().getType().getElementType(); // Fold the trivial cases. if (retElemTp == srcElemTp && encDst == encSrc) { - rewriter.replaceOp(op, adaptor.getSource()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } // @@ -1172,7 +1191,9 @@ public: // else: // dst = memref.copy(src) Location loc = op.getLoc(); - auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource()); + SparseTensorDescriptor srcDesc( + SparseTensorType(cast<RankedTensorType>(op.getSource().getType())), + adaptor.getSource()); SmallVector<Value> fields; foreachFieldAndTypeInSparseTensor( SparseTensorType(cast<RankedTensorType>(op.getResult().getType())), @@ -1224,7 +1245,7 @@ class SparseExtractSliceConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); @@ -1236,7 +1257,10 @@ public: assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); SmallVector<Value> fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields); + llvm::append_range(fields, adaptor.getSource()); + MutSparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getSource().getType())), + fields); auto newSpec = rewriter.create<StorageSpecifierInitOp>( loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); @@ -1280,13 +1304,15 @@ class SparseNumberOfEntriesConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, + matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values. // FIXME: the nse value computed in this way might be wrong when there is // any "loose_compressed" level. - rewriter.replaceOp( - op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor())); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); + rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); return success(); } }; @@ -1413,9 +1439,11 @@ struct SparseDisassembleOpConverter : OpConversionPattern(typeConverter, context) {} LogicalResult - matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, + matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + SparseTensorDescriptor desc( + SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())), + adaptor.getTensor()); Location loc = op.getLoc(); SmallVector<Value> retMem; SmallVector<Value> retLen; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 33fa9e4..9b96821 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -54,8 +54,6 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { }); } -/// Helper function that computes an insertion point where the given value is -/// defined and can be used without a dominance violation. static OpBuilder::InsertPoint computeInsertPoint(Value value) { Block *insertBlock = value.getParentBlock(); Block::iterator insertPt = insertBlock->begin(); @@ -64,6 +62,27 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { return OpBuilder::InsertPoint(insertBlock, insertPt); } +/// Helper function that computes an insertion point where the given value is +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { + assert(!vals.empty() && "expected at least one value"); + OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); + for (Value v : vals.drop_front()) { + OpBuilder::InsertPoint pt2 = computeInsertPoint(v); + assert(pt.getBlock() == pt2.getBlock()); + if (pt.getPoint() == pt.getBlock()->begin()) { + pt = pt2; + continue; + } + if (pt2.getPoint() == pt2.getBlock()->begin()) { + continue; + } + if (pt.getPoint()->isBeforeInBlock(&*pt2.getPoint())) + pt = pt2; + } + return pt; +} + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// @@ -73,89 +92,220 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { using ReplacementValues = SmallVector<Value, 1>; namespace { +struct SmallVectorMapInfo { + static SmallVector<Value, 1> getEmptyKey() { return SmallVector<Value, 1>{}; } + static SmallVector<Value, 1> getTombstoneKey() { + return SmallVector<Value, 1>{}; + } + static ::llvm::hash_code getHashValue(SmallVector<Value, 1> val) { + return ::llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(SmallVector<Value, 1> LHS, SmallVector<Value, 1> RHS) { + return LHS == RHS; + } +}; + /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { - /// Lookup the most recently mapped value with the desired type in the - /// mapping. + /// Find the most recently mapped values for the given value. If the value is + /// not mapped at all, return the given value. + SmallVector<Value, 1> lookupOrDefault(Value from) const; + + /// TODO: Find most recently mapped or materialization with matching type. May + /// return the given value if the type matches. + SmallVector<Value, 1> + lookupOrDefault(Value from, SmallVector<Type, 1> desiredTypes) const; + + Value lookupDirectSingleReplacement(Value from) const { + auto it = mapping.find(from); + if (it == mapping.end()) + return Value(); + const SmallVector<Value, 1> &repl = it->second; + if (repl.size() != 1) return Value(); + return repl.front(); +/* + if (!mapping.contains(from)) return Value(); + auto it = llvm::find(mapping, from); + const SmallVector<Value, 1> &repl = it->second; + if (repl.size() != 1) return Value(); + return repl.front(); + */ + } + + /// Find the most recently mapped values for the given value. If the value is + /// not mapped at all, return an empty vector. + SmallVector<Value, 1> lookupOrNull(Value from) const; + + /// Find the most recently mapped values for the given value. If those values + /// have the desired types, return them. Otherwise, try to find a + /// materialization to the desired types. /// - /// Special cases: - /// - If the desired type is "null", simply return the most recently mapped - /// value. - /// - If there is no mapping to the desired type, also return the most - /// recently mapped value. - /// - If there is no mapping for the given value at all, return the given - /// value. - Value lookupOrDefault(Value from, Type desiredType = nullptr) const; - - /// Lookup a mapped value within the map, or return null if a mapping does not - /// exist. If a mapping exists, this follows the same behavior of - /// `lookupOrDefault`. - Value lookupOrNull(Value from, Type desiredType = nullptr) const; - - /// Map a value to the one provided. - void map(Value oldVal, Value newVal) { - LLVM_DEBUG({ - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - assert(it != oldVal && "inserting cyclic mapping"); - }); - mapping.map(oldVal, newVal); + /// If the given value is not mapped at all or if there are no mapped values/ + /// materialization results with the desired types, return an empty vector. + SmallVector<Value, 1> lookupOrNull(Value from, + SmallVector<Type, 1> desiredTypes) const; + + Value lookupOrNull(Value from, Type desiredType) { + SmallVector<Value, 1> vals = + lookupOrNull(from, SmallVector<Type, 1>{desiredType}); + if (vals.empty()) + return Value(); + assert(vals.size() == 1 && "expected single value"); + return vals.front(); } - /// Try to map a value to the one provided. Returns false if a transitive - /// mapping from the new value to the old value already exists, true if the - /// map was updated. - bool tryMap(Value oldVal, Value newVal); + void erase(Value from) { mapping.erase(from); } - /// Drop the last mapping for the given value. - void erase(Value value) { mapping.erase(value); } + void map(Value from, ArrayRef<BlockArgument> to) { + SmallVector<Value> vals; + for (Value v : to) + vals.push_back(v); + map(from, vals); + } + + void map(Value from, ArrayRef<Value> to) { +#ifndef NDEBUG + assert(from && "expected non-null value"); + assert(!to.empty() && "cannot map to zero values"); + for (Value v : to) + assert(v && "expected non-null value"); +#endif + // assert(from != to && "cannot map value to itself"); + // TODO: Check for cyclic mapping. + assert(!mapping.contains(from) && "value is already mapped"); + mapping[from].assign(to.begin(), to.end()); + } + + void mapMaterialization(SmallVector<Value, 1> from, + SmallVector<Value, 1> to) { +#ifndef NDEBUG + assert(!from.empty() && "from cannot be empty"); + assert(!to.empty() && "to cannot be empty"); + for (Value v : from) { + assert(v && "expected non-null value"); + assert(!mapping.contains(v) && + "cannot add materialization for mapped value"); + } + for (Value v : to) { + assert(v && "expected non-null value"); + } + assert(TypeRange(from) != TypeRange(to) && + "cannot add materialization for identical type"); + for (const SmallVector<Value, 1> &mat : materializations[from]) + assert(TypeRange(mat) != TypeRange(to) && + "cannot register duplicate materialization"); +#endif // NDEBUG + materializations[from].push_back(to); + } + + void eraseMaterialization(SmallVector<Value, 1> from, + SmallVector<Value, 1> to) { + auto it = llvm::find(materializations[from], to); + if (it == materializations[from].end()) + return; + materializations[from].erase(it); + } /// Returns the inverse raw value mapping (without recursive query support). DenseMap<Value, SmallVector<Value>> getInverse() const { DenseMap<Value, SmallVector<Value>> inverse; - for (auto &it : mapping.getValueMap()) - inverse[it.second].push_back(it.first); + + for (auto &it : mapping) + for (Value v : it.second) + inverse[v].push_back(it.first); + + for (auto &it : materializations) + for (const SmallVector<Value, 1> &mat : it.second) + for (Value v : mat) + for (Value v2 : it.first) + inverse[v].push_back(v2); + return inverse; } private: - /// Current value mappings. - IRMapping mapping; + /// Replacement mapping: Value -> ValueRange + DenseMap<Value, SmallVector<Value, 1>> mapping; + + /// Materializations: ValueRange -> ValueRange* + DenseMap<SmallVector<Value, 1>, SmallVector<SmallVector<Value, 1>>, + SmallVectorMapInfo> + materializations; }; } // namespace -Value ConversionValueMapping::lookupOrDefault(Value from, - Type desiredType) const { - // Try to find the deepest value that has the desired type. If there is no - // such value, simply return the deepest value. - Value desiredValue; - do { - if (!desiredType || from.getType() == desiredType) - desiredValue = from; - - Value mappedValue = mapping.lookupOrNull(from); - if (!mappedValue) - break; - from = mappedValue; - } while (true); +SmallVector<Value, 1> +ConversionValueMapping::lookupOrDefault(Value from) const { + SmallVector<Value, 1> to = lookupOrNull(from); + return to.empty() ? SmallVector<Value, 1>{from} : to; +} - // If the desired value was found use it, otherwise default to the leaf value. - return desiredValue ? desiredValue : from; +SmallVector<Value, 1> ConversionValueMapping::lookupOrDefault( + Value from, SmallVector<Type, 1> desiredTypes) const { +#ifndef NDEBUG + assert(desiredTypes.size() > 0 && "expected non-empty types"); + for (Type t : desiredTypes) + assert(t && "expected non-null type"); +#endif // NDEBUG + + SmallVector<Value, 1> vals = lookupOrNull(from); + if (vals.empty()) { + // Value is not mapped. Return if the type matches. + if (TypeRange(from) == desiredTypes) + return {from}; + // Check materializations. + auto it = materializations.find({from}); + if (it == materializations.end()) + return {}; + for (const SmallVector<Value, 1> &mat : it->second) + if (TypeRange(mat) == desiredTypes) + return mat; + return {}; + } + + return lookupOrNull(from, desiredTypes); } -Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { - Value result = lookupOrDefault(from, desiredType); - if (result == from || (desiredType && result.getType() != desiredType)) - return nullptr; +SmallVector<Value, 1> ConversionValueMapping::lookupOrNull(Value from) const { + auto it = mapping.find(from); + if (it == mapping.end()) + return {}; + SmallVector<Value, 1> result; + for (Value v : it->second) { + llvm::append_range(result, lookupOrDefault(v)); + } return result; } -bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) { - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - if (it == oldVal) - return false; - map(oldVal, newVal); - return true; +SmallVector<Value, 1> +ConversionValueMapping::lookupOrNull(Value from, + SmallVector<Type, 1> desiredTypes) const { +#ifndef NDEBUG + assert(desiredTypes.size() > 0 && "expected non-empty types"); + for (Type t : desiredTypes) + assert(t && "expected non-null type"); +#endif // NDEBUG + + SmallVector<Value, 1> vals = lookupOrNull(from); + if (vals.empty()) + return {}; + + // There is a mapping and the types match. + if (TypeRange(vals) == desiredTypes) + return vals; + + // There is a mapping, but the types do not match. Try to find a matching + // materialization. + auto it = materializations.find(vals); + if (it == materializations.end()) + return {}; + for (const SmallVector<Value, 1> &mat : it->second) + if (TypeRange(mat) == desiredTypes) + return mat; + + // No materialization found. Return an empty vector. + return {}; } //===----------------------------------------------------------------------===// @@ -781,7 +931,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped); + SmallVector<SmallVector<Value, 1>> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -817,27 +967,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Build an unresolved materialization operation given an output type and set /// of input operands. - Value buildUnresolvedMaterialization(MaterializationKind kind, - OpBuilder::InsertPoint ip, Location loc, - ValueRange inputs, Type outputType, - Type originalType, - const TypeConverter *converter); - - /// Build an N:1 materialization for the given original value that was - /// replaced with the given replacement values. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. The conversion mapping can store only 1:1 replacements - /// and the conversion patterns only support single Value replacements in the - /// adaptor, so N values must be converted back to a single value. This - /// function will be deleted when full 1:N support has been added. - /// - /// This function inserts an argument materialization back to the original - /// type, followed by a target materialization to the legalized type (if - /// applicable). - void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, - ValueRange replacements, Value originalValue, - const TypeConverter *converter); + ValueRange buildUnresolvedMaterialization(MaterializationKind kind, + OpBuilder::InsertPoint ip, + Location loc, ValueRange inputs, + TypeRange outputTypes, + Type originalType, + const TypeConverter *converter); //===--------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1072,10 +1207,8 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( } void UnresolvedMaterializationRewrite::rollback() { - if (getMaterializationKind() == MaterializationKind::Target) { - for (Value input : op->getOperands()) - rewriterImpl.mapping.erase(input); - } + rewriterImpl.mapping.eraseMaterialization(op->getOperands(), + op->getResults()); rewriterImpl.unresolvedMaterializations.erase(getOperation()); op->erase(); } @@ -1120,7 +1253,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped) { + SmallVector<SmallVector<Value, 1>> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1132,7 +1265,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped value. - remapped.push_back(mapping.lookupOrDefault(operand)); + SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand); + remapped.push_back(vals); continue; } @@ -1146,36 +1280,29 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( return failure(); } - if (legalTypes.size() != 1) { - // TODO: Parts of the dialect conversion infrastructure do not support - // 1->N type conversions yet. Therefore, if a type is converted to 0 or - // multiple types, the only thing that we can do for now is passing - // through the most recently mapped value. Fixing this requires - // improvements to the `ConversionValueMapping` (to be able to store 1:N - // mappings) and to the `ConversionPattern` adaptor handling (to be able - // to pass multiple remapped values for a single operand to the adaptor). - remapped.push_back(mapping.lookupOrDefault(operand)); + // Try to find a mapped value with the desired type. + if (legalTypes.empty()) { + remapped.push_back({}); continue; } - // Handle 1->1 type conversions. - Type desiredType = legalTypes.front(); - // Try to find a mapped value with the desired type. (Or the operand itself - // if the value is not mapped at all.) - Value newOperand = mapping.lookupOrDefault(operand, desiredType); - if (newOperand.getType() != desiredType) { - // If the looked up value's type does not have the desired type, it means - // that the value was replaced with a value of different type and no - // source materialization was created yet. - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(newOperand), - operandLoc, - /*inputs=*/newOperand, /*outputType=*/desiredType, - /*originalType=*/origType, currentTypeConverter); - mapping.map(newOperand, castValue); - newOperand = castValue; + SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes); + if (!mat.empty()) { + // Mapped value has the correct type or there is an existing + // materialization. Or the value is not mapped at all and has the + // correct type. + remapped.push_back(mat); + continue; } - remapped.push_back(newOperand); + + // Create a materialization for the most recently mapped value. + SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand); + ValueRange castValues = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(vals), operandLoc, + /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType, currentTypeConverter); + + mapping.mapMaterialization(vals, castValues); + remapped.push_back(castValues); } return success(); } @@ -1287,7 +1414,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( MaterializationKind::Source, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*inputs=*/ValueRange(), - /*outputType=*/origArgType, /*originalType=*/Type(), converter); + /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)[0]; mapping.map(origArg, repl); appendRewrite<ReplaceBlockArgRewrite>(block, origArg); continue; @@ -1303,15 +1430,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( continue; } - // This is a 1->1+ mapping. 1->N mappings are not fully supported in the - // dialect conversion. Therefore, we need an argument materialization to - // turn the replacement block arguments into a single SSA value that can be - // used as a replacement. + // Map to replacement arguments. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - insertNTo1Materialization( - OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*replacements=*/replArgs, /*outputValue=*/origArg, converter); + mapping.map(origArg, replArgs); appendRewrite<ReplaceBlockArgRewrite>(block, origArg); } @@ -1330,59 +1452,21 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// Build an unresolved materialization operation given an output type and set /// of input operands. -Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( +ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - ValueRange inputs, Type outputType, Type originalType, - const TypeConverter *converter) { - assert((!originalType || kind == MaterializationKind::Target) && - "original type is valid only for target materializations"); - + ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter) { // Avoid materializing an unnecessary cast. - if (inputs.size() == 1 && inputs.front().getType() == outputType) - return inputs.front(); + if (TypeRange(inputs) == outputTypes) + return inputs; // Create an unresolved materialization. We use a new OpBuilder to avoid // tracking the materialization like we do for other operations. - OpBuilder builder(outputType.getContext()); + OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = - builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, - originalType); - return convertOp.getResult(0); -} - -void ConversionPatternRewriterImpl::insertNTo1Materialization( - OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, - Value originalValue, const TypeConverter *converter) { - // Insert argument materialization back to the original type. - Type originalType = originalValue.getType(); - Value argMat = - buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc, - /*inputs=*/replacements, originalType, - /*originalType=*/Type(), converter); - mapping.map(originalValue, argMat); - - // Insert target materialization to the legalized type. - Type legalOutputType; - if (converter) { - legalOutputType = converter->convertType(originalType); - } else if (replacements.size() == 1) { - // When there is no type converter, assume that the replacement value - // types are legal. This is reasonable to assume because they were - // specified by the user. - // FIXME: This won't work for 1->N conversions because multiple output - // types are not supported in parts of the dialect conversion. In such a - // case, we currently use the original value type. - legalOutputType = replacements[0].getType(); - } - if (legalOutputType && legalOutputType != originalType) { - Value targetMat = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(argMat), loc, - /*inputs=*/argMat, /*outputType=*/legalOutputType, - /*originalType=*/originalType, converter); - mapping.map(argMat, targetMat); - } + builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs); + appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, originalType); + return convertOp.getResults(); } //===----------------------------------------------------------------------===// @@ -1432,12 +1516,11 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( } // Materialize a replacement value "out of thin air". - Value sourceMat = buildUnresolvedMaterialization( + repl = buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), result.getLoc(), /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); - repl.push_back(sourceMat); } else { // Make sure that the user does not mess with unresolved materializations // that were inserted by the conversion driver. We keep track of these @@ -1450,18 +1533,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( } // Remap result to replacement value. - if (repl.empty()) - continue; - - if (repl.size() == 1) { - // Single replacement value: replace directly. - mapping.map(result, repl.front()); - } else { - // Multiple replacement values: insert N:1 materialization. - insertNTo1Materialization(computeInsertPoint(result), result.getLoc(), - /*replacements=*/repl, /*outputValue=*/result, - currentTypeConverter); - } + if (!repl.empty()) + mapping.map(result, repl); } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); @@ -1612,15 +1685,18 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(" << from.getOwner()->getParentOp() << ")\n"; }); impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from); - impl->mapping.map(impl->mapping.lookupOrDefault(from), to); + SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from); + assert(mapped.size() == 1 && "replaceUsesOfBlockArgument is not supported for 1:N replacements"); + impl->mapping.map(mapped.front(), to); } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector<Value> remappedValues; + SmallVector<SmallVector<Value, 1>> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; - return remappedValues.front(); + assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); + return remappedValues.front().front(); } LogicalResult @@ -1628,8 +1704,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl<Value> &results) { if (keys.empty()) return success(); - return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, - results); + SmallVector<SmallVector<Value, 1>> remapped; + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + remapped))) + return failure(); + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + results.push_back(values.front()); + } + return success(); } void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, @@ -1723,6 +1806,19 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// +SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( + ArrayRef<ArrayRef<Value>> operands) const { + SmallVector<Value> oneToOneOperands; + oneToOneOperands.reserve(operands.size()); + for (ArrayRef<Value> operand : operands) { + if (operand.size() != 1) + llvm::report_fatal_error("pattern '" + getDebugName() + + "' does not support 1:N conversion"); + oneToOneOperands.push_back(operand.front()); + } + return oneToOneOperands; +} + LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { @@ -1734,12 +1830,18 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector<Value, 4> operands; + SmallVector<SmallVector<Value, 1>> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, - op->getOperands(), operands))) { + op->getOperands(), remapped))) { return failure(); } - return matchAndRewrite(op, operands, dialectRewriter); + + // Convert to ArrayRef. + // TODO: This should not be necessary. + SmallVector<ArrayRef<Value>> remappedArrayRef; + for (const auto &vals : remapped) + remappedArrayRef.push_back(vals); + return matchAndRewrite(op, remappedArrayRef, dialectRewriter); } //===----------------------------------------------------------------------===// @@ -2483,45 +2585,40 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, assert(!op.use_empty() && "expected that dead materializations have already been DCE'd"); Operation::operand_range inputOperands = op.getOperands(); - Type outputType = op.getResultTypes()[0]; // Try to materialize the conversion. if (const TypeConverter *converter = rewrite->getConverter()) { rewriter.setInsertionPoint(op); - Value newMaterialization; + SmallVector<Value> newMaterialization; switch (rewrite->getMaterializationKind()) { case MaterializationKind::Argument: - // Try to materialize an argument conversion. - newMaterialization = converter->materializeArgumentConversion( - rewriter, op->getLoc(), outputType, inputOperands); - if (newMaterialization) - break; - // If an argument materialization failed, fallback to trying a target - // materialization. - [[fallthrough]]; + llvm_unreachable("argument materializations have been removed"); case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( - rewriter, op->getLoc(), outputType, inputOperands, + rewriter, op->getLoc(), op.getResultTypes(), inputOperands, rewrite->getOriginalType()); break; case MaterializationKind::Source: - newMaterialization = converter->materializeSourceConversion( - rewriter, op->getLoc(), outputType, inputOperands); + assert(op.getNumResults() == 1 && "*:N source materializations are not supported"); + Value sourceMat = converter->materializeSourceConversion( + rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); + if (sourceMat) + newMaterialization.push_back(sourceMat); break; } - if (newMaterialization) { - assert(newMaterialization.getType() == outputType && + if (!newMaterialization.empty()) { + assert(TypeRange(newMaterialization) == op.getResultTypes() && "materialization callback produced value of incorrect type"); rewriter.replaceOp(op, newMaterialization); return success(); } } - InFlightDiagnostic diag = - op->emitError() << "failed to legalize unresolved materialization " - "from (" - << inputOperands.getTypes() << ") to (" << outputType - << ") that remained live after conversion"; + InFlightDiagnostic diag = op->emitError() + << "failed to legalize unresolved materialization " + "from (" + << inputOperands.getTypes() << ") to (" << op.getResultTypes() + << ") that remained live after conversion"; diag.attachNote(op->getUsers().begin()->getLoc()) << "see existing live user here: " << *op->getUsers().begin(); return failure(); @@ -2642,6 +2739,11 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { std::tie(replacedValues, converter) = getReplacedValues(rewriterImpl.rewrites[i].get()); for (Value originalValue : replacedValues) { + // If this value is directly replaced with a value of the same type, + // there is nothing to do. + Value repl = rewriterImpl.mapping.lookupDirectSingleReplacement(originalValue); + if (repl && repl.getType() == originalValue.getType()) + continue; // If the type of this value changed and the value is still live, we need // to materialize a conversion. if (rewriterImpl.mapping.lookupOrNull(originalValue, @@ -2653,16 +2755,16 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { continue; // Legalize this value replacement. - Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); - assert(newValue && "replacement value not found"); + SmallVector<Value, 1> newValues = + rewriterImpl.mapping.lookupOrNull(originalValue); + assert(!newValues.empty() && "replacement value not found"); Value castValue = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(newValue), + MaterializationKind::Source, computeInsertPoint(newValues), originalValue.getLoc(), - /*inputs=*/newValue, /*outputType=*/originalValue.getType(), - /*originalType=*/Type(), converter); - rewriterImpl.mapping.map(originalValue, castValue); - inverseMapping[castValue].push_back(originalValue); - llvm::erase(inverseMapping[newValue], originalValue); + /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(), /*originalType=*/Type(), + converter)[0]; + rewriterImpl.mapping.mapMaterialization(newValues, {castValue}); + llvm::append_range(inverseMapping[castValue], newValues); } } } diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir index b8fad63..4e64131 100644 --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -9,10 +9,7 @@ // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32> -// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 -// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 -// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32 // CHECK-12N-LABEL: func @identity( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { @@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl // CHECK-LABEL: func @mixed_recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1> -// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2> -// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>> -// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> -// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<> -// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1> -// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1 -// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>> -// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2> -// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2 -// CHECK: return %[[V7]], %[[V10]] : i1, i2 +// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2 // CHECK-12N-LABEL: func @mixed_recursive_decomposition( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { @@ -87,14 +73,8 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32> // CHECK-LABEL: func @caller( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32> -// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 -// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 -// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32) -// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32> -// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 -// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 -// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32 // CHECK-12N-LABEL: func @caller( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { @@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup // CHECK-SAME: %[[I4:.*]]: i4, // CHECK-SAME: %[[I5:.*]]: i5, // CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { -// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5> -// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4 -// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5 -// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) -// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5> -// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4 -// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5 -// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 +// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 // CHECK-12N-LABEL: func @caller( // CHECK-12N-SAME: %[[I1:.*]]: i1, // CHECK-12N-SAME: %[[I2:.*]]: i2, diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index de511c5..0b8d4c0 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes tupleType.getFlattenedTypes(types); return success(); }); - typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildDecomposeTuple); populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 3df6cff..9154964 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1076,6 +1076,7 @@ struct TestUpdateConsumerType : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { + llvm::errs() << "TestUpdateConsumerType operand: " << operands.front() << "\n"; // Verify that the incoming operand has been successfully remapped to F64. if (!operands[0].getType().isF64()) return failure(); |