aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYinying Li <yinyingli@google.com>2024-05-02 12:28:34 -0400
committerGitHub <noreply@github.com>2024-05-02 12:28:34 -0400
commite71eacc5b19785bc46ce9c3d8541a0c83c65660e (patch)
treea44b540654368d137439cd1bf958610d1a373fa2
parent0708500ce0149c501e802b7ab6581770cc7a5334 (diff)
downloadllvm-e71eacc5b19785bc46ce9c3d8541a0c83c65660e.zip
llvm-e71eacc5b19785bc46ce9c3d8541a0c83c65660e.tar.gz
llvm-e71eacc5b19785bc46ce9c3d8541a0c83c65660e.tar.bz2
[mlir][sparse] Support explicit/implicit value for complex type (#90771)
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp5
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h9
-rw-r--r--mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir15
-rwxr-xr-xmlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir20
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel1
6 files changed, 42 insertions, 9 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index dd6f103..6f59b69 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -45,6 +45,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRComplexDialect
MLIRDialect
MLIRDialectUtils
MLIRIR
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 028a69d..de3d300 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -663,6 +664,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
explicitVal = result;
} else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
explicitVal = result;
+ } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
+ explicitVal = result;
} else {
parser.emitError(parser.getNameLoc(),
"expected a numeric value for explicitVal");
@@ -678,6 +681,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
implicitVal = result;
} else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
implicitVal = result;
+ } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
+ implicitVal = result;
} else {
parser.emitError(parser.getNameLoc(),
"expected a numeric value for implicitVal");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index cf3c35f..d0ef8a6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -401,9 +401,12 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
// Generates a constant from a validated value carrying attribute.
inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
- if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
- Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
- return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
+ if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) {
+ Type tp = cast<ComplexType>(complexAttr.getType()).getElementType();
+ return builder.create<complex::ConstantOp>(
+ loc, complexAttr.getType(),
+ builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()),
+ FloatAttr::get(tp, complexAttr.getImag())}));
}
return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 7eeda9a..7fb1c76 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -80,6 +80,21 @@ func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
// -----
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 64,
+ crdWidth = 64,
+ explicitVal = #complex.number<:f32 1.0, 0.0>,
+ implicitVal = #complex.number<:f32 0.0, 0.0>
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = #complex.number<:f32 1.000000e+00, 0.000000e+00> : complex<f32>, implicitVal = #complex.number<:f32 0.000000e+00, 0.000000e+00> : complex<f32> }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xcomplex<f32>, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)
+
+// -----
+
#BCSR = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
index 82f3147..be21725 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -2,9 +2,9 @@
// RUN: --sparsification-and-bufferization | FileCheck %s
#CSR_ones_complex = #sparse_tensor.encoding<{
- map = (d0, d1) -> (d0 : dense, d1 : compressed)
-// explicitVal = (1.0, 0.0) : complex<f32>,
-// implicitVal = (0.0, 0.0) : complex<f32>
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ explicitVal = #complex.number<:f32 1.0, 0.0>,
+ implicitVal = #complex.number<:f32 0.0, 0.0>
}>
#CSR_ones_fp = #sparse_tensor.encoding<{
@@ -20,9 +20,17 @@
}>
// CHECK-LABEL: func.func @matmul_complex
-//
-// TODO: make this work
-//
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[X:.*]] = memref.load
+// CHECK: scf.for
+// CHECK: %[[I:.*]] = memref.load
+// CHECK: %[[Y:.*]] = memref.load
+// CHECK: %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32>
+// CHECK: memref.store %[[M]]
+// CHECK: }
+// CHECK: }
+// CHECK: }
func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
%b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index acd2d3a..13c246a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3066,6 +3066,7 @@ cc_library(
":ArithDialect",
":BufferizationInterfaces",
":BytecodeOpInterface",
+ ":ComplexDialect",
":DialectUtils",
":IR",
":InferTypeOpInterface",