diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2021-10-07 15:52:17 -0700 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2021-10-08 03:17:20 -0700 |
commit | e2a37bb5407e6ccf465bd870e89505768497ca50 (patch) | |
tree | 556c516059077dc162e7513b71486da0c0a8a99c | |
parent | 4b46a4134385eda6449e93ad124f9b2685b25bb2 (diff) | |
download | llvm-e2a37bb5407e6ccf465bd870e89505768497ca50.zip llvm-e2a37bb5407e6ccf465bd870e89505768497ca50.tar.gz llvm-e2a37bb5407e6ccf465bd870e89505768497ca50.tar.bz2 |
[mlir] Add alignment option to constant tensor bufferization pass
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D111364
5 files changed, 33 insertions, 7 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h index c7e331e..58eb7f0 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns( RewritePatternSet &patterns); /// Creates an instance of tensor constant bufferization pass. -std::unique_ptr<Pass> createTensorConstantBufferizePass(); +std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0); /// Creates an instance of the StdExpand pass that legalizes Std /// dialect ops to be convertible to LLVM. For example, diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td index 15f63e9..286b685 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> { }]; let constructor = "mlir::createTensorConstantBufferizePass()"; let dependentDialects = ["memref::MemRefDialect"]; + let options = [ + Option<"alignment", "alignment", "unsigned", /*default=*/"0", + "Create global memrefs with a specified alignment">, + ]; } #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h index f73c97d..5d1b9e5 100644 --- a/mlir/include/mlir/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Transforms/BufferUtils.h @@ -125,11 +125,13 @@ class GlobalOp; // names. Duplicates are avoided. class GlobalCreator { public: - explicit GlobalCreator(ModuleOp module) : moduleOp(module) {} + GlobalCreator(ModuleOp module, unsigned alignment = 0) + : moduleOp(module), alignment(alignment) {} memref::GlobalOp getGlobalFor(ConstantOp constantOp); private: ModuleOp moduleOp; + unsigned alignment; // This could use memref::GlobalOp key but we avoid introducing a new // dependence to the memref dialect for this. DenseMap<Attribute, Operation *> globals; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp index 39dbd8e..df61b1d 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) { interleave(type.getShape(), os, "x"); os << "x" << type.getElementType(); + // Add an optional alignment to the global memref. + IntegerAttr memrefAlignment = + alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) + : IntegerAttr(); + auto global = globalBuilder.create<memref::GlobalOp>( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/typeConverter.convertType(type).cast<MemRefType>(), /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), /*constant=*/true, - /*alignment=*/IntegerAttr()); + /*alignment=*/memrefAlignment); symbolTable.insert(global); // The symbol table inserts at the end of the module, but globals are a bit // nicer if they are at the beginning. @@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns( } namespace { -struct TensorConstantBufferizePass +class TensorConstantBufferizePass : public TensorConstantBufferizeBase<TensorConstantBufferizePass> { +public: + explicit TensorConstantBufferizePass(unsigned alignment) { + if (alignment) + this->alignment = alignment; + } + void runOnOperation() override { auto module = getOperation(); - GlobalCreator globals(module); + GlobalCreator globals(module, alignment); auto *context = &getContext(); BufferizeTypeConverter typeConverter; @@ -111,6 +122,7 @@ struct TensorConstantBufferizePass }; } // namespace -std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() { - return std::make_unique<TensorConstantBufferizePass>(); +std::unique_ptr<Pass> +mlir::createTensorConstantBufferizePass(unsigned alignment) { + return std::make_unique<TensorConstantBufferizePass>(alignment); } diff --git a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir index 16a7d7d..cdaccf3 100644 --- a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir +++ b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir @@ -1,9 +1,17 @@ // RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s // CHECK-LABEL: module { + // We check the debug name too since we put some effort into making that readable. // The name isn't load-bearing though. + // CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00> +// CHECK-NOT: alignment + +// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00> +// ALIGNED-SAME: {alignment = 64 : i64} + // CHECK: @basic func @basic() -> tensor<3x4xf32> { // CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32> |