aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp3
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp3
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp6
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir2
4 files changed, 6 insertions, 8 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 0592009..ded1e65 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -823,8 +823,7 @@ public:
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
if (enableBufferInitialization) {
Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
- Value fillValue = rewriter.create<arith::ConstantOp>(
- loc, value.getType(), rewriter.getZeroAttr(value.getType()));
+ Value fillValue = constantZero(rewriter, loc, value.getType());
Value subBuffer = rewriter.create<memref::SubViewOp>(
loc, newBuffer, /*offset=*/ValueRange{newSize},
/*size=*/ValueRange{fillSize},
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index e059bd3..4c190bc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -235,8 +235,7 @@ static Value createAllocation(OpBuilder &builder, Location loc,
Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
Type elemType = memRefType.getElementType();
if (enableInit) {
- Value fillValue = builder.create<arith::ConstantOp>(
- loc, elemType, builder.getZeroAttr(elemType));
+ Value fillValue = constantZero(builder, loc, elemType);
builder.create<linalg::FillOp>(loc, fillValue, buffer);
}
return buffer;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 06f57eb..6845aa0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -216,9 +216,9 @@ struct SparseTensorCodegenPass
// The following operations and dialects may be introduced by the
// codegen rules, and are therefore marked as legal.
target.addLegalOp<linalg::FillOp>();
- target.addLegalDialect<arith::ArithDialect,
- bufferization::BufferizationDialect,
- memref::MemRefDialect, scf::SCFDialect>();
+ target.addLegalDialect<
+ arith::ArithDialect, bufferization::BufferizationDialect,
+ complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
index 80947a2..6ee26b9 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir
@@ -8,7 +8,7 @@
// RUN: %{command}
//
// Do the same run, but now with direct IR generation.
-// REDEFINE: %{option} = enable-runtime-library=false
+// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true"
// RUN: %{command}
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>