diff options
author | Nicolas Vasilache <ntv@google.com> | 2020-09-21 15:30:42 -0400 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2020-09-22 04:13:22 -0400 |
commit | ed229132f1c4ea2ba0644fc345d8279e47a00565 (patch) | |
tree | ddd4c9f341131c15e1c2d33c8c7181cb5142e313 /mlir/test/lib | |
parent | 0304c6da1006f6b472a1e5c1d8776a9f35c9439d (diff) | |
download | llvm-ed229132f1c4ea2ba0644fc345d8279e47a00565.zip llvm-ed229132f1c4ea2ba0644fc345d8279e47a00565.tar.gz llvm-ed229132f1c4ea2ba0644fc345d8279e47a00565.tar.bz2 |
[mlir][Linalg] Uniformize linalg.generic with named ops.
This revision allows representing a reduction at the level of linalg on tensors for generic ops by uniformizing with the named ops approach.
Diffstat (limited to 'mlir/test/lib')
-rw-r--r-- | mlir/test/lib/Transforms/TestBufferPlacement.cpp | 73 |
1 files changed, 46 insertions, 27 deletions
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index c338f0f..dd6629e 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -39,6 +39,11 @@ struct TestBufferPlacementPreparationPass /// Converts tensor-type generic linalg operations to memref ones using /// buffer assignment. + /// TODO: Avoid the copy-pasta by exposing the pattern from BufferPlacement.h + /// This is limited by not wanting BufferPlacement to depend on Linalg. Fixing + /// this probably requires an OpConversionPattern over generic Operation*. For + /// now only RewritePattern but not ConversionPattern allow this. + class GenericOpConverter : public BufferAssignmentOpConversionPattern<linalg::GenericOp> { public: @@ -48,34 +53,47 @@ struct TestBufferPlacementPreparationPass LogicalResult matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { + linalg::GenericOpAdaptor adaptor(operands, + op.getOperation()->getAttrDictionary()); + + // TODO: support ops with reduction. + if (!op.init_tensors().empty()) + return failure(); + + // All inputs need to be turned into buffers first. Until then, bail out. + if (llvm::any_of(adaptor.inputs(), [](Value in) { + return !in.getType().isa<MemRefType>(); + })) + return failure(); + Location loc = op.getLoc(); - ResultRange results = op.getOperation()->getResults(); - SmallVector<Value, 2> newArgs, newResults; - newArgs.reserve(operands.size() + results.size()); - newArgs.append(operands.begin(), operands.end()); - newResults.reserve(results.size()); + SmallVector<Value, 2> outputBuffers, newOutputBuffers; + outputBuffers.assign(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); + newOutputBuffers.reserve(op.getNumOutputs()); + newOutputBuffers.append(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); // Update all types to memref types. - for (auto result : results) { - ShapedType type = result.getType().cast<ShapedType>(); - assert(type && "Generic operations with non-shaped typed results are " - "not currently supported."); + for (Type t : op.getResultTypes()) { + auto type = t.cast<ShapedType>(); if (!type.hasStaticShape()) return rewriter.notifyMatchFailure( op, "dynamic shapes not currently supported"); auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); auto alloc = rewriter.create<AllocOp>(loc, memrefType); - newArgs.push_back(alloc); - newResults.push_back(alloc); + newOutputBuffers.push_back(alloc); } // Generate a new linalg operation that works on buffers. auto linalgOp = rewriter.create<linalg::GenericOp>( - loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()), - rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(), - op.iterator_types(), op.docAttr(), op.library_callAttr(), - op.symbol_sourceAttr()); + loc, + /*resultTensorTypes=*/ArrayRef<Type>{}, + /*inputs=*/adaptor.inputs(), + /*outputBuffers=*/newOutputBuffers, + /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(), + op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr()); // Create a new block in the region of the new Generic Op. Block &oldBlock = op.getRegion().front(); @@ -83,23 +101,24 @@ struct TestBufferPlacementPreparationPass Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), oldBlock.getArgumentTypes()); - // Map the old block arguments to the new ones. - BlockAndValueMapping mapping; - mapping.map(oldBlock.getArguments(), newBlock->getArguments()); - // Add the result arguments to the new block. - for (auto result : newResults) - newBlock->addArgument( - result.getType().cast<ShapedType>().getElementType()); + for (Value v : newOutputBuffers) + newBlock->addArgument(v.getType().cast<MemRefType>().getElementType()); // Clone the body of the old block to the new block. + BlockAndValueMapping mapping; + for (unsigned i = 0; i < oldBlock.getNumArguments(); i++) + mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i)); + + OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(newBlock); - for (auto &op : oldBlock.getOperations()) - rewriter.clone(op, mapping); + for (auto &op : oldBlock.getOperations()) { + Operation *clonedOp = rewriter.clone(op, mapping); + mapping.map(op.getResults(), clonedOp->getResults()); + } - // Replace the results of the old Generic Op with the results of the new - // one. - rewriter.replaceOp(op, newResults); + // Replace the results of the old op with the new output buffers. + rewriter.replaceOp(op, newOutputBuffers); return success(); } }; |