diff options
4 files changed, 6 insertions, 8 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 0592009..ded1e65 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -823,8 +823,7 @@ public: rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity); if (enableBufferInitialization) { Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize); - Value fillValue = rewriter.create<arith::ConstantOp>( - loc, value.getType(), rewriter.getZeroAttr(value.getType())); + Value fillValue = constantZero(rewriter, loc, value.getType()); Value subBuffer = rewriter.create<memref::SubViewOp>( loc, newBuffer, /*offset=*/ValueRange{newSize}, /*size=*/ValueRange{fillSize}, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index e059bd3..4c190bc 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -235,8 +235,7 @@ static Value createAllocation(OpBuilder &builder, Location loc, Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz); Type elemType = memRefType.getElementType(); if (enableInit) { - Value fillValue = builder.create<arith::ConstantOp>( - loc, elemType, builder.getZeroAttr(elemType)); + Value fillValue = constantZero(builder, loc, elemType); builder.create<linalg::FillOp>(loc, fillValue, buffer); } return buffer; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 06f57eb..6845aa0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -216,9 +216,9 @@ struct SparseTensorCodegenPass // The following operations and dialects may be introduced by the // codegen rules, and are therefore marked as legal. target.addLegalOp<linalg::FillOp>(); - target.addLegalDialect<arith::ArithDialect, - bufferization::BufferizationDialect, - memref::MemRefDialect, scf::SCFDialect>(); + target.addLegalDialect< + arith::ArithDialect, bufferization::BufferizationDialect, + complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); target.addLegalOp<UnrealizedConversionCastOp>(); // Populate with rules and apply rewriting rules. populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir index 80947a2..6ee26b9 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir @@ -8,7 +8,7 @@ // RUN: %{command} // // Do the same run, but now with direct IR generation. -// REDEFINE: %{option} = enable-runtime-library=false +// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true" // RUN: %{command} #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> |