aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEugene Zhulenev <ezhulenev@google.com>2021-10-07 15:52:17 -0700
committerEugene Zhulenev <ezhulenev@google.com>2021-10-08 03:17:20 -0700
commite2a37bb5407e6ccf465bd870e89505768497ca50 (patch)
tree556c516059077dc162e7513b71486da0c0a8a99c
parent4b46a4134385eda6449e93ad124f9b2685b25bb2 (diff)
downloadllvm-e2a37bb5407e6ccf465bd870e89505768497ca50.zip
llvm-e2a37bb5407e6ccf465bd870e89505768497ca50.tar.gz
llvm-e2a37bb5407e6ccf465bd870e89505768497ca50.tar.bz2
[mlir] Add alignment option to constant tensor bufferization pass
Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D111364
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h2
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td4
-rw-r--r--mlir/include/mlir/Transforms/BufferUtils.h4
-rw-r--r--mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp22
-rw-r--r--mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir8
5 files changed, 33 insertions, 7 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index c7e331e..58eb7f0 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns(
RewritePatternSet &patterns);
/// Creates an instance of tensor constant bufferization pass.
-std::unique_ptr<Pass> createTensorConstantBufferizePass();
+std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
/// Creates an instance of the StdExpand pass that legalizes Std
/// dialect ops to be convertible to LLVM. For example,
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 15f63e9..286b685 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
}];
let constructor = "mlir::createTensorConstantBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"];
+ let options = [
+ Option<"alignment", "alignment", "unsigned", /*default=*/"0",
+ "Create global memrefs with a specified alignment">,
+ ];
}
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h
index f73c97d..5d1b9e5 100644
--- a/mlir/include/mlir/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Transforms/BufferUtils.h
@@ -125,11 +125,13 @@ class GlobalOp;
// names. Duplicates are avoided.
class GlobalCreator {
public:
- explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
+ GlobalCreator(ModuleOp module, unsigned alignment = 0)
+ : moduleOp(module), alignment(alignment) {}
memref::GlobalOp getGlobalFor(ConstantOp constantOp);
private:
ModuleOp moduleOp;
+ unsigned alignment;
// This could use memref::GlobalOp key but we avoid introducing a new
// dependence to the memref dialect for this.
DenseMap<Attribute, Operation *> globals;
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 39dbd8e..df61b1d 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
+ // Add an optional alignment to the global memref.
+ IntegerAttr memrefAlignment =
+ alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
+ : IntegerAttr();
+
auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/typeConverter.convertType(type).cast<MemRefType>(),
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
/*constant=*/true,
- /*alignment=*/IntegerAttr());
+ /*alignment=*/memrefAlignment);
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
@@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns(
}
namespace {
-struct TensorConstantBufferizePass
+class TensorConstantBufferizePass
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
+public:
+ explicit TensorConstantBufferizePass(unsigned alignment) {
+ if (alignment)
+ this->alignment = alignment;
+ }
+
void runOnOperation() override {
auto module = getOperation();
- GlobalCreator globals(module);
+ GlobalCreator globals(module, alignment);
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
@@ -111,6 +122,7 @@ struct TensorConstantBufferizePass
};
} // namespace
-std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
- return std::make_unique<TensorConstantBufferizePass>();
+std::unique_ptr<Pass>
+mlir::createTensorConstantBufferizePass(unsigned alignment) {
+ return std::make_unique<TensorConstantBufferizePass>(alignment);
}
diff --git a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
index 16a7d7d..cdaccf3 100644
--- a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
@@ -1,9 +1,17 @@
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s
// CHECK-LABEL: module {
+
// We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though.
+
// CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// CHECK-NOT: alignment
+
+// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// ALIGNED-SAME: {alignment = 64 : i64}
+
// CHECK: @basic
func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>